跳转到主要内容

JAX + OpenAI Triton集成

项目描述

jax-triton

PyPI version

jax-triton存储库包含JAXTriton之间的集成。

文档可以在这里找到。

这不是一个官方支持的产品。

快速入门

我们感兴趣的主要功能是应用于JAX数组的Triton函数的jax_triton.triton_call,包括在jax.jit编译的函数内部。例如,我们可以定义Triton教程中的一个内核

import triton
import triton.language as tl


@triton.jit
def add_kernel(
    x_ptr,
    y_ptr,
    length,
    output_ptr,
    block_size: tl.constexpr,
):
  """Adds two vectors."""
  pid = tl.program_id(axis=0)
  block_start = pid * block_size
  offsets = block_start + tl.arange(0, block_size)
  mask = offsets < length
  x = tl.load(x_ptr + offsets, mask=mask)
  y = tl.load(y_ptr + offsets, mask=mask)
  output = x + y
  tl.store(output_ptr + offsets, output, mask=mask)

然后我们可以使用jax_triton.triton_call将其应用于JAX数组

import jax
import jax.numpy as jnp
import jax_triton as jt

def add(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
  out_shape = jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype)
  block_size = 8
  return jt.triton_call(
      x,
      y,
      x.size,
      kernel=add_kernel,
      out_shape=out_shape,
      grid=(x.size // block_size,),
      block_size=block_size)

x_val = jnp.arange(8)
y_val = jnp.arange(8, 16)
print(add(x_val, y_val))
print(jax.jit(add)(x_val, y_val))

请参阅示例目录,特别是fused_attention.pyfused attention ipynb

安装

$ pip install jax-triton

确保您已安装兼容CUDA的jaxlib。例如,您可以运行

$ pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

开发

要开发jax-triton,您可以使用以下命令克隆存储库

$ git clone https://github.com/jax-ml/jax-triton.git

并使用以下命令进行可编辑安装

$ cd jax-triton
$ pip install -e .

要运行jax-triton测试,您需要pytestabsl-py

$ pip install pytest absl-py
$ pytest tests/

项目详情


下载文件

下载适用于您平台上的文件。如果您不确定该选择哪个,请了解更多有关安装包的信息。

源分发

jax-triton-0.1.3.tar.gz (44.1 kB 查看哈希值)

上传时间

由以下支持