跳转到主要内容

像numpy数组一样的Dataclasses(具有索引、切片、向量化)。

项目描述

Dataclass Array

Unittests PyPI version Documentation Status

DataclassArray 是类似numpy的数据类(可以批量处理、重塑、切片等),与Jax、TensorFlow和numpy兼容(计划支持torch)。

这减少了样板代码并提高了可读性。请参阅下面的 激励示例 部分。

要查看实际使用中的dataclass数组示例,请参阅 visu3d

文档

定义

要创建一个 dca.DataclassArray,取一个冻结的数据类并

  • dca.DataclassArray 继承
  • 使用 dataclass_array.typing 注释字段以指定数组的内部形状和数据类型(请参见下面的静态或嵌套数据类字段)。数组类型是来自 etils.array_types 的别名。
import dataclass_array as dca
from dataclass_array.typing import FloatArray


class Ray(dca.DataclassArray):
  pos: FloatArray['*batch_shape 3']
  dir: FloatArray['*batch_shape 3']

用法

之后,可以将数据类用作numpy数组

ray = Ray(pos=jnp.zeros((3, 3)), dir=jnp.eye(3))


ray.shape == (3,)  # 3 rays batched together
ray.pos.shape == (3, 3)  # Individual fields still available

# Numpy slicing/indexing/masking
ray = ray[..., 1:2]
ray = ray[norm(ray.dir) > 1e-7]

# Shape transformation
ray = ray.reshape((1, 3))
ray = ray.reshape('h w -> w h')  # Native einops support
ray = ray.flatten()

# Stack multiple dataclass arrays together
ray = dca.stack([ray0, ray1, ...])

# Supports TF, Jax, Numpy (torch planned) and can be easily converted
ray = ray.as_jax()  # as_np(), as_tf()
ray.xnp == jax.numpy  # `numpy`, `jax.numpy`, `tf.experimental.numpy`

# Compatibility `with jax.tree_util`, `jax.vmap`,..
ray = jax.tree_util.tree_map(lambda x: x+1, ray)

DataclassArray 有两种类型的字段

  • 数组字段:类似于numpy数组的批量字段,支持重塑、切片等。可以是 xnp.ndarray 或嵌套的 dca.DataclassArray
  • 静态字段:其他非numpy字段。不会被重塑等操作修改。在 jax.tree_map 中也会忽略静态字段。
class MyArray(dca.DataclassArray):
  # Array fields
  a: FloatArray['*batch_shape 3']  # Defined by `etils.array_types`
  b: FloatArray['*batch_shape _ _']  # Dynamic shape
  c: Ray  # Nested DataclassArray (equivalent to `Ray['*batch_shape']`)
  d: Ray['*batch_shape 6']

  # Array fields explicitly defined
  e: Any = dca.field(shape=(3,), dtype=np.float32)
  f: Any = dca.field(shape=(None,  None), dtype=np.float32)  # Dynamic shape
  g: Ray = dca.field(shape=(3,), dtype=Ray)  # Nested DataclassArray

  # Static field (everything not defined as above)
  static0: float
  static1: np.array

向量化

@dca.vectorize_method 允许您的数据类方法自动支持批量处理

  1. 将方法实现为 self.shape == () 形式
  2. 使用 dca.vectorize_method 装饰该方法
class Camera(dca.DataclassArray):
  K: FloatArray['*batch_shape 4 4']
  resolution = tuple[int, int]

  @dca.vectorize_method
  def rays(self) -> Ray:
    # Inside `@dca.vectorize_method` shape is always guarantee to be `()`
    assert self.shape == ()
    assert self.K.shape == (4, 4)

    # Compute the ray as if there was only a single camera
    return Ray(pos=..., dir=...)

之后,我们可以为多个一起批处理的相机生成射线

cams = Camera(K=K)  # K.shape == (num_cams, 4, 4)
rays = cams.rays()  # Generate the rays for all the cameras

cams.shape == (num_cams,)
rays.shape == (num_cams, h, w)

@dca.vectorize_methodjax.vmap 类似,但

  • 仅适用于 dca.DataclassArray 方法
  • @dca.vectorize_method 不会仅对单轴进行向量化,而是会对 *self.shape 进行向量化(而不仅仅是 self.shape[0])。这类似于如果将 vmap 应用到 self.flatten()
  • 当有多个参数时,维度为 1 的轴会被广播。

例如,__matmul__(self, x: T) -> T

() @ (*x,) -> (*x,)
(b,) @ (b, *x) -> (b, *x)
(b,) @ (1, *x) -> (b, *x)
(1,) @ (b, *x) -> (b, *x)
(b, h, w) @ (b, h, w, *x) -> (b, h, w, *x)
(1, h, w) @ (b, 1, 1, *x) -> (b, h, w, *x)
(a, *x) @ (b, *x) -> Error: Incompatible a != b

要测试Colab,请参阅 visu3d 数据类的 Colab教程

动机示例

dca.DataclassArray 通过简化常见模式来提高可读性

  • 重塑数据类的所有字段

    之前(rays 是简单的 dataclass

    num_rays = math.prod(rays.origins.shape[:-1])
    rays = jax.tree_map(lambda r: r.reshape((num_rays, -1)), rays)
    

    之后(raysDataclassArray

    rays = rays.flatten()  # (b, h, w) -> (b*h*w,)
    
  • 渲染视频

    之前(cams: list[Camera]

    img = cams[0].render(scene)
    imgs = np.stack([cam.render(scene) for cam in cams[::2]])
    imgs = np.stack([cam.render(scene) for cam in cams])
    

    之后(cams: Cameracams.shape == (num_cams,)}

    img = cams[0].render(scene)  # Render only the first camera (to debug)
    imgs = cams[::2].render(scene)  # Render 1/2 frames (for quicker iteration)
    imgs = cams.render(scene)  # Render all cameras at once
    

安装

pip install dataclass_array

这不是一个官方的Google产品

项目详情


下载文件

下载适合您平台文件。如果您不确定选择哪个,请了解更多关于 安装包 的信息。

源分布

dataclass_array-1.5.2.tar.gz (34.8 kB 查看哈希值)

上传时间

构建分布

dataclass_array-1.5.2-py3-none-any.whl (43.6 kB 查看哈希值)

上传时间 Python 3

由以下支持

AWS AWS 云计算和安全赞助商 Datadog Datadog 监控 Fastly Fastly CDN Google Google 下载分析 Microsoft Microsoft PSF 赞助商 Pingdom Pingdom 监控 Sentry Sentry 错误日志 StatusPage StatusPage 状态页面