JAX + OpenAI Triton集成
项目描述
jax-triton
jax-triton
存储库包含JAX和Triton之间的集成。
文档可以在这里找到。
这不是一个官方支持的产品。
快速入门
我们感兴趣的主要功能是应用于JAX数组的Triton函数的jax_triton.triton_call
,包括在jax.jit
编译的函数内部。例如,我们可以定义Triton教程中的一个内核
import triton
import triton.language as tl
@triton.jit
def add_kernel(
x_ptr,
y_ptr,
length,
output_ptr,
block_size: tl.constexpr,
):
"""Adds two vectors."""
pid = tl.program_id(axis=0)
block_start = pid * block_size
offsets = block_start + tl.arange(0, block_size)
mask = offsets < length
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
tl.store(output_ptr + offsets, output, mask=mask)
然后我们可以使用jax_triton.triton_call
将其应用于JAX数组
import jax
import jax.numpy as jnp
import jax_triton as jt
def add(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
out_shape = jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype)
block_size = 8
return jt.triton_call(
x,
y,
x.size,
kernel=add_kernel,
out_shape=out_shape,
grid=(x.size // block_size,),
block_size=block_size)
x_val = jnp.arange(8)
y_val = jnp.arange(8, 16)
print(add(x_val, y_val))
print(jax.jit(add)(x_val, y_val))
请参阅示例目录,特别是fused_attention.py和fused attention ipynb。
安装
$ pip install jax-triton
确保您已安装兼容CUDA的jaxlib
。例如,您可以运行
$ pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
开发
要开发jax-triton
,您可以使用以下命令克隆存储库
$ git clone https://github.com/jax-ml/jax-triton.git
并使用以下命令进行可编辑安装
$ cd jax-triton
$ pip install -e .
要运行jax-triton
测试,您需要pytest
和absl-py
$ pip install pytest absl-py
$ pytest tests/
项目详情
关闭
jax-triton-0.1.3.tar.gz的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | 80fb1db3ea0af2818194f68acff3a105e3d6783d8699a474fd3104d6d79c7fac |
|
MD5 | f838b2ee3b82e529bd68cd1b93c92989 |
|
BLAKE2b-256 | f7ed7d2cf14270777cf8b0476f862e8b72d736ef86eb9dc44e7009a93f784e1f |