跳转到主要内容

用于处理ARC挑战的实用工具。

项目描述

arc-py - 用于处理Python中ARC挑战的类型和实用工具

Build Status Coverage PyPI version Downloads Code style: black

arc-py 包提供处理ARC挑战的类型和实用工具。

它帮助您将原始.json文件转换为numpy数组,使用matplotlib查看它们(以及您可能进行的任何编辑或生成)。它还提供了一个在ARC挑战中竞争的代理的接口,并提供评估函数。

通过运行pip install arc-py来安装它。

访问数据

安装arc-py后,您将本地获取原始.json文件。您可以在from arc import train_data_dir, eval_data_dir中找到包含它们的文件夹。然而,如果您打算使用python,您可能发现将数据作为逻辑上分为两组问题的2D numpy数组获取更为方便,这两组问题是训练和验证。这些可以通过from arc import train_problems, validation_problems获取。

数据类型

arc-py引入了一些类型来结构化挑战中的数据

  • arc.types.ArcProblem代表一个任务 - 由演示(prob.train_pairs)和测试(prob.test_pairs)组成。
  • arc.types.ArcIOPair是构成问题的演示/测试的单位 - 它是输入网格(pair.x)和相应的输出网格(pair.y)。
  • arc.types.ArcGridnp.ndarray的别名。具体来说,ArcGrid必须是2D的,每个维度在1到30之间,且只有整数[0, 9]范围内的整数作为其元素。为了验证某些numpy数组是否符合此规范,提供了一个检查函数:arc.types.verify_is_arc_grid

使用matplotlib查看数据

arc-py提供了一些基本功能,用于使用matplotlib查看网格(与原始repo中的网页视图一致)

  • arc.plot.plot_grid可以绘制单个2D网格
  • arc.types.ArcIOPair.plot()将显示输入/输出对
import numpy as np
from matplotlib import pyplot as plt
from arc.plot import plot_grid


grid = np.zeros([4,4], dtype=np.uint8)

for i in range(4):
    grid[i,i] = 3

plot_grid(grid)
plt.show()

构建代理和评估结果

代理API

代理的概念如下 - 它是您通过查看训练问题开发/训练的程序,然后对其进行验证问题评分。根据原始kaggle竞赛的规则

  1. 一个问题可以包含多个测试网格。
  2. 对于给定的测试网格的答案,可以猜测最多3次。

由于代理不能看到测试网格的答案,我们给出了以下签名

    def predict(
        self, demo_pairs: List[ArcIOPair], test_grids: List[ArcGrid]
    ) -> List[ArcPrediction]:

在这里,ArcPredictionlist[ArcGrid] - 包含1到3个猜测。输出列表的第一个元素对应于第一个测试网格,依此类推(这意味着对于返回值result,您可以使用x, y in zip(test_grids, result)来匹配输入和输出)。

评估

arc-py提供了一个类,可以跟踪您的准确率和一些辅助目标指标:arc.evaluation.ArcEvaluationResult。最好用示例来说明

def evaluate_agent(
    agent: ArcAgent, problems: List[ArcProblem] = validation_problems
) -> ArcEvaluationResult:

    result = ArcEvaluationResult()
    for prob in problems:
        pred = agent.predict(prob.train_pairs, prob.test_inputs)
        result.add_answer(prob, pred)

    return result

只需打印ArcEvaluationResult对象即可查看结果。示例输出

ARC results for 400 problems. Stats:
Accuracy                 : 2.2%
Accuracy(at least one)   : 2.5%
Correct answer shape     : 52.4%

示例

有关使用arc-py的项目示例,请参阅https://github.com/ikamensh/solve_arc

查看训练示例

from arc import train_problems, validation_problems, describe_task_group

describe_task_group(train_problems)
describe_task_group(validation_problems)

for n, task in enumerate(train_problems, start=1):
    for i, pair in enumerate(task.train_pairs, start=1):
        pair.plot(show=True, title=f"Task {n}: Demo {i}")

    for i, pair in enumerate(task.test_pairs, start=1):
        pair.plot(show=True, title=f"Task {n}: Test {i}")

Alt Task 1 example1 Alt Task 1 example2 Alt Task 1 example3

注意:ARC挑战是为不知道测试问题的开发者设计的。如果您知道这些问题,您将过度拟合它们。我们建议不要查看评估集。

查看随机代理的输出

from typing import List
import numpy as np
from arc.types import ArcIOPair, ArcGrid, ArcPrediction
from arc.agents import ArcAgent


class RandomAgent(ArcAgent):
    """Makes random predicitons. Low chance of success. """

    def predict(
            self, demo_pairs: List[ArcIOPair], test_grids: List[ArcGrid]
    ) -> List[ArcPrediction]:
        """We are allowed to make up to 3 guesses per challange rules. """
        outputs = []
        for tg in test_grids:
            out_shape = tg.shape
            out1 = np.random.randint(0, 9, out_shape)
            out2 = np.random.randint(0, 9, out_shape)
            out3 = np.random.randint(0, 9, out_shape)
            outputs.append([out1, out2, out3])
        return outputs


from arc import train_problems

p1 = next(iter(train_problems))  # problem #1
agent = RandomAgent()
outs = agent.predict(p1.train_pairs, p1.test_inputs)

for test_pair, predicitons in zip(p1.test_pairs, outs):
    for p in predicitons:
        prediction = ArcIOPair(test_pair.x, p)
        prediction.plot(show=True)

评估随机代理

from arc.agents import RandomAgent, CheatingAgent
from arc.evaluation import evaluate_agent


agent = RandomAgent()
results = evaluate_agent(agent)
print(results)
assert results.accuracy < 0.5
assert results.accuracy_any < 0.5
assert results.shape_accuracy > 0  # random agent guesses the shape of some outputs correctly

400个问题的ARC结果。统计数据
准确率:0.0%
(至少一个)准确率:0.0%
正确答案形状:67.5%

项目详情


下载文件

下载适合您平台的文件。如果您不确定选择哪个,请了解更多关于安装包的信息。

源分发

arc-py-0.12.tar.gz (285.1 kB 查看哈希值)

上传时间

构建分发

arc_py-0.12-py3-none-any.whl (477.5 kB 查看哈希值)

上传时间 Python 3

由以下组织支持