Jraph:Jax中图神经网络的库
项目描述
Jraph - Jax中图神经网络的库。
新功能!Pmap示例和数据加载。
我们添加了一个pmap 示例。
我们的朋友instadeep、Jama Hussein Mohamud和Tom Makkink编写了一篇关于使用PyTorch数据加载的精美指南。您可以在这里找到它 这里。
新功能!支持大型分布式MPNN
我们发布了一个分布式图网络实现,允许您在多个设备上跨多个设备分发一个非常大的(数百万条边)图网络,具有显式的边消息。 查看它!
新功能!交互式Jraph Colabs
我们有两个新的colabs来帮助您掌握Jraph。
首先是与教育相关的colab,以令人惊叹的介绍开始了图神经网络和图理论,展示了如何使用Jraph解决一系列问题。查看这里。
第二项是一个完整工作的示例,展示了使用OGBG-MOLPCBA的Jraph最佳实践,并附有精美的可视化。查看这里。
感谢Lisa Wang、Nikola Jovanović和Ameya Daigavane。
快速入门
Jraph(发音为“giraffe”)是一个轻量级的库,用于在jax中处理图神经网络。它提供了一种图的数据结构,一组用于处理图的实用工具,以及一组可分叉的图神经网络模型。
安装
pip安装jraph
或者您可以使用以下命令直接从GitHub安装Jraph
pip install git+git://github.com/deepmind/jraph.git
示例需要额外的依赖项。要安装它们,请运行
pip install "jraph[examples, ogb_examples] @ git+git://github.com/deepmind/jraph.git"
概述
Jraph旨在为在jax中处理图提供实用工具,但不规定编写或开发图神经网络的方式。
graph.py
提供了一种轻量级的数据结构,GraphsTuple
,用于处理图。utils.py
提供了一组用于在jax中处理GraphsTuples
的实用工具。- 用于批处理
GraphsTuples
数据集的实用工具。 - 通过填充和掩码支持可变形状图的jit编译的实用工具。
- 用于定义输入分区上的损失的实用工具。
- 用于批处理
models.py
提供了不同类型的图神经网络消息传递的示例。这些设计轻量级、易于分叉和适应。它们不为您管理参数 - 对于此目的,请考虑使用haiku
或flax
。请参阅示例以获取更多详细信息。
快速入门
Jraph在定义GraphsTuple
数据结构时受到了Tensorflow graph_nets库的启发,它是一个包含一个或多个有向图的namedtuple。
表示图 - GraphsTuple
import jraph
import jax.numpy as jnp
# Define a three node graph, each node has an integer as its feature.
node_features = jnp.array([[0.], [1.], [2.]])
# We will construct a graph for which there is a directed edge between each node
# and its successor. We define this with `senders` (source nodes) and `receivers`
# (destination nodes).
senders = jnp.array([0, 1, 2])
receivers = jnp.array([1, 2, 0])
# You can optionally add edge attributes.
edges = jnp.array([[5.], [6.], [7.]])
# We then save the number of nodes and the number of edges.
# This information is used to make running GNNs over multiple graphs
# in a GraphsTuple possible.
n_node = jnp.array([3])
n_edge = jnp.array([3])
# Optionally you can add `global` information, such as a graph label.
global_context = jnp.array([[1]])
graph = jraph.GraphsTuple(nodes=node_features, senders=senders, receivers=receivers,
edges=edges, n_node=n_node, n_edge=n_edge, globals=global_context)
GraphsTuple
可以包含多个图。
two_graph_graphstuple = jraph.batch([graph, graph])
节点和边特征堆叠在主轴上。
jraph.batch([graph, graph]).nodes
>>> DeviceArray([[0.],
[1.],
[2.],
[0.],
[1.],
[2.]], dtype=float32)
您可以通过查看n_node
来判断哪些节点来自哪个图。
jraph.batch([graph, graph]).n_node
>>> DeviceArray([3, 3], dtype=int32)
您可以在nodes
、edges
和globals
中存储特征嵌套。这使得为每个节点、边或图存储多组特征成为可能,这些特征可能具有不同类型和不同的语义意义(例如“训练”和“测试”节点)。唯一的要求是,每个嵌套中的所有数组都必须具有共同的领先维度大小,与Graphstuple
中节点、边或图的总数相匹配。
node_targets = jnp.array([[True], [False], [True]])
graph = graph._replace(nodes={'inputs': graph.nodes, 'targets': node_targets})
使用模型库
Jraph提供了一套实现的参考模型供您使用。
Jraph模型定义了图节点、边和全局属性之间的消息传递算法。用户定义update
函数来更新图特征,这些通常是神经网络,但也可以是任意的jax函数。
让我们通过一个GraphNetwork
(论文)示例来了解。GraphNet的第一个更新函数使用edge
特征、sender
和receiver
的节点特征以及global
特征来更新边。
# As one example, we just pass the edge features straight through.
def update_edge_fn(edge, sender, receiver, globals_):
return edge
我们通常使用这些特征的连接,而jraph
提供了一个使用concatenated_args
装饰器的简单方法来完成这项工作。
@jraph.concatenated_args
def update_edge_fn(concatenated_features):
return concatenated_features
通常,在更新函数中,我们会使用如多层感知器这样的学习模型。
用户可以像定义更新节点和全局变量的函数一样定义这些函数。然后,使用这些函数来配置一个GraphNetwork
。要查看节点和全局update_fns
的参数,请参阅模型库。
net = jraph.GraphNetwork(update_edge_fn=update_edge_fn,
update_node_fn=update_node_fn,
update_global_fn=update_global_fn)
net
是一个根据GraphNetwork
算法发送消息并应用update_fn
的函数。它接收一个图,并返回一个图。
updated_graph = net(graph)
示例
对于更深入的了解,最好从示例开始。特别是
examples/basic.py
提供了对库功能的介绍。ogb_examples/train.py
提供了一个在molhiv
Open Graph Benchmark数据集上训练GraphNet
的端到端示例。请注意,您需要下载该数据集才能运行此示例。
其余的示例是短脚本,展示了如何使用我们的模型库中的各种模型,以及如何使用jax.jit
让模型运行得更快,以及如何处理Jax的静态形状要求。
引用Jraph
要引用此存储库
@software{jraph2020github,
author = {Jonathan Godwin* and Thomas Keck* and Peter Battaglia and Victor Bapst and Thomas Kipf and Yujia Li and Kimberly Stachenfeld and Petar Veli\v{c}kovi\'{c} and Alvaro Sanchez-Gonzalez},
title = {{J}raph: {A} library for graph neural networks in jax.},
url = {http://github.com/deepmind/jraph},
version = {0.0.1.dev},
year = {2020},
}
项目详情
下载文件
下载适合您平台的文件。如果您不确定选择哪个,请了解更多关于安装包的信息。
源分发
构建分发
jraph-0.0.6.dev0.tar.gz的散列
算法 | 散列摘要 | |
---|---|---|
SHA256 | c3ac3a0b224b344eb6d367e8bc312d95ea41bf825d01ea31b80dd8c22c0dd8b8 |
|
MD5 | b8687f1e7abf09cf468a760e4eb59d45 |
|
BLAKE2b-256 | 8c7768d90dbc44c0b51aaa774b47b51f8568062452adf91e67b5e50f6711c981 |
jraph-0.0.6.dev0-py3-none-any.whl的散列
算法 | 散列摘要 | |
---|---|---|
SHA256 | 350fe37bf717f934f1f84fd3370a480b3178bfcb61dfa217c738971308c57625 |
|
MD5 | c0424577ffffb84dc8e5e90304beed1b |
|
BLAKE2b-256 | 2ae2f799edeb39a154560b52134cdb3a3359e2de965c76886949966e46d5c42b |