跳转到主要内容

JAX绑定Flatiron研究所非均匀快速傅里叶变换库

项目描述

JAX到FINUFFT的绑定

GitHub Tests Jenkins Tests

此包提供了一个JAX接口,用于(子集)Flatiron研究所非均匀快速傅里叶变换(FINUFFT)库。请参阅FINUFFT文档以获取所有必要的定义、约定以及有关算法及其实现的更多信息。此包使用低级接口直接将FINUFFT库暴露给JAX的XLA后端,并实现了变换的微分规则。

包含的功能

此库包括CPU和GPU(CUDA)支持。GPU支持通过FINUFFT库的cuFINUFFT接口实现。

支持1维、2维和3维中的类型1和2变换。所有这些函数都支持正向、反向和更高阶微分,以及使用vmap进行批处理。

安装

目前仅支持源构建。

对于构建,您只需要一个较新的Python版本(>3.6)和FFTW。启用GPU的构建还需要一个工作的CUDA编译器(即CUDA Toolkit)、CUDA >= 11.8以及兼容的cuDNN(较旧的CUDA版本可能工作,但未经过测试)。在运行时,您需要numpyjax

首先,克隆仓库并进入仓库根目录(别忘了使用 --recursive 标志,因为 FINUFFT 被包含为子模块)

git clone --recursive https://github.com/flatironinstitute/jax-finufft
cd jax-finufft

然后,您可以使用 conda 来设置构建环境(但您当然可以使用任何适合您的流程!)。例如,对于 CPU 构建,您可以使用

conda create -n jax-finufft -c conda-forge python=3.10 numpy scipy fftw cxx-compiler
conda activate jax-finufft
export CPATH=$CONDA_PREFIX/include:$CPATH
python -m pip install "jax[cpu]"
python -m pip install .

需要导出 CPATH 以确保构建可以找到通过 conda 安装的库(如 FFTW)的头文件。

对于 GPU 构建,尽管 CUDA 库和编译器名义上可以通过 conda 获取,但我们的经验表明,直接从 NVIDIA 获取 CUDA Toolkit 的“传统”方式可能效果最佳(参见 Horovod 的相关建议)。安装 CUDA Toolkit 后,可以使用以下命令设置其余的依赖项

conda create -n gpu-jax-finufft -c conda-forge python=3.10 numpy scipy fftw 'gxx<12'
conda activate gpu-jax-finufft
export CPATH=$CONDA_PREFIX/include:$CPATH
export CMAKE_ARGS="-DCMAKE_CUDA_ARCHITECTURES=70 -DJAX_FINUFFT_USE_CUDA=ON"
python -m pip install "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
python -m pip install .

JAX 网站上提供了安装 JAX 的其他方法;对于 jax-finufft,建议使用 “本地 CUDA” 安装方法,因为这确保了 CUDA 扩展与 CUDA 运行时使用相同的 Toolkit 版本编译。

在上面的 CMAKE_ARGS 行中,您需要选择您希望编译的 CUDA 架构。要查询您的 GPU 的 CUDA 架构(计算能力),您可以运行

$ nvidia-smi --query-gpu=compute_cap --format=csv,noheader
7.0

这对应于 CMAKE_CUDA_ARCHITECTURES=70,即

export CMAKE_ARGS="-DCMAKE_CUDA_ARCHITECTURES=70 -DJAX_FINUFFT_USE_CUDA=ON"

请注意,pip 安装正在运行 CMake,因此必须在运行之前设置 CMAKE_ARGS,但在运行时不需要。

在运行时,您可能还需要

export LD_LIBRARY_PATH="$CUDA_PATH/extras/CUPTI/lib64:$LD_LIBRARY_PATH"

如果未设置 CUDA_PATH,您需要在上面的行中将它替换为您 CUDA 安装路径,通常像 /usr/local/cuda 这样的路径。

对于 Flatiron 用户,可以使用以下环境设置脚本代替 conda

环境脚本
ml modules/2.2
ml gcc
ml python/3.11
ml fftw
ml cuda/11
ml cudnn
ml nccl

export LD_LIBRARY_PATH=$CUDA_HOME/extras/CUPTI/lib64:$LD_LIBRARY_PATH
export CMAKE_ARGS="-DCMAKE_CUDA_ARCHITECTURES=60;70;80;90 -DJAX_FINUFFT_USE_CUDA=ON"

用法

此库提供两个高级函数(您通常只需与这些函数交互即可):nufft1nufft2(对于两种“类型”的变换)。如果您已经熟悉 FINUFFT 的 Python 接口,请注意,这里的函数签名是不同的!

例如,以下是进行一维类型 1 变换的示例

import numpy as np
from jax_finufft import nufft1

M = 100000
N = 200000

x = 2 * np.pi * np.random.uniform(size=M)
c = np.random.standard_normal(size=M) + 1j * np.random.standard_normal(size=M)
f = nufft1(N, c, x, eps=1e-6, iflag=1)

[!WARNING] 如 FINUFFT 文档 所述,非均匀点必须位于范围 [-3pi, 3pi] 内,但这 不会检查,因为 JAX 目前没有良好的运行时值检查接口。如果未满足此条件,可能会发生意外的崩溃。

