JAX中的梯度处理和优化库。
项目描述
Optax
简介
Optax是JAX的梯度处理和优化库。
Optax旨在通过提供易于以自定义方式重新组合的构建块,以促进研究。
我们的目标是
- 提供简单、经过良好测试、高效的核组件实现。
- 通过允许轻松将低级成分组合成自定义优化器(或其他梯度处理组件)来提高研究生产力。
- 通过使任何人都能轻松贡献,加速新想法的采用。
我们倾向于关注可以有效地组合成自定义解决方案的小型可组合构建块。其他人可以在这些基本组件上构建更复杂的抽象。在合理的情况下,实现优先考虑可读性和将代码结构化以匹配标准方程,而不是代码重用。
此库的初始原型作为jax.experimental.optix
在JAX的实验文件夹中提供。鉴于optix
在DeepMind中的广泛应用,以及API经过几轮迭代后,optix
最终从experimental
中移出,作为一个独立的开源库,并更名为optax
。
Optax的文档可以在optax.readthedocs.io找到。
安装
您可以通过以下方式从PyPI安装Optax的最新发布版本:
pip install optax
或者您可以从GitHub安装最新开发版本
pip install git+https://github.com/google-deepmind/optax.git
快速入门
Optax包含许多流行优化器的实现,例如许多流行优化器和损失函数。例如,以下代码片段使用来自optax.adam
的Adam优化器和来自optax.l2_loss
的均方误差。我们使用模型的init
函数和params
初始化优化器状态。
optimizer = optax.adam(learning_rate)
# Obtain the `opt_state` that contains statistics for the optimizer.
params = {'w': jnp.ones((num_weights,))}
opt_state = optimizer.init(params)
要编写更新循环,我们需要一个可以由Jax(在本例中使用jax.grad
)求导的损失函数以获得梯度。
compute_loss = lambda params, x, y: optax.l2_loss(params['w'].dot(x), y)
grads = jax.grad(compute_loss)(params, xs, ys)
然后,通过optimizer.update
将梯度转换为应用于当前参数以获得新参数的更新。optax.apply_updates
是一个方便的实用工具来完成此操作。
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
您可以在Optax 🚀 快速入门笔记本中继续快速入门。
开发
我们欢迎新的贡献者。
源代码
您可以使用以下命令检查最新源代码。
git clone https://github.com/google-deepmind/optax.git
测试
要运行测试,请执行以下脚本。
sh ./test.sh
文档
要构建文档,首先确保已安装所有依赖项。
pip install -e ".[docs]"
然后,执行以下操作。
cd docs/
make html
基准测试
如果您在深度学习的众多优化器中感到迷茫,存在一些广泛的基准测试
《神经网络训练算法基准测试》,Dahl G. et al,2023,
《通过拥挤山谷下降——深度学习优化器的基准测试》,Schmidt R. et al,2021.
如果您有兴趣为某些任务开发自己的基准测试,请考虑以下框架
Benchopt:可重复、高效且协作的优化基准测试,Moreau T. et al,2022.
最后,如果您正在寻找有关调整优化器的建议,请考虑查看以下内容
《深度学习调整手册》,Godbole V. et al,2023.
引用Optax
此存储库是DeepMind JAX生态系统的一部分,要引用Optax,请使用以下引用
@software{deepmind2020jax,
title = {The {D}eep{M}ind {JAX} {E}cosystem},
author = {DeepMind and Babuschkin, Igor and Baumli, Kate and Bell, Alison and Bhupatiraju, Surya and Bruce, Jake and Buchlovsky, Peter and Budden, David and Cai, Trevor and Clark, Aidan and Danihelka, Ivo and Dedieu, Antoine and Fantacci, Claudio and Godwin, Jonathan and Jones, Chris and Hemsley, Ross and Hennigan, Tom and Hessel, Matteo and Hou, Shaobo and Kapturowski, Steven and Keck, Thomas and Kemaev, Iurii and King, Michael and Kunesch, Markus and Martens, Lena and Merzic, Hamza and Mikulik, Vladimir and Norman, Tamara and Papamakarios, George and Quan, John and Ring, Roman and Ruiz, Francisco and Sanchez, Alvaro and Sartran, Laurent and Schneider, Rosalia and Sezener, Eren and Spencer, Stephen and Srinivasan, Srivatsan and Stanojevi\'{c}, Milo\v{s} and Stokowiec, Wojciech and Wang, Luyu and Zhou, Guangyao and Viola, Fabio},
url = {http://github.com/google-deepmind},
year = {2020},
}
项目详细信息
下载文件
下载适用于您平台的文件。如果您不确定选择哪一个,请了解更多关于 安装包 的信息。
源代码分布
构建分布
optax-0.2.3.tar.gz 的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | ec7ab925440b0c5a512e1f24fba0fb3e7d760a7fd5d2496d7a691e9d37da01d9 |
|
MD5 | 42ef1896a4646ec2a0d4b3e4ed32b807 |
|
BLAKE2b-256 | d65fe8b09028b37a8c1c159359e59469f3504b550910d472d8ee59543b1735d9 |
optax-0.2.3-py3-none-any.whl 的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | 083e603dcd731d7e74d99f71c12f77937dd53f79001b4c09c290e4f47dd2e94f |
|
MD5 | d66a5772aa1cae304539fdf08be8d3b1 |
|
BLAKE2b-256 | a38b7032a6788205e9da398a8a33e1030ee9a22bd9289126e5afed9aac33bcde |