CLRS算法推理基准。
项目描述
CLRS算法推理基准
算法表示学习是机器学习的一个新兴领域,旨在将神经网络的概念与经典算法相连接。CLRS算法推理基准(CLRS)通过提供一系列经典算法的实现,巩固并扩展了先前评估算法推理的工作。这些算法选自Cormen、Leiserson、Rivest和Stein所著的标准教材《算法导论》的第三版。
入门指南
可以使用pip安装CLRS算法推理基准,无论是从PyPI
pip install dm-clrs
还是直接从GitHub(更新更频繁)
pip install git+https://github.com/google-deepmind/clrs.git
如果Python安装中存在冲突,您可能更喜欢在虚拟环境中安装它
python3 -m venv clrs_env
source clrs_env/bin/activate
pip install git+https://github.com/google-deepmind/clrs.git
安装后,您可以运行我们的示例基线模型
python3 -m clrs.examples.run
如果这是示例的第一次运行,数据集将被下载并存储在--dataset_path
(默认'/tmp/CLRS30')。或者,您也可以下载并解压缩https://storage.googleapis.com/dm-clrs/CLRS30_v1.0.0.tar.gz
作为图的算法
CLRS以惯用的方式实现所选算法,尽可能接近原始CLRS 3版伪代码。通过控制输入数据分布以符合先决条件,我们可以自动生成输入/输出对。我们还提供了“提示”轨迹,以揭示每个算法的内部状态,这既可以简化学习挑战,也可以区分解决相同整体任务(例如排序)的不同算法。
在最通用的情况下,算法可以被视为操作对象集合,以及它们之间的任何关系(这些关系本身可以分解为二元关系)。因此,我们使用图表示法研究本基准中的所有算法。如果对象遵循更严格的有序结构(例如数组或根树),我们通过包含前驱链接来施加这种顺序。
它的工作原理
对于每个算法,我们提供一组标准train、eval和test轨迹,以基准测试分布外泛化。
轨迹 | 问题规模 | |
---|---|---|
Train | 1000 | 16 |
Eval | 32 x 乘数 | 16 |
Test | 32 x 乘数 | 64 |
在这里,“问题规模”指的是例如数组的长度或图中的节点数,具体取决于算法。“乘数”是算法特定的因子,它增加可用eval和test轨迹的数量,以补偿评估信号的不足。“乘数”对所有算法都为1,除了
- 最大子数组(Kadane),其“乘数”为32。
- 快速选择、最小值、二分查找、字符串匹配器(包括朴素和KMP)以及段交集,其“乘数”为64。
轨迹可以这样使用
train_ds, num_samples, spec = clrs.create_dataset(
folder='/tmp/CLRS30', algorithm='bfs',
split='train', batch_size=32)
for i, feedback in enumerate(train_ds.as_numpy_iterator()):
if i == 0:
model.init(feedback.features, initial_seed)
loss = model.feedback(rng_key, feedback)
在这里,feedback
是一个具有以下结构的namedtuple
Feedback = collections.namedtuple('Feedback', ['features', 'outputs'])
Features = collections.namedtuple('Features', ['inputs', 'hints', 'lengths'])
其中Features
的内容可用于训练,而outputs
则保留用于评估。元组的每个字段都是一个具有先导批维度的ndarray
。因为提供了完整的算法轨迹的hints
,所以这些包含一个额外的填充到数据集中任何轨迹的最大长度max(T)
的时间维度。`lengths`字段指定每个轨迹的真实长度`t <= max(T)`,可用于例如损失掩码。
examples
目录包含一个使用JAX和DeepMind JAX生态系统库的完整工作图神经网络(GNN)示例。它允许在单个处理器上训练多个算法,如"A Generalist Neural Algorithmic Learner"中所述。
我们提供的内容
算法
我们的初始CLRS-30基准包括以下30个算法。我们旨在在未来支持更多算法。
- 排序
- 插入排序
- 冒泡排序
- 堆排序(Williams,1964)
- 快速排序(Hoare,1962)
- 搜索
- 最小值
- 二分查找
- 快速选择算法(霍华,1961年)
- 分而治之
- 最大子数组(Kadane变体)(Bentley,1984年)
- 贪婪算法
- 活动选择(加夫里尔,1972年)
- 任务调度(劳勒,1985年)
- 动态规划
- 矩阵链乘法
- 最长公共子序列
- 最优二叉搜索树(Aho等人,1974年)
- 图
- 深度优先搜索(摩尔,1959年)
- 广度优先搜索(摩尔,1959年)
- 拓扑排序(克努特,1973年)
- 割点
- 桥
- Kosaraju的强连通分量算法(Aho等人,1974年)
- Kruskal的最小生成树算法(Kruskal,1956年)
- Prim的最小生成树算法(Prim,1957年)
- 单源最短路径的Bellman-Ford算法(Bellman,1958年)
- 单源最短路径的Dijkstra算法(Dijkstra等人,1959年)
- 有向无环图单源最短路径
- 所有对最短路径的Floyd-Warshall算法(Floyd,1962年)
- 字符串
- 朴素字符串匹配
- Knuth-Morris-Pratt(KMP)字符串匹配器(Knuth等人,1977年)
- 几何学
- 线段相交
- Graham扫描凸包算法(Graham,1972年)
- Jarvis步凸包算法(Jarvis,1973年)
基线
模型由一个处理器以及多个编码器和解码器组成。我们提供了以下GNN基线处理器的JAX实现:
- Deep Sets(Zaheer等人,NIPS 2017年)
- 端到端记忆网络(Sukhbaatar等人,NIPS 2015年)
- 图注意力网络(Veličković等人,ICLR 2018年)
- 图注意力网络 v2(Brody等人,ICLR 2022年)
- 消息传递神经网络(Gilmer等人,ICML 2017年)
- 指针图网络(Veličković等人,NeurIPS 2020年)
如果您想实现一个新的处理器,最简单的方法是将它添加到processors.py
文件中,并通过那里的get_processor_factory
方法使其可用。处理器应该有一个类似于以下的__call__
方法
__call__(self,
node_fts, edge_fts, graph_fts,
adj_mat, hidden,
nb_nodes, batch_size)
其中node_fts
,edge_fts
和graph_fts
将分别是以节点、边和图编码特征的float数组,形状为batch_size
x nb_nodes
x H,batch_size
x nb_nodes
x nb_nodes
x H,和batch_size
x H,adj_mat
是一个batch_size
x nb_nodes
x nb_nodes
的布尔数组,它是由提示和输入构建的连接性,hidden
是一个batch_size
x nb_nodes
x H的float数组,其中包含处理器的上一输出步的输出。该方法应返回一个batch_size
x nb_nodes
x H的float数组。
对于更多根本不同的基线,有必要创建一个新的类,该类扩展了Model API(如clrs/_src/model.py
中找到的)。clrs/_src/baselines.py
提供了一个如何做到这一点的示例。
创建自己的数据集
我们在dataset.py
中提供了一个tensorflow_dataset
生成器类。此文件可以修改以生成可用的不同版本的算法,并可以通过遵循https://tensorflowcn.cn/datasets上的安装说明后使用tfds build
来构建。
或者,您可以通过使用clrs/_src/samplers.py
中的build_sampler
方法实例化采样器来生成样本,而不需要通过tfds
,如下所示
sampler, spec = clrs.build_sampler(
name='bfs',
seed=42,
num_samples=1000,
length=16)
def _iterate_sampler(batch_size):
while True:
yield sampler.next(batch_size)
for feedback in _iterate_sampler(batch_size=32):
...
最近,我们提供了CLRS-Text,这是基准的基于文本的变体,适用于训练和评估语言模型的算法推理能力。请参阅相关子文件夹中的专用README文件。
您还可以查看有关CLRS-Text的配套论文。
添加新算法
将新算法添加到任务套件需要以下步骤
- 确定您算法的输入/提示/输出规范,并将其包含在
clrs/_src/specs.py
的SPECS
字典中。 - 以抽象化的形式实现所需的算法。这种实现的例子可以在
clrs/_src/algorithms/
文件夹中找到。
- 在算法执行过程中选择适当的时刻创建探针,以捕获输入、输出和所有中间状态(使用
probing.push
函数)。 - 一旦生成,探针必须使用
probing.finalize
方法进行格式化,并应与算法输出一起返回。
- 为您的算法实现适当的输入数据采样器,并将其包含在
clrs/_src/samplers.py
的SAMPLERS
字典中。
以这种方式添加算法后,可以使用 build_sampler
方法访问它,并且如果使用 dataset.py
中的生成类重新生成,它还将被包含在数据集中,如上所述。
引用
引用 CLRS 算法推理基准
@article{deepmind2022clrs,
title={The CLRS Algorithmic Reasoning Benchmark},
author={Petar Veli\v{c}kovi\'{c} and Adri\`{a} Puigdom\`{e}nech Badia and
David Budden and Razvan Pascanu and Andrea Banino and Misha Dashevskiy and
Raia Hadsell and Charles Blundell},
journal={arXiv preprint arXiv:2205.15659},
year={2022}
}
引用 CLRS-Text 算法推理语言基准
@article{deepmind2024clrstext,
title={The CLRS-Text Algorithmic Reasoning Language Benchmark},
author={Larisa Markeeva and Sean McLeish and Borja Ibarz and Wilfried Bounsi
and Olga Kozlova and Alex Vitvitskyi and Charles Blundell and
Tom Goldstein and Avi Schwarzschild and Petar Veli\v{c}kovi\'{c}},
journal={arXiv preprint arXiv:2406.04229},
year={2024}
}
项目详情
下载文件
下载您平台上的文件。如果您不确定选择哪个,请了解有关 安装包 的更多信息。
源分发
构建分发
dm_clrs-2.0.1.tar.gz 的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | 07ac6f75a46f817334aab78067178568820afa11033fea7b50193749aa38be70 |
|
MD5 | 6ca328fa6545bd5f189637939fd0c6c3 |
|
BLAKE2b-256 | bd8f557763dcca30b1f35bd1b5634dc834b2cb2b8cab4b19c25e36ea3bcbd438 |
dm_clrs-2.0.1-py3-none-any.whl 的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | 4678e1d7e8c77966f2ce5b662fa35ee8e937e504924525cdac5f151d631cee6a |
|
MD5 | 4cf270fa4c03863e43f97bb1c32e24d0 |
|
BLAKE2b-256 | d852eb541ba10157a5b8d6b0268c77650f9c271b2bde50405d7b3e079f216d79 |