跳转到主要内容

Sonnet是一个用于在TensorFlow中构建神经网络的库。

项目描述

Sonnet

Sonnet

文档 | 示例

Sonnet是一个基于TensorFlow 2构建的库,旨在为机器学习研究提供简单、可组合的抽象。

简介

Sonnet是由DeepMind的研究人员设计和构建的。它可以用于构建各种目的的神经网络(无/监督学习、强化学习等)。我们发现它对我们组织来说是一个成功的抽象,您也可能如此!

更具体地说,Sonnet提供了一种简单但强大的编程模型,其核心概念是snt.Module。模块可以持有对参数、其他模块以及应用某些函数于用户输入的方法的引用。Sonnet附带了许多预定义模块(例如snt.Linearsnt.Conv2Dsnt.BatchNorm)和一些预定义的模块网络(例如snt.nets.MLP),但同时也鼓励用户构建自己的模块。

与许多框架不同,Sonnet在如何使用您的模块方面几乎没有偏见。模块被设计成自包含且完全解耦。Sonnet没有附带训练框架,并鼓励用户构建自己的或采用他人的框架。

Sonnet还设计得易于理解,我们的代码(希望!)清晰且专注。在我们选择默认值的地方(例如,初始化参数值的默认值),我们尝试指出原因。

入门指南

示例

尝试Sonnet的最简单方法是使用Google Colab,它提供了一个免费连接GPU或TPU的Python笔记本。

安装

要开始,请安装TensorFlow 2.0和Sonnet 2

$ pip install tensorflow tensorflow-probability
$ pip install dm-sonnet

您可以使用以下命令验证是否正确安装

import tensorflow as tf
import sonnet as snt

print("TensorFlow version {}".format(tf.__version__))
print("Sonnet version {}".format(snt.__version__))

使用现有模块

Sonnet附带了一些内置模块,您可以轻松使用。例如,要定义一个MLP,我们可以使用snt.Sequential模块调用一系列模块,将给定模块的输出作为下一个模块的输入。我们可以使用snt.Lineartf.nn.relu来实际定义我们的计算

mlp = snt.Sequential([
    snt.Linear(1024),
    tf.nn.relu,
    snt.Linear(10),
])

要使用我们的模块,我们需要“调用”它。Sequential模块(以及大多数模块)定义了一个__call__方法,这意味着您可以通过名称调用它们

logits = mlp(tf.random.normal([batch_size, input_size]))

通常,您还会要求获取您模块的所有参数。大多数Sonnet模块在第一次调用某个输入时创建其参数(因为在大多数情况下,参数的形状是输入的一个函数)。Sonnet模块提供了两个属性来访问参数。

variables属性返回给定模块引用的所有tf.Variable

all_variables = mlp.variables

值得注意的是,tf.Variable不仅用于模型的参数。例如,它们用于在snt.BatchNorm中使用的度量中保留状态。在大多数情况下,用户检索模块变量,将它们传递给优化器以进行更新。在这种情况下,非可训练变量通常不应在该列表中,因为它们通过不同的机制进行更新。TensorFlow有一个内置机制来标记变量为“可训练”(模型的参数)或“不可训练”(其他变量)。Sonnet提供了一个机制来收集您模块中的所有可训练变量,这可能是您想要传递给优化器的变量

model_parameters = mlp.trainable_variables

构建您自己的模块

Sonnet强烈鼓励用户通过子类化snt.Module来定义自己的模块。让我们首先创建一个简单的名为MyLinearLinear

class MyLinear(snt.Module):

  def __init__(self, output_size, name=None):
    super(MyLinear, self).__init__(name=name)
    self.output_size = output_size

  @snt.once
  def _initialize(self, x):
    initial_w = tf.random.normal([x.shape[1], self.output_size])
    self.w = tf.Variable(initial_w, name="w")
    self.b = tf.Variable(tf.zeros([self.output_size]), name="b")

  def __call__(self, x):
    self._initialize(x)
    return tf.matmul(x, self.w) + self.b

使用此模块很简单

mod = MyLinear(32)
mod(tf.ones([batch_size, input_size]))

通过子类化snt.Module,您将获得许多有用的特性。例如,一个默认的__repr__实现,该实现显示构造函数参数(这对于调试和自省非常有用)

>>> print(repr(mod))
MyLinear(output_size=10)

您还将获得variablestrainable_variables属性

>>> mod.variables
(<tf.Variable 'my_linear/b:0' shape=(10,) ...)>,
 <tf.Variable 'my_linear/w:0' shape=(1, 10) ...)>)

您可能会注意到上面变量中的my_linear前缀。这是因为Sonnet模块在调用方法时也会进入模块名称作用域。通过进入模块名称作用域,我们为像TensorBoard这样的工具提供了更有用的图,例如,所有在my_linear内部发生的操作都将在一个名为my_linear的组中。

此外,您的模块现在将支持TensorFlow检查点和保存模型等高级功能,这些功能将在稍后介绍。

序列化

Sonnet支持多种序列化格式。我们支持的最简单格式是Python的pickle,所有内置模块都已测试以确保它们可以在同一个Python进程中通过pickle进行保存/加载。一般来说,我们不鼓励使用pickle,它在TensorFlow的许多部分中支持不佳,并且根据我们的经验可能相当脆弱。

