用于处理ARC挑战的实用工具。
项目描述
arc-py - 用于处理Python中ARC挑战的类型和实用工具
arc-py
包提供处理ARC挑战的类型和实用工具。
它帮助您将原始.json文件转换为numpy数组,使用matplotlib查看它们(以及您可能进行的任何编辑或生成)。它还提供了一个在ARC挑战中竞争的代理的接口,并提供评估函数。
通过运行pip install arc-py
来安装它。
访问数据
安装arc-py
后,您将本地获取原始.json文件。您可以在from arc import train_data_dir, eval_data_dir
中找到包含它们的文件夹。然而,如果您打算使用python,您可能发现将数据作为逻辑上分为两组问题的2D numpy数组获取更为方便,这两组问题是训练和验证。这些可以通过from arc import train_problems, validation_problems
获取。
数据类型
arc-py
引入了一些类型来结构化挑战中的数据
arc.types.ArcProblem
代表一个任务 - 由演示(prob.train_pairs
)和测试(prob.test_pairs
)组成。arc.types.ArcIOPair
是构成问题的演示/测试的单位 - 它是输入网格(pair.x
)和相应的输出网格(pair.y
)。arc.types.ArcGrid
是np.ndarray
的别名。具体来说,ArcGrid必须是2D的,每个维度在1到30之间,且只有整数[0, 9]范围内的整数作为其元素。为了验证某些numpy数组是否符合此规范,提供了一个检查函数:arc.types.verify_is_arc_grid
。
使用matplotlib查看数据
arc-py提供了一些基本功能,用于使用matplotlib查看网格(与原始repo中的网页视图一致)
arc.plot.plot_grid
可以绘制单个2D网格arc.types.ArcIOPair.plot()
将显示输入/输出对
import numpy as np
from matplotlib import pyplot as plt
from arc.plot import plot_grid
grid = np.zeros([4,4], dtype=np.uint8)
for i in range(4):
grid[i,i] = 3
plot_grid(grid)
plt.show()
构建代理和评估结果
代理API
代理的概念如下 - 它是您通过查看训练问题开发/训练的程序,然后对其进行验证问题评分。根据原始kaggle竞赛的规则
- 一个问题可以包含多个测试网格。
- 对于给定的测试网格的答案,可以猜测最多3次。
由于代理不能看到测试网格的答案,我们给出了以下签名
def predict(
self, demo_pairs: List[ArcIOPair], test_grids: List[ArcGrid]
) -> List[ArcPrediction]:
在这里,ArcPrediction
是list[ArcGrid]
- 包含1到3个猜测。输出列表的第一个元素对应于第一个测试网格,依此类推(这意味着对于返回值result
,您可以使用x, y in zip(test_grids, result)
来匹配输入和输出)。
评估
arc-py
提供了一个类,可以跟踪您的准确率和一些辅助目标指标:arc.evaluation.ArcEvaluationResult
。最好用示例来说明
def evaluate_agent(
agent: ArcAgent, problems: List[ArcProblem] = validation_problems
) -> ArcEvaluationResult:
result = ArcEvaluationResult()
for prob in problems:
pred = agent.predict(prob.train_pairs, prob.test_inputs)
result.add_answer(prob, pred)
return result
只需打印ArcEvaluationResult
对象即可查看结果。示例输出
ARC results for 400 problems. Stats:
Accuracy : 2.2%
Accuracy(at least one) : 2.5%
Correct answer shape : 52.4%
示例
有关使用arc-py的项目示例,请参阅https://github.com/ikamensh/solve_arc
查看训练示例
from arc import train_problems, validation_problems, describe_task_group
describe_task_group(train_problems)
describe_task_group(validation_problems)
for n, task in enumerate(train_problems, start=1):
for i, pair in enumerate(task.train_pairs, start=1):
pair.plot(show=True, title=f"Task {n}: Demo {i}")
for i, pair in enumerate(task.test_pairs, start=1):
pair.plot(show=True, title=f"Task {n}: Test {i}")
注意:ARC挑战是为不知道测试问题的开发者设计的。如果您知道这些问题,您将过度拟合它们。我们建议不要查看评估集。
查看随机代理的输出
from typing import List
import numpy as np
from arc.types import ArcIOPair, ArcGrid, ArcPrediction
from arc.agents import ArcAgent
class RandomAgent(ArcAgent):
"""Makes random predicitons. Low chance of success. """
def predict(
self, demo_pairs: List[ArcIOPair], test_grids: List[ArcGrid]
) -> List[ArcPrediction]:
"""We are allowed to make up to 3 guesses per challange rules. """
outputs = []
for tg in test_grids:
out_shape = tg.shape
out1 = np.random.randint(0, 9, out_shape)
out2 = np.random.randint(0, 9, out_shape)
out3 = np.random.randint(0, 9, out_shape)
outputs.append([out1, out2, out3])
return outputs
from arc import train_problems
p1 = next(iter(train_problems)) # problem #1
agent = RandomAgent()
outs = agent.predict(p1.train_pairs, p1.test_inputs)
for test_pair, predicitons in zip(p1.test_pairs, outs):
for p in predicitons:
prediction = ArcIOPair(test_pair.x, p)
prediction.plot(show=True)
评估随机代理
from arc.agents import RandomAgent, CheatingAgent
from arc.evaluation import evaluate_agent
agent = RandomAgent()
results = evaluate_agent(agent)
print(results)
assert results.accuracy < 0.5
assert results.accuracy_any < 0.5
assert results.shape_accuracy > 0 # random agent guesses the shape of some outputs correctly
400个问题的ARC结果。统计数据
准确率:0.0%
(至少一个)准确率:0.0%
正确答案形状:67.5%
项目详情
下载文件
下载适合您平台的文件。如果您不确定选择哪个,请了解更多关于安装包的信息。
源分发
构建分发
arc-py-0.12.tar.gz的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | 3e3435fac59ac9e753bd95bdde5c449bf93d5be4f15efa991694c61241648f37 |
|
MD5 | 3a4b943411bdc15f45c78b90558dd2b8 |
|
BLAKE2b-256 | dffd850bbbefa10675f90c29dc3f592ff07f5ecd640a1aa41002cb4799ced262 |
arc_py-0.12-py3-none-any.whl的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | 0db1c208ece47fc91a356f0dc51113f65c62ba5d18839c9c10a4012c3bf0a2e1 |
|
MD5 | 技术文档标识:d67ea1897841e83c44e1f9a7462ef7c1 |
|
BLAKE2b-256 | 技术文档标识:be9f52595eab658e8d60adef9775d538318a9593b5022332684df00d1f1e6624 |