Sonnet是一个用于在TensorFlow中构建神经网络的库。
项目描述
Sonnet
Sonnet是一个基于TensorFlow 2构建的库,旨在为机器学习研究提供简单、可组合的抽象。
简介
Sonnet是由DeepMind的研究人员设计和构建的。它可以用于构建各种目的的神经网络(无/监督学习、强化学习等)。我们发现它对我们组织来说是一个成功的抽象,您也可能如此!
更具体地说,Sonnet提供了一种简单但强大的编程模型,其核心概念是snt.Module
。模块可以持有对参数、其他模块以及应用某些函数于用户输入的方法的引用。Sonnet附带了许多预定义模块(例如snt.Linear
、snt.Conv2D
、snt.BatchNorm
)和一些预定义的模块网络(例如snt.nets.MLP
),但同时也鼓励用户构建自己的模块。
与许多框架不同,Sonnet在如何使用您的模块方面几乎没有偏见。模块被设计成自包含且完全解耦。Sonnet没有附带训练框架,并鼓励用户构建自己的或采用他人的框架。
Sonnet还设计得易于理解,我们的代码(希望!)清晰且专注。在我们选择默认值的地方(例如,初始化参数值的默认值),我们尝试指出原因。
入门指南
示例
尝试Sonnet的最简单方法是使用Google Colab,它提供了一个免费连接GPU或TPU的Python笔记本。
安装
要开始,请安装TensorFlow 2.0和Sonnet 2
$ pip install tensorflow tensorflow-probability
$ pip install dm-sonnet
您可以使用以下命令验证是否正确安装
import tensorflow as tf
import sonnet as snt
print("TensorFlow version {}".format(tf.__version__))
print("Sonnet version {}".format(snt.__version__))
使用现有模块
Sonnet附带了一些内置模块,您可以轻松使用。例如,要定义一个MLP,我们可以使用snt.Sequential
模块调用一系列模块,将给定模块的输出作为下一个模块的输入。我们可以使用snt.Linear
和tf.nn.relu
来实际定义我们的计算
mlp = snt.Sequential([
snt.Linear(1024),
tf.nn.relu,
snt.Linear(10),
])
要使用我们的模块,我们需要“调用”它。Sequential
模块(以及大多数模块)定义了一个__call__
方法,这意味着您可以通过名称调用它们
logits = mlp(tf.random.normal([batch_size, input_size]))
通常,您还会要求获取您模块的所有参数。大多数Sonnet模块在第一次调用某个输入时创建其参数(因为在大多数情况下,参数的形状是输入的一个函数)。Sonnet模块提供了两个属性来访问参数。
variables
属性返回给定模块引用的所有tf.Variable
all_variables = mlp.variables
值得注意的是,tf.Variable
不仅用于模型的参数。例如,它们用于在snt.BatchNorm
中使用的度量中保留状态。在大多数情况下,用户检索模块变量,将它们传递给优化器以进行更新。在这种情况下,非可训练变量通常不应在该列表中,因为它们通过不同的机制进行更新。TensorFlow有一个内置机制来标记变量为“可训练”(模型的参数)或“不可训练”(其他变量)。Sonnet提供了一个机制来收集您模块中的所有可训练变量,这可能是您想要传递给优化器的变量
model_parameters = mlp.trainable_variables
构建您自己的模块
Sonnet强烈鼓励用户通过子类化snt.Module
来定义自己的模块。让我们首先创建一个简单的名为MyLinear
的Linear
层
class MyLinear(snt.Module):
def __init__(self, output_size, name=None):
super(MyLinear, self).__init__(name=name)
self.output_size = output_size
@snt.once
def _initialize(self, x):
initial_w = tf.random.normal([x.shape[1], self.output_size])
self.w = tf.Variable(initial_w, name="w")
self.b = tf.Variable(tf.zeros([self.output_size]), name="b")
def __call__(self, x):
self._initialize(x)
return tf.matmul(x, self.w) + self.b
使用此模块很简单
mod = MyLinear(32)
mod(tf.ones([batch_size, input_size]))
通过子类化snt.Module
,您将获得许多有用的特性。例如,一个默认的__repr__
实现,该实现显示构造函数参数(这对于调试和自省非常有用)
>>> print(repr(mod))
MyLinear(output_size=10)
您还将获得variables
和trainable_variables
属性
>>> mod.variables
(<tf.Variable 'my_linear/b:0' shape=(10,) ...)>,
<tf.Variable 'my_linear/w:0' shape=(1, 10) ...)>)
您可能会注意到上面变量中的my_linear
前缀。这是因为Sonnet模块在调用方法时也会进入模块名称作用域。通过进入模块名称作用域,我们为像TensorBoard这样的工具提供了更有用的图,例如,所有在my_linear内部发生的操作都将在一个名为my_linear的组中。
此外,您的模块现在将支持TensorFlow检查点和保存模型等高级功能,这些功能将在稍后介绍。
序列化
Sonnet支持多种序列化格式。我们支持的最简单格式是Python的pickle
,所有内置模块都已测试以确保它们可以在同一个Python进程中通过pickle进行保存/加载。一般来说,我们不鼓励使用pickle,它在TensorFlow的许多部分中支持不佳,并且根据我们的经验可能相当脆弱。
TensorFlow检查点
参考: https://tensorflowcn.cn/alpha/guide/checkpoints
TensorFlow检查点可以在训练过程中定期保存参数的值。这可以在程序崩溃或停止时保存训练进度。Sonnet设计用于与TensorFlow检查点干净地协同工作
checkpoint_root = "/tmp/checkpoints"
checkpoint_name = "example"
save_prefix = os.path.join(checkpoint_root, checkpoint_name)
my_module = create_my_sonnet_module() # Can be anything extending snt.Module.
# A `Checkpoint` object manages checkpointing of the TensorFlow state associated
# with the objects passed to it's constructor. Note that Checkpoint supports
# restore on create, meaning that the variables of `my_module` do **not** need
# to be created before you restore from a checkpoint (their value will be
# restored when they are created).
checkpoint = tf.train.Checkpoint(module=my_module)
# Most training scripts will want to restore from a checkpoint if one exists. This
# would be the case if you interrupted your training (e.g. to use your GPU for
# something else, or in a cloud environment if your instance is preempted).
latest = tf.train.latest_checkpoint(checkpoint_root)
if latest is not None:
checkpoint.restore(latest)
for step_num in range(num_steps):
train(my_module)
# During training we will occasionally save the values of weights. Note that
# this is a blocking call and can be slow (typically we are writing to the
# slowest storage on the machine). If you have a more reliable setup it might be
# appropriate to save less frequently.
if step_num and not step_num % 1000:
checkpoint.save(save_prefix)
# Make sure to save your final values!!
checkpoint.save(save_prefix)
TensorFlow保存模型
参考: https://tensorflowcn.cn/alpha/guide/saved_model
TensorFlow保存模型可以用来保存一个与Python源代码分离的网络的副本。这是通过保存描述计算的TensorFlow图和包含权重值的检查点来实现的。
创建保存模型的第一步是创建一个要保存的snt.Module
。
my_module = snt.nets.MLP([1024, 1024, 10])
my_module(tf.ones([1, input_size]))
接下来,我们需要创建另一个模块来描述我们想要导出的模型的特定部分。我们建议这样做(而不是就地修改原始模型),以便您可以精细控制实际导出的内容。这通常很重要,可以避免创建非常大的保存模型,并且仅共享您想要共享的模型部分(例如,您只想共享GAN的生成器,但保留判别器为私有)。
@tf.function(input_signature=[tf.TensorSpec([None, input_size])])
def inference(x):
return my_module(x)
to_save = snt.Module()
to_save.inference = inference
to_save.all_variables = list(my_module.variables)
tf.saved_model.save(to_save, "/tmp/example_saved_model")
现在我们在/tmp/example_saved_model
文件夹中有一个保存模型
$ ls -lh /tmp/example_saved_model
total 24K
drwxrwsr-t 2 tomhennigan 154432098 4.0K Apr 28 00:14 assets
-rw-rw-r-- 1 tomhennigan 154432098 14K Apr 28 00:15 saved_model.pb
drwxrwsr-t 2 tomhennigan 154432098 4.0K Apr 28 00:15 variables
加载此模型很简单,可以在不同的机器上完成,而无需任何构建保存模型的Python代码
loaded = tf.saved_model.load("/tmp/example_saved_model")
# Use the inference method. Note this doesn't run the Python code from `to_save`
# but instead uses the TensorFlow Graph that is part of the saved model.
loaded.inference(tf.ones([1, input_size]))
# The all_variables property can be used to retrieve the restored variables.
assert len(loaded.all_variables) > 0
请注意,加载的对象不是一个Sonnet模块,而是一个容器对象,它具有我们在之前部分中添加的特定方法(例如inference
)和属性(例如all_variables
)。
分布式训练
示例: https://github.com/deepmind/sonnet/blob/v2/examples/distributed_cifar10.ipynb
Sonnet支持使用自定义TensorFlow分布策略进行分布式训练。
与使用tf.keras
进行的分布式训练相比,Sonnet模块和优化器在分布策略下运行时表现不同(例如,我们不会平均您的梯度或同步您的批归一化统计信息)。我们相信用户应该完全控制他们训练的这些方面,并且它们不应该被嵌入到库中。这里的权衡是您需要在训练脚本中实现这些功能(通常只需2行代码即可在应用优化器之前对梯度进行全量归约)或交换显式具有分布意识的模块(例如snt.distribute.CrossReplicaBatchNorm
)。
我们的分布式Cifar-10示例介绍了如何使用Sonnet进行多GPU训练。
项目详情
下载文件
下载适用于您平台的文件。如果您不确定选择哪个,请了解有关安装包的更多信息。