在PyTorch的多维网格上进行三次样条插值
项目描述
torch-cubic-spline-grids
在PyTorch的多维网格上进行三次样条插值。
此包的主要目标是提供1-4D空间的可学习的连续参数化。
概述
torch_cubic_spline_grids
提供了一组名为grids的PyTorch组件。
网格由以下参数定义
- 它们的维度(1d,2d,3d,4d...)
- 每个维度覆盖的点的数量(
resolution
) - 每个网格点上存储的值的数量(
n_channels
) - 我们如何在网格点上插值值
此包中的所有网格都由均匀分布在每个维度全范围上的点组成。
第一步
让我们创建一个简单的2D网格,每个网格点上有一个值。
import torch
from torch_cubic_spline_grids import CubicBSplineGrid2d
grid = CubicBSplineGrid2d(resolution=(5, 3), n_channels=1)
grid.ndim
是2
grid.resolution
是(5, 3)
(或(h, w)
)grid.n_channels
是1
grid.data.shape
是(1, 5, 3)
(或(c, h, w)
)
换句话说,网格跨越两个维度(h, w)
,在h
中有5个点,在w
中有3个点。2D网格上的每个点存储一个值。网格数据存储在形状为(c, *grid_resolution)
的张量中。
我们可以获得网格上任何连续点的值(插值)。网格坐标系沿着每个网格维度从[0, 1]
扩展。通过沿着网格的每个维度逐个应用三次样条插值来获得插值。
coords = torch.rand(size=(10, 2)) # values in [0, 1]
interpolants = grid(coords)
interpolants.shape
是(10, 1)
优化
通过最小化与网格插值相关的损失函数,可以优化每个网格点的值。这样可以将网格的连续空间更精确地建模为1-4D空间。
上图显示了1D网格上6个控制点的值正在被优化,以便通过三次B样条插值在这些点之间进行插值时,可以近似正弦波的单次振荡。
网格类型
torch_cubic_spline_grids
提供可以用于三次B样条插值或三次Catmull-Rom样条插值的网格。
样条 | 连续性 | 插值? |
---|---|---|
三次B样条 | C2 | 否 |
Catmull-Rom样条 | C1 | 是 |
如果您需要得到的曲线与网格上的数据进行交点,应使用三次Catmull-Rom样条网格
CubicCatmullRomGrid1d
CubicCatmullRomGrid2d
CubicCatmullRomGrid3d
CubicCatmullRomGrid4d
如果您需要连续的二阶导数,则三次B样条网格更为合适。
CubicBSplineGrid1d
CubicBSplineGrid2d
CubicBSplineGrid3d
CubicBSplineGrid4d
正则化
每个维度中的点数应选择得使得在网格上插值可以近似模型化的底层现象,而不至于过拟合。低分辨率网格通过平滑模型提供正则化效果。
安装
torch_cubic_spline_grids
可在PyPI上使用
pip install torch-cubic-spline-grids
相关工作
这是Warp在cryo-EM图像中建模连续变形场和局部可变光学参数的PyTorch实现方式。Warp的方法在Dimitry Tegunov的论文中进行了描述。
Warp中的许多方法基于1-到3维空间的连续参数化。这种参数化通过在粗细均匀网格上的点之间进行样条插值来实现,这具有计算效率。网格扩展到需要建模的每个维度的全部。网格分辨率由每个维度的控制点数定义,并按物理约束(例如,帧数或像素数)和可用信号进行缩放。后者为防止过拟合稀疏数据提供正则化。当从空间(和时间)中的点检索由网格描述的参数时(例如,对于粒子(帧)),在网格上的该点执行B样条插值。通常,为了拟合网格的参数,优化与网格上特定位置的插值相关联的成本函数。
我推荐观看Freya Holmer的YouTube视频Splines的连续性来了解样条。
项目详情
下载文件
下载适合您平台的应用程序。如果您不确定选择哪个,请了解更多关于安装软件包的信息。
源代码分发
构建分发
torch_cubic_spline_grids-0.0.8.tar.gz 的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | fd16315635eb7d0e2a97c8982c6f878c20edf875532b3a0f980280065b4a5353 |
|
MD5 | 12373a0a752d89452e21fe7aecd53be1 |
|
BLAKE2b-256 | b9e724dd5f302a0f80c7fd66d3744f5e3e90d06fe2dd98450efcb6cbb9a25a54 |
torch_cubic_spline_grids-0.0.8-py3-none-any.whl 的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | 6d0de7aac91136120dc33e20f7e620fe4c3763b6b74f585e363f026cb1a58dc1 |
|
MD5 | c68209f8908c18d9192e10fe18ceab30 |
|
BLAKE2b-256 | ed33dbcd43ae6b682778690ca285733d561f47ba78a2534e34c98372ca2f8a7f |