像numpy数组一样的Dataclasses(具有索引、切片、向量化)。
项目描述
Dataclass Array
DataclassArray
是类似numpy的数据类(可以批量处理、重塑、切片等),与Jax、TensorFlow和numpy兼容(计划支持torch)。
这减少了样板代码并提高了可读性。请参阅下面的 激励示例 部分。
要查看实际使用中的dataclass数组示例,请参阅 visu3d。
文档
定义
要创建一个 dca.DataclassArray
,取一个冻结的数据类并
- 从
dca.DataclassArray
继承 - 使用
dataclass_array.typing
注释字段以指定数组的内部形状和数据类型(请参见下面的静态或嵌套数据类字段)。数组类型是来自etils.array_types
的别名。
import dataclass_array as dca
from dataclass_array.typing import FloatArray
class Ray(dca.DataclassArray):
pos: FloatArray['*batch_shape 3']
dir: FloatArray['*batch_shape 3']
用法
之后,可以将数据类用作numpy数组
ray = Ray(pos=jnp.zeros((3, 3)), dir=jnp.eye(3))
ray.shape == (3,) # 3 rays batched together
ray.pos.shape == (3, 3) # Individual fields still available
# Numpy slicing/indexing/masking
ray = ray[..., 1:2]
ray = ray[norm(ray.dir) > 1e-7]
# Shape transformation
ray = ray.reshape((1, 3))
ray = ray.reshape('h w -> w h') # Native einops support
ray = ray.flatten()
# Stack multiple dataclass arrays together
ray = dca.stack([ray0, ray1, ...])
# Supports TF, Jax, Numpy (torch planned) and can be easily converted
ray = ray.as_jax() # as_np(), as_tf()
ray.xnp == jax.numpy # `numpy`, `jax.numpy`, `tf.experimental.numpy`
# Compatibility `with jax.tree_util`, `jax.vmap`,..
ray = jax.tree_util.tree_map(lambda x: x+1, ray)
DataclassArray
有两种类型的字段
- 数组字段:类似于numpy数组的批量字段,支持重塑、切片等。可以是
xnp.ndarray
或嵌套的dca.DataclassArray
。 - 静态字段:其他非numpy字段。不会被重塑等操作修改。在
jax.tree_map
中也会忽略静态字段。
class MyArray(dca.DataclassArray):
# Array fields
a: FloatArray['*batch_shape 3'] # Defined by `etils.array_types`
b: FloatArray['*batch_shape _ _'] # Dynamic shape
c: Ray # Nested DataclassArray (equivalent to `Ray['*batch_shape']`)
d: Ray['*batch_shape 6']
# Array fields explicitly defined
e: Any = dca.field(shape=(3,), dtype=np.float32)
f: Any = dca.field(shape=(None, None), dtype=np.float32) # Dynamic shape
g: Ray = dca.field(shape=(3,), dtype=Ray) # Nested DataclassArray
# Static field (everything not defined as above)
static0: float
static1: np.array
向量化
@dca.vectorize_method
允许您的数据类方法自动支持批量处理
- 将方法实现为
self.shape == ()
形式 - 使用
dca.vectorize_method
装饰该方法
class Camera(dca.DataclassArray):
K: FloatArray['*batch_shape 4 4']
resolution = tuple[int, int]
@dca.vectorize_method
def rays(self) -> Ray:
# Inside `@dca.vectorize_method` shape is always guarantee to be `()`
assert self.shape == ()
assert self.K.shape == (4, 4)
# Compute the ray as if there was only a single camera
return Ray(pos=..., dir=...)
之后,我们可以为多个一起批处理的相机生成射线
cams = Camera(K=K) # K.shape == (num_cams, 4, 4)
rays = cams.rays() # Generate the rays for all the cameras
cams.shape == (num_cams,)
rays.shape == (num_cams, h, w)
@dca.vectorize_method
与 jax.vmap
类似,但
- 仅适用于
dca.DataclassArray
方法 @dca.vectorize_method
不会仅对单轴进行向量化,而是会对*self.shape
进行向量化(而不仅仅是self.shape[0]
)。这类似于如果将vmap
应用到self.flatten()
上- 当有多个参数时,维度为
1
的轴会被广播。
例如,__matmul__(self, x: T) -> T
() @ (*x,) -> (*x,)
(b,) @ (b, *x) -> (b, *x)
(b,) @ (1, *x) -> (b, *x)
(1,) @ (b, *x) -> (b, *x)
(b, h, w) @ (b, h, w, *x) -> (b, h, w, *x)
(1, h, w) @ (b, 1, 1, *x) -> (b, h, w, *x)
(a, *x) @ (b, *x) -> Error: Incompatible a != b
要测试Colab,请参阅 visu3d
数据类的 Colab教程。
动机示例
dca.DataclassArray
通过简化常见模式来提高可读性
-
重塑数据类的所有字段
之前(
rays
是简单的dataclass
)num_rays = math.prod(rays.origins.shape[:-1]) rays = jax.tree_map(lambda r: r.reshape((num_rays, -1)), rays)
之后(
rays
是DataclassArray
)rays = rays.flatten() # (b, h, w) -> (b*h*w,)
-
渲染视频
之前(
cams: list[Camera]
)img = cams[0].render(scene) imgs = np.stack([cam.render(scene) for cam in cams[::2]]) imgs = np.stack([cam.render(scene) for cam in cams])
之后(
cams: Camera
且cams.shape == (num_cams,)}
)img = cams[0].render(scene) # Render only the first camera (to debug) imgs = cams[::2].render(scene) # Render 1/2 frames (for quicker iteration) imgs = cams.render(scene) # Render all cameras at once
安装
pip install dataclass_array
这不是一个官方的Google产品
项目详情
下载文件
下载适合您平台文件。如果您不确定选择哪个,请了解更多关于 安装包 的信息。
源分布
dataclass_array-1.5.2.tar.gz (34.8 kB 查看哈希值)
构建分布
dataclass_array-1.5.2-py3-none-any.whl (43.6 kB 查看哈希值)