注意 epsiflag 是可选的,并且(出于良好的原因,我保证!)位置参数的顺序与 finufft Python 包相反。

二维或三维变换的语法是

f = nufft1((Nx, Ny), c, x, y)  # 2D
f = nufft1((Nx, Ny, Nz), c, x, y, z)  # 3D

类型 2 变换的语法是(也允许可选的 iflageps 参数)

c = nufft2(f, x)  # 1D
c = nufft2(f, x, y)  # 2D
c = nufft2(f, x, y, z)  # 3D

所有这些函数都支持使用 vmap 进行批处理,以及正向和反向模式微分。

高级使用

可以使用 nufft1nufft2opts 参数设置库的调整参数。例如,要显式设置 FINUFFT 应使用的 CPU 上采样因子,您可以将上面的示例更新如下

from jax_finufft import options

opts = options.Opts(upsampfac=2.0)
nufft1(N, c, x, opts=opts)

GPU 的对应选项是 gpu_upsampfac。实际上,所有 GPU 选项都是以 gpu_ 为前缀。

这里的一个复杂问题是,NUFFT 的向量-雅可比积需要评估不同类型的 NUFFT。这意味着您可能想要分别调整正向和反向传递的选项。这可以通过使用 options.NestedOpts 接口来实现。例如,要为正向和反向传递使用不同的上采样因子,上面的代码变为

import jax

opts = options.NestedOpts(
  forward=options.Opts(upsampfac=2.0),
  backward=options.Opts(upsampfac=1.25),
)
jax.grad(lambda args: nufft1(N, *args, opts=opts).real.sum())((c, x))

或者,在这种情况下等价于

opts = options.NestedOpts(
  type1=options.Opts(upsampfac=2.0),
  type2=options.Opts(upsampfac=1.25),
)

有关所有CPU调优参数的描述,请参阅FINUFFT文档。相应的GPU参数目前仅在cufinufft_opts.h的源代码形式中列出。

类似库

  • finufft:FINUFFT的“官方”Python绑定。如果您尚未使用JAX且不需要对变换进行微分,这是一个不错的选择。
  • mrphys/tensorflow-nufft:FINUFFT和cuFINUFFT的TensorFlow绑定。

许可与归属

本包由Dan Foreman-Mackey开发,根据Apache License,Version 2.0许可,以下为版权信息

版权所有 2021,2022,2023,Simons基金会,Inc.

如果您使用本软件,请引用FINUFFT文档中列出的主要参考文献。

项目详情


下载文件

下载适用于您的平台的文件。如果您不确定选择哪个,请了解更多关于安装软件包的信息。

源代码分发

jax_finufft-0.1.0.tar.gz (2.6 MB 查看哈希值)

上传时间 源代码

构建分发

jax_finufft-0.1.0-cp312-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.2 MB 查看哈希值)

上传时间 CPython 3.12+ manylinux: glibc 2.17+ x86-64

jax_finufft-0.1.0-cp312-abi3-macosx_11_0_arm64.whl (1.3 MB 查看哈希值)

上传时间 CPython 3.12+ macOS 11.0+ ARM64

jax_finufft-0.1.0-cp312-abi3-macosx_10_14_x86_64.whl (2.9 MB 查看哈希值)

上传时间 CPython 3.12+ macOS 10.14+ x86-64

jax_finufft-0.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.2 MB 查看哈希值)

上传时间 CPython 3.11 manylinux: glibc 2.17+ x86-64

jax_finufft-0.1.0-cp311-cp311-macosx_11_0_arm64.whl (1.3 MB 查看哈希值)

上传时间 CPython 3.11 macOS 11.0+ ARM64

jax_finufft-0.1.0-cp311-cp311-macosx_10_14_x86_64.whl (2.9 MB 查看哈希值)

上传于 CPython 3.11 macOS 10.14+ x86-64

jax_finufft-0.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.2 MB 查看哈希)

上传于 CPython 3.10 manylinux: glibc 2.17+ x86-64

jax_finufft-0.1.0-cp310-cp310-macosx_11_0_arm64.whl (1.3 MB 查看哈希)

上传于 CPython 3.10 macOS 11.0+ ARM64

jax_finufft-0.1.0-cp310-cp310-macosx_10_14_x86_64.whl (2.9 MB 查看哈希)

上传于 CPython 3.10 macOS 10.14+ x86-64

jax_finufft-0.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.2 MB 查看哈希)

上传于 CPython 3.9 manylinux: glibc 2.17+ x86-64

jax_finufft-0.1.0-cp39-cp39-macosx_11_0_arm64.whl (1.3 MB 查看哈希)

上传于 CPython 3.9 macOS 11.0+ ARM64

jax_finufft-0.1.0-cp39-cp39-macosx_10_14_x86_64.whl (2.9 MB 查看哈希)

上传于 CPython 3.9 macOS 10.14+ x86-64

由以下机构支持

AWS AWS 云计算和安全赞助商 Datadog Datadog 监控 Fastly Fastly CDN Google Google 下载分析 Microsoft Microsoft PSF 赞助商 Pingdom Pingdom 监控 Sentry Sentry 错误记录 StatusPage StatusPage 状态页面