将张量从Jax转换为Torch以及相反方向的工具
项目描述
Torch <-> Jax 互操作工具
嘿,你在那里!
- 你使用PyTorch,但好奇Jax(或反之亦然)吗?你愿意逐渐将一些(Jax/PyTorch)添加到你的项目中,而不是从头开始吗?
- 想要避免将现有PyTorch代码库中的模型重写为Jax(或反之亦然)的痛苦吗?
- 你喜欢Jax的性能优势,但不想放弃你喜欢的PyTorch软件框架(例如 Lightning)吗?
那么,我有一些好消息要告诉你! 你可以拥有一切:来自Jax的甜美的jit函数和自动微分,以及来自PyTorch软件生态系统的成熟、广泛使用的框架。
这个工具的作用
此软件包包含一些简化Jax和Torch之间互操作性的实用函数:torch_to_jax
、jax_to_torch
、WrappedJaxFunction
、torch_module_to_jax
。
此存储库包含将PyTorch张量转换为JAX数组和相反方向的工具。这种转换是通过dlpack
格式实现的,这是在不同深度学习框架之间交换张量的通用格式。关键的是,该格式允许PyTorch和JAX之间进行零拷贝的*张量共享。
* 注意:对于某些具有特定内存布局的torch张量(例如通道优先的图像张量),Jax将拒绝从dlpack中读取数组,因此在转换时需要对数据进行扁平化和展开,这可能涉及复制。这目前在命令行上显示为警告。
安装
pip install torch-jax-interop
类似的项目
- https://github.com/lucidrains/jax2torch:这似乎是这个类型的第一版最小原型。支持jax2torch函数,但不支持相反的操作。
- https://github.com/subho406/pytorch2jax:非常相似。我们将
torch.nn.Module
转换为jax.custom_vjp
的方式实际上是基于它们的实现,增加了一些功能(支持jitting,以及更灵活的输入/输出签名)。 - https://github.com/samuela/torch2jax:采用不同的方法:使用
torch.Tensor
子类和__torch_fuction__
。 - https://github.com/rdyro/torch2jax:刚刚找到这个,看起来对torch到jax的转换支持很好,但反过来不行。具有一些附加功能,如指定深度(导数的级别)。
用法
import torch
import jax.numpy as jnp
from torch_jax_interop import jax_to_torch, torch_to_jax
将torch.Tensor
转换为jax.Array
import jax
import torch
tensors = {
"x": torch.randn(5),
"y": torch.arange(5),
}
jax_arrays = jax.tree.map(torch_to_jax, tensors)
torch_tensors = jax.tree.map(jax_to_torch, jax_arrays)
将torch.Tensors传递给Jax函数
@jax_to_torch
def some_jax_function(x: jnp.ndarray) -> jnp.ndarray:
return x + jnp.ones_like(x)
torch_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
some_torch_tensor = torch.arange(5, device=device)
torch_output = some_jax_function(some_torch_tensor)
some_jax_array = jnp.arange(5)
@torch_to_jax
def some_torch_function(x: torch.Tensor) -> torch.Tensor:
return x + torch.ones_like(x)
print(some_torch_function(some_jax_array))
示例
Jax到Torch nn.Module
假设我们有一些希望在PyTorch模型中使用的jax函数
import jax
import jax.numpy as jnp
def some_jax_function(params: jax.Array, x: jax.Array):
'''Some toy function that takes in some parameters and an input vector.'''
return jnp.dot(x, params)
通过导入这个
from torch_jax_interop import WrappedJaxFunction
我们可以将这个jax函数包装成一个具有可学习参数的torch.nn.Module
import torch
import torch.nn
module = WrappedJaxFunction(some_jax_function, jax.random.normal(jax.random.key(0), (2, 1)))
module = module.to("cpu") # jax arrays are on GPU by default, moving them to CPU for this example.
这些参数现在是模块参数的可学习参数
dict(module.state_dict())
{'params.0': tensor([[-0.7848],
[ 0.8564]])}
您可以使用它就像其他任何torch.nn.Module一样
x, y = torch.randn(2), torch.rand(1)
output = module(x)
loss = torch.nn.functional.mse_loss(output, y)
loss.backward()
model = torch.nn.Sequential(
torch.nn.Linear(123, 2),
module,
)
同样,对于flax.linen.Module
,您现在可以在torch forward / backward过程中使用它们
import flax.linen
class Classifier(flax.linen.Module):
num_classes: int = 10
@flax.linen.compact
def __call__(self, x: jax.Array):
x = x.reshape((x.shape[0], -1)) # flatten
x = flax.linen.Dense(features=256)(x)
x = flax.linen.relu(x)
x = flax.linen.Dense(features=self.num_classes)(x)
return x
jax_module = Classifier(num_classes=10)
jax_params = jax_module.init(jax.random.key(0), x)
from torch_jax_interop import WrappedJaxFunction
torch_module = WrappedJaxFunction(jax.jit(jax_module.apply), jax_params)
Torch nn.Module到jax函数
>>> import torch
>>> import jax
>>> model = torch.nn.Linear(3, 2, device="cuda")
>>> apply_fn, params = torch_module_to_jax(model)
>>> def loss_function(params, x: jax.Array, y: jax.Array) -> jax.Array:
... y_pred = apply_fn(params, x)
... return jax.numpy.mean((y - y_pred) ** 2)
>>> x = jax.random.uniform(key=jax.random.key(0), shape=(1, 3))
>>> y = jax.random.uniform(key=jax.random.key(1), shape=(1, 1))
>>> loss, grad = jax.value_and_grad(loss_function)(params, x, y)
>>> loss
Array(0.3944674, dtype=float32)
>>> grad
(Array([[-0.46541408, -0.15171866, -0.30520514],
[-0.7201077 , -0.23474531, -0.47222584]], dtype=float32), Array([-0.4821338, -0.7459771], dtype=float32))
要在模型上使用jax.jit
,您需要传递一个输出示例,这样JIT编译器就可以告诉我们预期的输出形状和dtypes
>>> # here we reuse the same model as before:
>>> apply, params = torch_module_to_jax(model, example_output=torch.zeros(1, 2, device="cuda"))
>>> def loss_function(params, x: jax.Array, y: jax.Array) -> jax.Array:
... y_pred = apply(params, x)
... return jax.numpy.mean((y - y_pred) ** 2)
>>> loss, grad = jax.jit(jax.value_and_grad(loss_function))(params, x, y)
>>> loss
Array(0.3944674, dtype=float32)
>>> grad
(Array([[-0.46541408, -0.15171866, -0.30520514],
[-0.7201077 , -0.23474531, -0.47222584]], dtype=float32), Array([-0.4821338, -0.7459771], dtype=float32))
项目详情
关闭
torch_jax_interop-0.0.7.tar.gz的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | 0547a05d15c68c3b3e316b76b10bf55a2f686eb23f6fc3c06c0c679b0ea90bf1 |
|
MD5 | e9c0021101a730c903596a0e8f169935 |
|
BLAKE2b-256 | ad40f156eb5c20ade2ba7d535f5585de62a52b333db56b3946d14823eebf3ee7 |
关闭
torch_jax_interop-0.0.7-py3-none-any.whl的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | 71f65e035f468b7f1dc30a357f02c70e727990fa5924f6d117dd238295912d1a |
|
MD5 | 23366fe86f47a3ce3f19670d49179c90 |
|
BLAKE2b-256 | e5b8b7934c4b37120e7a69289921dde4055488f9521bc3d264c6aeaa1862e0bb |