跳转到主要内容

在PyTorch的多维网格上进行三次样条插值

项目描述

torch-cubic-spline-grids

License PyPI Python Version CI codecov

在PyTorch的多维网格上进行三次样条插值。

此包的主要目标是提供1-4D空间的可学习的连续参数化。


概述

torch_cubic_spline_grids提供了一组名为grids的PyTorch组件。

网格由以下参数定义

  • 它们的维度(1d,2d,3d,4d...)
  • 每个维度覆盖的点的数量(resolution
  • 每个网格点上存储的值的数量(n_channels
  • 我们如何在网格点上插值值

此包中的所有网格都由均匀分布在每个维度全范围上的点组成。

第一步

让我们创建一个简单的2D网格,每个网格点上有一个值。

import torch
from torch_cubic_spline_grids import CubicBSplineGrid2d

grid = CubicBSplineGrid2d(resolution=(5, 3), n_channels=1)
  • grid.ndim2
  • grid.resolution(5, 3)(或(h, w)
  • grid.n_channels1
  • grid.data.shape(1, 5, 3)(或(c, h, w)

换句话说,网格跨越两个维度(h, w),在h中有5个点,在w中有3个点。2D网格上的每个点存储一个值。网格数据存储在形状为(c, *grid_resolution)的张量中。

我们可以获得网格上任何连续点的值(插值)。网格坐标系沿着每个网格维度从[0, 1]扩展。通过沿着网格的每个维度逐个应用三次样条插值来获得插值。

coords = torch.rand(size=(10, 2))  # values in [0, 1]
interpolants = grid(coords)
  • interpolants.shape(10, 1)

优化

通过最小化与网格插值相关的损失函数,可以优化每个网格点的值。这样可以将网格的连续空间更精确地建模为1-4D空间。

上图显示了1D网格上6个控制点的值正在被优化,以便通过三次B样条插值在这些点之间进行插值时,可以近似正弦波的单次振荡。

1D示例和类似的2D示例都有可用的笔记本。

网格类型

torch_cubic_spline_grids提供可以用于三次B样条插值或三次Catmull-Rom样条插值的网格。

样条 连续性 插值?
三次B样条 C2
Catmull-Rom样条 C1

如果您需要得到的曲线与网格上的数据进行交点,应使用三次Catmull-Rom样条网格

  • CubicCatmullRomGrid1d
  • CubicCatmullRomGrid2d
  • CubicCatmullRomGrid3d
  • CubicCatmullRomGrid4d

如果您需要连续的二阶导数,则三次B样条网格更为合适。

  • CubicBSplineGrid1d
  • CubicBSplineGrid2d
  • CubicBSplineGrid3d
  • CubicBSplineGrid4d

正则化

每个维度中的点数应选择得使得在网格上插值可以近似模型化的底层现象,而不至于过拟合。低分辨率网格通过平滑模型提供正则化效果。

安装

torch_cubic_spline_grids可在PyPI上使用

pip install torch-cubic-spline-grids

相关工作

这是Warp在cryo-EM图像中建模连续变形场和局部可变光学参数的PyTorch实现方式。Warp的方法在Dimitry Tegunov的论文中进行了描述。

Warp中的许多方法基于1-到3维空间的连续参数化。这种参数化通过在粗细均匀网格上的点之间进行样条插值来实现,这具有计算效率。网格扩展到需要建模的每个维度的全部。网格分辨率由每个维度的控制点数定义,并按物理约束(例如,帧数或像素数)和可用信号进行缩放。后者为防止过拟合稀疏数据提供正则化。当从空间(和时间)中的点检索由网格描述的参数时(例如,对于粒子(帧)),在网格上的该点执行B样条插值。通常,为了拟合网格的参数,优化与网格上特定位置的插值相关联的成本函数。


我推荐观看Freya Holmer的YouTube视频Splines的连续性来了解样条。

样条的连续性 - YouTube

项目详情


下载文件

下载适合您平台的应用程序。如果您不确定选择哪个,请了解更多关于安装软件包的信息。

源代码分发

torch_cubic_spline_grids-0.0.8.tar.gz (148.8 kB 查看哈希值)

上传时间 源代码

构建分发

torch_cubic_spline_grids-0.0.8-py3-none-any.whl (14.4 kB 查看哈希值)

上传时间 Python 3

由以下支持

AWSAWS 云计算和安全赞助商 DatadogDatadog 监控 FastlyFastly CDN GoogleGoogle 下载分析 MicrosoftMicrosoft PSF赞助商 PingdomPingdom 监控 SentrySentry 错误日志 StatusPageStatusPage 状态页面