JAX绑定Flatiron研究所非均匀快速傅里叶变换库
项目描述
JAX到FINUFFT的绑定
此包提供了一个JAX接口,用于(子集)Flatiron研究所非均匀快速傅里叶变换(FINUFFT)库。请参阅FINUFFT文档以获取所有必要的定义、约定以及有关算法及其实现的更多信息。此包使用低级接口直接将FINUFFT库暴露给JAX的XLA后端,并实现了变换的微分规则。
包含的功能
此库包括CPU和GPU(CUDA)支持。GPU支持通过FINUFFT库的cuFINUFFT接口实现。
支持1维、2维和3维中的类型1和2变换。所有这些函数都支持正向、反向和更高阶微分,以及使用vmap
进行批处理。
安装
目前仅支持源构建。
对于构建,您只需要一个较新的Python版本(>3.6)和FFTW。启用GPU的构建还需要一个工作的CUDA编译器(即CUDA Toolkit)、CUDA >= 11.8以及兼容的cuDNN(较旧的CUDA版本可能工作,但未经过测试)。在运行时,您需要numpy
和jax
。
首先,克隆仓库并进入仓库根目录(别忘了使用 --recursive
标志,因为 FINUFFT 被包含为子模块)
git clone --recursive https://github.com/flatironinstitute/jax-finufft
cd jax-finufft
然后,您可以使用 conda
来设置构建环境(但您当然可以使用任何适合您的流程!)。例如,对于 CPU 构建,您可以使用
conda create -n jax-finufft -c conda-forge python=3.10 numpy scipy fftw cxx-compiler
conda activate jax-finufft
export CPATH=$CONDA_PREFIX/include:$CPATH
python -m pip install "jax[cpu]"
python -m pip install .
需要导出 CPATH
以确保构建可以找到通过 conda 安装的库(如 FFTW)的头文件。
对于 GPU 构建,尽管 CUDA 库和编译器名义上可以通过 conda 获取,但我们的经验表明,直接从 NVIDIA 获取 CUDA Toolkit 的“传统”方式可能效果最佳(参见 Horovod 的相关建议)。安装 CUDA Toolkit 后,可以使用以下命令设置其余的依赖项
conda create -n gpu-jax-finufft -c conda-forge python=3.10 numpy scipy fftw 'gxx<12'
conda activate gpu-jax-finufft
export CPATH=$CONDA_PREFIX/include:$CPATH
export CMAKE_ARGS="-DCMAKE_CUDA_ARCHITECTURES=70 -DJAX_FINUFFT_USE_CUDA=ON"
python -m pip install "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
python -m pip install .
JAX 网站上提供了安装 JAX 的其他方法;对于 jax-finufft,建议使用 “本地 CUDA” 安装方法,因为这确保了 CUDA 扩展与 CUDA 运行时使用相同的 Toolkit 版本编译。
在上面的 CMAKE_ARGS
行中,您需要选择您希望编译的 CUDA 架构。要查询您的 GPU 的 CUDA 架构(计算能力),您可以运行
$ nvidia-smi --query-gpu=compute_cap --format=csv,noheader
7.0
这对应于 CMAKE_CUDA_ARCHITECTURES=70
,即
export CMAKE_ARGS="-DCMAKE_CUDA_ARCHITECTURES=70 -DJAX_FINUFFT_USE_CUDA=ON"
请注意,pip 安装正在运行 CMake,因此必须在运行之前设置 CMAKE_ARGS
,但在运行时不需要。
在运行时,您可能还需要
export LD_LIBRARY_PATH="$CUDA_PATH/extras/CUPTI/lib64:$LD_LIBRARY_PATH"
如果未设置 CUDA_PATH
,您需要在上面的行中将它替换为您 CUDA 安装路径,通常像 /usr/local/cuda
这样的路径。
对于 Flatiron 用户,可以使用以下环境设置脚本代替 conda
环境脚本
ml modules/2.2
ml gcc
ml python/3.11
ml fftw
ml cuda/11
ml cudnn
ml nccl
export LD_LIBRARY_PATH=$CUDA_HOME/extras/CUPTI/lib64:$LD_LIBRARY_PATH
export CMAKE_ARGS="-DCMAKE_CUDA_ARCHITECTURES=60;70;80;90 -DJAX_FINUFFT_USE_CUDA=ON"
用法
此库提供两个高级函数(您通常只需与这些函数交互即可):nufft1
和 nufft2
(对于两种“类型”的变换)。如果您已经熟悉 FINUFFT 的 Python 接口,请注意,这里的函数签名是不同的!
例如,以下是进行一维类型 1 变换的示例
import numpy as np
from jax_finufft import nufft1
M = 100000
N = 200000
x = 2 * np.pi * np.random.uniform(size=M)
c = np.random.standard_normal(size=M) + 1j * np.random.standard_normal(size=M)
f = nufft1(N, c, x, eps=1e-6, iflag=1)
[!WARNING] 如 FINUFFT 文档 所述,非均匀点必须位于范围
[-3pi, 3pi]
内,但这 不会检查,因为 JAX 目前没有良好的运行时值检查接口。如果未满足此条件,可能会发生意外的崩溃。
注意 eps
和 iflag
是可选的,并且(出于良好的原因,我保证!)位置参数的顺序与 finufft
Python 包相反。
二维或三维变换的语法是
f = nufft1((Nx, Ny), c, x, y) # 2D
f = nufft1((Nx, Ny, Nz), c, x, y, z) # 3D
类型 2 变换的语法是(也允许可选的 iflag
和 eps
参数)
c = nufft2(f, x) # 1D
c = nufft2(f, x, y) # 2D
c = nufft2(f, x, y, z) # 3D
所有这些函数都支持使用 vmap
进行批处理,以及正向和反向模式微分。
高级使用
可以使用 nufft1
和 nufft2
的 opts
参数设置库的调整参数。例如,要显式设置 FINUFFT 应使用的 CPU 上采样因子,您可以将上面的示例更新如下
from jax_finufft import options
opts = options.Opts(upsampfac=2.0)
nufft1(N, c, x, opts=opts)
GPU 的对应选项是 gpu_upsampfac
。实际上,所有 GPU 选项都是以 gpu_
为前缀。
这里的一个复杂问题是,NUFFT 的向量-雅可比积需要评估不同类型的 NUFFT。这意味着您可能想要分别调整正向和反向传递的选项。这可以通过使用 options.NestedOpts
接口来实现。例如,要为正向和反向传递使用不同的上采样因子,上面的代码变为
import jax
opts = options.NestedOpts(
forward=options.Opts(upsampfac=2.0),
backward=options.Opts(upsampfac=1.25),
)
jax.grad(lambda args: nufft1(N, *args, opts=opts).real.sum())((c, x))
或者,在这种情况下等价于
opts = options.NestedOpts(
type1=options.Opts(upsampfac=2.0),
type2=options.Opts(upsampfac=1.25),
)
有关所有CPU调优参数的描述,请参阅FINUFFT文档。相应的GPU参数目前仅在cufinufft_opts.h
的源代码形式中列出。
类似库
- finufft:FINUFFT的“官方”Python绑定。如果您尚未使用JAX且不需要对变换进行微分,这是一个不错的选择。
- mrphys/tensorflow-nufft:FINUFFT和cuFINUFFT的TensorFlow绑定。
许可与归属
本包由Dan Foreman-Mackey开发,根据Apache License,Version 2.0许可,以下为版权信息
版权所有 2021,2022,2023,Simons基金会,Inc.
如果您使用本软件,请引用FINUFFT文档中列出的主要参考文献。
项目详情
jax_finufft-0.1.0.tar.gz 的哈希
算法 | 哈希摘要 | |
---|---|---|
SHA256 | 0c9173837fa0ae47b61074f8c05b246d9ca5b21bda6174beda8c27ea75c4f152 |
|
MD5 | 564167555b26b8c01788a52653612fbb |
|
BLAKE2b-256 | b7269aa275d78c4ae4abca4c8d095d2c1c1bf137dab8aaea07eab2b2f6e71ebb |
jax_finufft-0.1.0-cp312-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl 的哈希
算法 | 哈希摘要 | |
---|---|---|
SHA256 | fd0cebca0ac3da173d30b5bf413fe7ba8fa5b0bf8483b5de7b1d2c55e0d640ce |
|
MD5 | e98274fdcf1fcc1f4f2de664a7184ad1 |
|
BLAKE2b-256 | c957713d26c173c245d42e4fa5afc00fe1dd5feb27bff4d6d6a355e95dcfa755 |
jax_finufft-0.1.0-cp312-abi3-macosx_11_0_arm64.whl 的哈希
算法 | 哈希摘要 | |
---|---|---|
SHA256 | 8faebf86576abddb46d282fc0e5de740a3780691e0e430547412f321c481645d |
|
MD5 | bc12923035af8032d5baf2aabd5a42cd |
|
BLAKE2b-256 | abc76bc1b5f70502bc31b3f4b91e0f350db68bdaee6e4f5e666661b57a4ea08f |
jax_finufft-0.1.0-cp312-abi3-macosx_10_14_x86_64.whl 的哈希
算法 | 哈希摘要 | |
---|---|---|
SHA256 | b1cb11ec47b264e9a25cd87970bdc8828d7886dd5ebd481f12b5f5e6a02d104b |
|
MD5 | 69f526ab3048f97bd24a73edba484d82 |
|
BLAKE2b-256 | 4a5db78b4553b31a43351b4e20fa2669209882b2e25dfe9a0c4353790ed27fe5 |
哈希值 对于 jax_finufft-0.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
算法 | 哈希摘要 | |
---|---|---|
SHA256 | b09ad5ed078fe49ccbea7cd17043ea4739941cb0be7835a695c55d5129def919 |
|
MD5 | 30cba15526b52fc5d9818ba754bf5919 |
|
BLAKE2b-256 | 5e20f14375d8a9eb4662562d273803e2378a83497a74426a3ac852df8fa35f24 |
哈希值 对于 jax_finufft-0.1.0-cp311-cp311-macosx_11_0_arm64.whl
算法 | 哈希摘要 | |
---|---|---|
SHA256 | e0f4688b6f831c47591179cf79988481bf1437da5cb4775c839317c3b5c84ee4 |
|
MD5 | 979e6c4ab1948c91bfd15aec6642d8fc |
|
BLAKE2b-256 | f52468611ac67f151f3cdc56f78ad044fd407d499ca548574475fe28688de614 |
哈希值 对于 jax_finufft-0.1.0-cp311-cp311-macosx_10_14_x86_64.whl
算法 | 哈希摘要 | |
---|---|---|
SHA256 | e5a66da56d3956c077bc06faa54f2f615bda6abad4298480ea8d8f7d70a4af7f |
|
MD5 | e3247199df1b16c2a758835bfc3b0fa4 |
|
BLAKE2b-256 | 6baf63574c18bec5039d4fd478cbb9feef2d7707de82320094ef76d9fc680620 |
哈希值 对于 jax_finufft-0.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
算法 | 哈希摘要 | |
---|---|---|
SHA256 | eb6c95c3223d6e4fec1a6559436b81e33463863a0c31e0fb7426152fe6c834a4 |
|
MD5 | 52caf206d76827acf3ca68e17e2e02a3 |
|
BLAKE2b-256 | 8a80ea22b5cba3f2e2ec57def22821ba2967d1035631b652882f7038bbb4ab51 |
哈希值 对于 jax_finufft-0.1.0-cp310-cp310-macosx_11_0_arm64.whl
算法 | 哈希摘要 | |
---|---|---|
SHA256 | 2bbe7c67020cc8bd2fc9ab70421fa6aa3a3ca2b7cdfba18d2bdf9452101fb896 |
|
MD5 | 71bd0ca1e2e71532daad901563a24ea0 |
|
BLAKE2b-256 | 1eb98b72df683f3d5031d38f440a001dd8d3eb59a12fbaa01fa652e5e625a190 |
哈希值 对于 jax_finufft-0.1.0-cp310-cp310-macosx_10_14_x86_64.whl
算法 | 哈希摘要 | |
---|---|---|
SHA256 | e2fb906afd0d95cc8729a6b7139a64c5ded28c800b6007ade61b97cdb3b36c40 |
|
MD5 | 9c357307ba0c6bf0ec01b9269bd216cb |
|
BLAKE2b-256 | ee4c72c18cd06804cb34869411515c1dd0ae9d8a84308ddb50f222ac5053e58a |
哈希值 对于 jax_finufft-0.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
算法 | 哈希摘要 | |
---|---|---|
SHA256 | 606cb916883e1ec2842a0147bb71b8e24a6618506f2536ec177df00720a738a7 |
|
MD5 | f8f25325953e15ca59d53df8af3db722 |
|
BLAKE2b-256 | 50f4159d92173959ef1621ff79b9712a8140bc18321356250c3cebcecf66f39b |
哈希值 对于 jax_finufft-0.1.0-cp39-cp39-macosx_11_0_arm64.whl
算法 | 哈希摘要 | |
---|---|---|
SHA256 | 5f354bcb22140f2c29014a0d270c27d374afe5264391904c56f1a6c562c10139 |
|
MD5 | 41d1d8d53ed96f1807a0114eb3e39dd6 |
|
BLAKE2b-256 | 46206704e153c2523c41a7e7e93d2557a457526b9a43169d3e20740b7f9227ee |
哈希值 对于 jax_finufft-0.1.0-cp39-cp39-macosx_10_14_x86_64.whl
算法 | 哈希摘要 | |
---|---|---|
SHA256 | c48382b45866fb078f6187810fc34632496fd8338e750adf7c0a7a66536b4118 |
|
MD5 | 03eae66bf0484c9814a39c81fd0bebd7 |
|
BLAKE2b-256 | 304d0bd8f3f262612bfa95b48feadfabd3730271dfba3e4597bb0ec6d8a2f927 |