跳转到主要内容

将张量从Jax转换为Torch以及相反方向的工具

项目描述

Torch <-> Jax 互操作工具

嘿,你在那里!

  • 你使用PyTorch,但好奇Jax(或反之亦然)吗?你愿意逐渐将一些(Jax/PyTorch)添加到你的项目中,而不是从头开始吗?
  • 想要避免将现有PyTorch代码库中的模型重写为Jax(或反之亦然)的痛苦吗?
  • 你喜欢Jax的性能优势,但不想放弃你喜欢的PyTorch软件框架(例如 Lightning)吗?

那么,我有一些好消息要告诉你! 你可以拥有一切:来自Jax的甜美的jit函数和自动微分,以及来自PyTorch软件生态系统的成熟、广泛使用的框架。

这个工具的作用

此软件包包含一些简化Jax和Torch之间互操作性的实用函数:torch_to_jaxjax_to_torchWrappedJaxFunctiontorch_module_to_jax

此存储库包含将PyTorch张量转换为JAX数组和相反方向的工具。这种转换是通过dlpack格式实现的,这是在不同深度学习框架之间交换张量的通用格式。关键的是,该格式允许PyTorch和JAX之间进行零拷贝的*张量共享。

* 注意:对于某些具有特定内存布局的torch张量(例如通道优先的图像张量),Jax将拒绝从dlpack中读取数组,因此在转换时需要对数据进行扁平化和展开,这可能涉及复制。这目前在命令行上显示为警告。

安装

pip install torch-jax-interop

类似的项目

用法

import torch
import jax.numpy as jnp
from torch_jax_interop import jax_to_torch, torch_to_jax

torch.Tensor转换为jax.Array

import jax
import torch

tensors = {
    "x": torch.randn(5),
    "y": torch.arange(5),
}

jax_arrays = jax.tree.map(torch_to_jax, tensors)
torch_tensors = jax.tree.map(jax_to_torch, jax_arrays)

将torch.Tensors传递给Jax函数

@jax_to_torch
def some_jax_function(x: jnp.ndarray) -> jnp.ndarray:
    return x + jnp.ones_like(x)

torch_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
some_torch_tensor = torch.arange(5, device=device)

torch_output = some_jax_function(some_torch_tensor)


some_jax_array = jnp.arange(5)

@torch_to_jax
def some_torch_function(x: torch.Tensor) -> torch.Tensor:
    return x + torch.ones_like(x)

print(some_torch_function(some_jax_array))

示例

Jax到Torch nn.Module

假设我们有一些希望在PyTorch模型中使用的jax函数

import jax
import jax.numpy as jnp
def some_jax_function(params: jax.Array, x: jax.Array):
    '''Some toy function that takes in some parameters and an input vector.'''
    return jnp.dot(x, params)

通过导入这个

from torch_jax_interop import WrappedJaxFunction

我们可以将这个jax函数包装成一个具有可学习参数的torch.nn.Module

import torch
import torch.nn
module = WrappedJaxFunction(some_jax_function, jax.random.normal(jax.random.key(0), (2, 1)))
module = module.to("cpu")  # jax arrays are on GPU by default, moving them to CPU for this example.

这些参数现在是模块参数的可学习参数

dict(module.state_dict())
{'params.0': tensor([[-0.7848],
        [ 0.8564]])}

您可以使用它就像其他任何torch.nn.Module一样

x, y = torch.randn(2), torch.rand(1)
output = module(x)
loss = torch.nn.functional.mse_loss(output, y)
loss.backward()

model = torch.nn.Sequential(
    torch.nn.Linear(123, 2),
    module,
)

同样,对于flax.linen.Module,您现在可以在torch forward / backward过程中使用它们

import flax.linen

class Classifier(flax.linen.Module):
    num_classes: int = 10

    @flax.linen.compact
    def __call__(self, x: jax.Array):
        x = x.reshape((x.shape[0], -1))  # flatten
        x = flax.linen.Dense(features=256)(x)
        x = flax.linen.relu(x)
        x = flax.linen.Dense(features=self.num_classes)(x)
        return x

jax_module = Classifier(num_classes=10)
jax_params = jax_module.init(jax.random.key(0), x)

from torch_jax_interop import WrappedJaxFunction

torch_module = WrappedJaxFunction(jax.jit(jax_module.apply), jax_params)

Torch nn.Module到jax函数

>>> import torch
>>> import jax

>>> model = torch.nn.Linear(3, 2, device="cuda")
>>> apply_fn, params = torch_module_to_jax(model)


>>> def loss_function(params, x: jax.Array, y: jax.Array) -> jax.Array:
...     y_pred = apply_fn(params, x)
...     return jax.numpy.mean((y - y_pred) ** 2)


>>> x = jax.random.uniform(key=jax.random.key(0), shape=(1, 3))
>>> y = jax.random.uniform(key=jax.random.key(1), shape=(1, 1))

>>> loss, grad = jax.value_and_grad(loss_function)(params, x, y)
>>> loss
Array(0.3944674, dtype=float32)
>>> grad
(Array([[-0.46541408, -0.15171866, -0.30520514],
        [-0.7201077 , -0.23474531, -0.47222584]], dtype=float32), Array([-0.4821338, -0.7459771], dtype=float32))

要在模型上使用jax.jit,您需要传递一个输出示例,这样JIT编译器就可以告诉我们预期的输出形状和dtypes

>>> # here we reuse the same model as before:
>>> apply, params = torch_module_to_jax(model, example_output=torch.zeros(1, 2, device="cuda"))
>>> def loss_function(params, x: jax.Array, y: jax.Array) -> jax.Array:
...     y_pred = apply(params, x)
...     return jax.numpy.mean((y - y_pred) ** 2)
>>> loss, grad = jax.jit(jax.value_and_grad(loss_function))(params, x, y)
>>> loss
Array(0.3944674, dtype=float32)
>>> grad
(Array([[-0.46541408, -0.15171866, -0.30520514],
        [-0.7201077 , -0.23474531, -0.47222584]], dtype=float32), Array([-0.4821338, -0.7459771], dtype=float32))

项目详情


下载文件

下载适合您平台的文件。如果您不确定该选择哪个,请了解更多关于安装包的信息。

源分布

torch_jax_interop-0.0.7.tar.gz (26.6 kB 查看哈希值)

上传时间

构建分布

torch_jax_interop-0.0.7-py3-none-any.whl (30.8 kB 查看哈希值)

上传时间 Python 3

由以下机构支持

AWS AWS 云计算和安全赞助商 Datadog Datadog 监控 Fastly Fastly CDN Google Google 下载分析 Microsoft Microsoft PSF 赞助商 Pingdom Pingdom 监控 Sentry Sentry 错误记录 StatusPage StatusPage 状态页面