TensorFlow检查点

参考: https://tensorflowcn.cn/alpha/guide/checkpoints

TensorFlow检查点可以在训练过程中定期保存参数的值。这可以在程序崩溃或停止时保存训练进度。Sonnet设计用于与TensorFlow检查点干净地协同工作

checkpoint_root = "/tmp/checkpoints"
checkpoint_name = "example"
save_prefix = os.path.join(checkpoint_root, checkpoint_name)

my_module = create_my_sonnet_module()  # Can be anything extending snt.Module.

# A `Checkpoint` object manages checkpointing of the TensorFlow state associated
# with the objects passed to it's constructor. Note that Checkpoint supports
# restore on create, meaning that the variables of `my_module` do **not** need
# to be created before you restore from a checkpoint (their value will be
# restored when they are created).
checkpoint = tf.train.Checkpoint(module=my_module)

# Most training scripts will want to restore from a checkpoint if one exists. This
# would be the case if you interrupted your training (e.g. to use your GPU for
# something else, or in a cloud environment if your instance is preempted).
latest = tf.train.latest_checkpoint(checkpoint_root)
if latest is not None:
  checkpoint.restore(latest)

for step_num in range(num_steps):
  train(my_module)

  # During training we will occasionally save the values of weights. Note that
  # this is a blocking call and can be slow (typically we are writing to the
  # slowest storage on the machine). If you have a more reliable setup it might be
  # appropriate to save less frequently.
  if step_num and not step_num % 1000:
    checkpoint.save(save_prefix)

# Make sure to save your final values!!
checkpoint.save(save_prefix)

TensorFlow保存模型

参考: https://tensorflowcn.cn/alpha/guide/saved_model

TensorFlow保存模型可以用来保存一个与Python源代码分离的网络的副本。这是通过保存描述计算的TensorFlow图和包含权重值的检查点来实现的。

创建保存模型的第一步是创建一个要保存的snt.Module

my_module = snt.nets.MLP([1024, 1024, 10])
my_module(tf.ones([1, input_size]))

接下来,我们需要创建另一个模块来描述我们想要导出的模型的特定部分。我们建议这样做(而不是就地修改原始模型),以便您可以精细控制实际导出的内容。这通常很重要,可以避免创建非常大的保存模型,并且仅共享您想要共享的模型部分(例如,您只想共享GAN的生成器,但保留判别器为私有)。

@tf.function(input_signature=[tf.TensorSpec([None, input_size])])
def inference(x):
  return my_module(x)

to_save = snt.Module()
to_save.inference = inference
to_save.all_variables = list(my_module.variables)
tf.saved_model.save(to_save, "/tmp/example_saved_model")

现在我们在/tmp/example_saved_model文件夹中有一个保存模型

$ ls -lh /tmp/example_saved_model
total 24K
drwxrwsr-t 2 tomhennigan 154432098 4.0K Apr 28 00:14 assets
-rw-rw-r-- 1 tomhennigan 154432098  14K Apr 28 00:15 saved_model.pb
drwxrwsr-t 2 tomhennigan 154432098 4.0K Apr 28 00:15 variables

加载此模型很简单,可以在不同的机器上完成,而无需任何构建保存模型的Python代码

loaded = tf.saved_model.load("/tmp/example_saved_model")

# Use the inference method. Note this doesn't run the Python code from `to_save`
# but instead uses the TensorFlow Graph that is part of the saved model.
loaded.inference(tf.ones([1, input_size]))

# The all_variables property can be used to retrieve the restored variables.
assert len(loaded.all_variables) > 0

请注意,加载的对象不是一个Sonnet模块,而是一个容器对象,它具有我们在之前部分中添加的特定方法(例如inference)和属性(例如all_variables)。

分布式训练

示例: https://github.com/deepmind/sonnet/blob/v2/examples/distributed_cifar10.ipynb

Sonnet支持使用自定义TensorFlow分布策略进行分布式训练。

与使用tf.keras进行的分布式训练相比,Sonnet模块和优化器在分布策略下运行时表现不同(例如,我们不会平均您的梯度或同步您的批归一化统计信息)。我们相信用户应该完全控制他们训练的这些方面,并且它们不应该被嵌入到库中。这里的权衡是您需要在训练脚本中实现这些功能(通常只需2行代码即可在应用优化器之前对梯度进行全量归约)或交换显式具有分布意识的模块(例如snt.distribute.CrossReplicaBatchNorm)。

我们的分布式Cifar-10示例介绍了如何使用Sonnet进行多GPU训练。

项目详情


下载文件

下载适用于您平台的文件。如果您不确定选择哪个,请了解有关安装包的更多信息。

源代码分发

dm-sonnet-2.0.2.tar.gz (165.1 kB 查看哈希值)

上传时间 源代码

构建分发

dm_sonnet-2.0.2-py3-none-any.whl (268.4 kB 查看哈希值)

上传时间 Python 3

支持

AWSAWS 云计算和安全赞助商 DatadogDatadog 监控 FastlyFastly CDN GoogleGoogle 下载分析 MicrosoftMicrosoft PSF赞助商 PingdomPingdom 监控 SentrySentry 错误日志 StatusPageStatusPage 状态页面