跳转到主要内容

Jraph:Jax中图神经网络的库

项目描述

logo

Jraph - Jax中图神经网络的库。

新功能!Pmap示例和数据加载。

我们添加了一个pmap 示例

我们的朋友instadeep、Jama Hussein Mohamud和Tom Makkink编写了一篇关于使用PyTorch数据加载的精美指南。您可以在这里找到它 这里

新功能!支持大型分布式MPNN

我们发布了一个分布式图网络实现,允许您在多个设备上跨多个设备分发一个非常大的(数百万条边)图网络,具有显式的边消息。 查看它!

新功能!交互式Jraph Colabs

我们有两个新的colabs来帮助您掌握Jraph。

首先是与教育相关的colab,以令人惊叹的介绍开始了图神经网络和图理论,展示了如何使用Jraph解决一系列问题。查看这里

第二项是一个完整工作的示例,展示了使用OGBG-MOLPCBA的Jraph最佳实践,并附有精美的可视化。查看这里

感谢Lisa Wang、Nikola Jovanović和Ameya Daigavane。

快速入门

快速入门|文档

Jraph(发音为“giraffe”)是一个轻量级的库,用于在jax中处理图神经网络。它提供了一种图的数据结构,一组用于处理图的实用工具,以及一组可分叉的图神经网络模型。

安装

pip安装jraph

或者您可以使用以下命令直接从GitHub安装Jraph

pip install git+git://github.com/deepmind/jraph.git

示例需要额外的依赖项。要安装它们,请运行

pip install "jraph[examples, ogb_examples] @ git+git://github.com/deepmind/jraph.git"

概述

Jraph旨在为在jax中处理图提供实用工具,但不规定编写或开发图神经网络的方式。

  • graph.py提供了一种轻量级的数据结构,GraphsTuple,用于处理图。
  • utils.py提供了一组用于在jax中处理GraphsTuples的实用工具。
    • 用于批处理GraphsTuples数据集的实用工具。
    • 通过填充和掩码支持可变形状图的jit编译的实用工具。
    • 用于定义输入分区上的损失的实用工具。
  • models.py提供了不同类型的图神经网络消息传递的示例。这些设计轻量级、易于分叉和适应。它们不为您管理参数 - 对于此目的,请考虑使用haikuflax。请参阅示例以获取更多详细信息。

快速入门

Jraph在定义GraphsTuple数据结构时受到了Tensorflow graph_nets库的启发,它是一个包含一个或多个有向图的namedtuple。

表示图 - GraphsTuple

import jraph
import jax.numpy as jnp

# Define a three node graph, each node has an integer as its feature.
node_features = jnp.array([[0.], [1.], [2.]])

# We will construct a graph for which there is a directed edge between each node
# and its successor. We define this with `senders` (source nodes) and `receivers`
# (destination nodes).
senders = jnp.array([0, 1, 2])
receivers = jnp.array([1, 2, 0])

# You can optionally add edge attributes.
edges = jnp.array([[5.], [6.], [7.]])

# We then save the number of nodes and the number of edges.
# This information is used to make running GNNs over multiple graphs
# in a GraphsTuple possible.
n_node = jnp.array([3])
n_edge = jnp.array([3])

# Optionally you can add `global` information, such as a graph label.

global_context = jnp.array([[1]])
graph = jraph.GraphsTuple(nodes=node_features, senders=senders, receivers=receivers,
edges=edges, n_node=n_node, n_edge=n_edge, globals=global_context)

GraphsTuple可以包含多个图。

two_graph_graphstuple = jraph.batch([graph, graph])

节点和边特征堆叠在主轴上。

jraph.batch([graph, graph]).nodes
>>> DeviceArray([[0.],
             [1.],
             [2.],
             [0.],
             [1.],
             [2.]], dtype=float32)

您可以通过查看n_node来判断哪些节点来自哪个图。

jraph.batch([graph, graph]).n_node
>>> DeviceArray([3, 3], dtype=int32)

您可以在nodesedgesglobals中存储特征嵌套。这使得为每个节点、边或图存储多组特征成为可能,这些特征可能具有不同类型和不同的语义意义(例如“训练”和“测试”节点)。唯一的要求是,每个嵌套中的所有数组都必须具有共同的领先维度大小,与Graphstuple中节点、边或图的总数相匹配。

node_targets = jnp.array([[True], [False], [True]])
graph = graph._replace(nodes={'inputs': graph.nodes, 'targets': node_targets})

使用模型库

Jraph提供了一套实现的参考模型供您使用。

Jraph模型定义了图节点、边和全局属性之间的消息传递算法。用户定义update函数来更新图特征,这些通常是神经网络,但也可以是任意的jax函数。

让我们通过一个GraphNetwork (论文)示例来了解。GraphNet的第一个更新函数使用edge特征、senderreceiver的节点特征以及global特征来更新边。

# As one example, we just pass the edge features straight through.
def update_edge_fn(edge, sender, receiver, globals_):
  return edge

我们通常使用这些特征的连接,而jraph提供了一个使用concatenated_args装饰器的简单方法来完成这项工作。

@jraph.concatenated_args
def update_edge_fn(concatenated_features):
  return concatenated_features

通常,在更新函数中,我们会使用如多层感知器这样的学习模型。

用户可以像定义更新节点和全局变量的函数一样定义这些函数。然后,使用这些函数来配置一个GraphNetwork。要查看节点和全局update_fns的参数,请参阅模型库。

net = jraph.GraphNetwork(update_edge_fn=update_edge_fn,
                         update_node_fn=update_node_fn,
                         update_global_fn=update_global_fn)

net是一个根据GraphNetwork算法发送消息并应用update_fn的函数。它接收一个图,并返回一个图。

updated_graph = net(graph)

示例

对于更深入的了解,最好从示例开始。特别是

  • examples/basic.py提供了对库功能的介绍。
  • ogb_examples/train.py提供了一个在molhiv Open Graph Benchmark数据集上训练GraphNet的端到端示例。请注意,您需要下载该数据集才能运行此示例。

其余的示例是短脚本,展示了如何使用我们的模型库中的各种模型,以及如何使用jax.jit让模型运行得更快,以及如何处理Jax的静态形状要求。

引用Jraph

要引用此存储库

@software{jraph2020github,
  author = {Jonathan Godwin* and Thomas Keck* and Peter Battaglia and Victor Bapst and Thomas Kipf and Yujia Li and Kimberly Stachenfeld and Petar Veli\v{c}kovi\'{c} and Alvaro Sanchez-Gonzalez},
  title = {{J}raph: {A} library for graph neural networks in jax.},
  url = {http://github.com/deepmind/jraph},
  version = {0.0.1.dev},
  year = {2020},
}

项目详情


下载文件

下载适合您平台的文件。如果您不确定选择哪个,请了解更多关于安装包的信息。

源分发

jraph-0.0.6.dev0.tar.gz (68.4 kB 查看散列)

上传时间

构建分发

jraph-0.0.6.dev0-py3-none-any.whl (90.6 kB 查看散列)

上传时间 Python 3

支持者

AWS AWS 云计算和安全赞助商 Datadog Datadog 监控 Fastly Fastly CDN Google Google 下载分析 Microsoft Microsoft PSF 赞助商 Pingdom Pingdom 监控 Sentry Sentry 错误记录 StatusPage StatusPage 状态页面