跳转到主要内容

PyTorch中的可微分拉普拉斯重建

项目描述

PyTorch实现的可微分拉普拉斯重建

Documentation Status Tests arXiv License: MIT

此库提供了在PyTorch中实现的逆拉普拉斯变换(ILT)算法。支持使用黎曼单叶双曲函数投影在拉普拉斯域中通过微分方程(DE)解的反传播,以更好地表示复数拉普拉斯域的全局性。有关在深度学习应用中拉普拉斯域的DE表示的用法,请参阅参考文献 [1]

安装

安装最新稳定版本

pip install torchlaplace

从GitHub安装最新版本

pip install git+https://github.com/samholt/NeuralLaplace.git

教程

  1. 教程:拉普拉斯重建 Test In Colab
  2. 教程:逆拉普拉斯变换算法 Test In Colab

示例

示例放置在 examples 目录中。

鼓励有兴趣使用这个库的人查看 examples/simple_demo.py,以了解如何使用 torchlaplace 来拟合DE系统。

Lotka Volterra DDE Demo

基本用法

此库提供了一个主要接口 laplace_reconstruct,它使用选定的逆拉普拉斯变换算法从提供的参数化拉普拉斯表示功能 $\mathbf{F}(\mathbf{p},\mathbf{s})$ 中重建轨迹,

$$\mathbf{x}(t) = \text{inverse laplace transform}(\mathbf{F}(\mathbf{p},\mathbf{s}), t)$$

其中 $\mathbf{p}$ 是一个Tensor,表示初始系统状态作为潜在变量,$t$ 是要重建轨迹的时间点。

这可以通过以下方式使用:

from torchlaplace import laplace_reconstruct

laplace_reconstruct(laplace_rep_func, p, t)

其中 laplace_rep_func 是任何实现参数化拉普拉斯表示功能的可调用对象 $\mathbf{F}(\mathbf{p},\mathbf{s})$,p 是一个形状为 $(\text{MiniBatchSize},\text{K})$ 的Tensor,表示初始状态。其中 $\text{K}$ 是一个超参数,可以由用户设置。最后,t 是一个形状为 $(\text{MiniBatchSize},\text{SeqLen})$ 或 $(\text{SeqLen})$ 的Tensor,包含要重建轨迹的时间点。

请注意,这并非对所有ILT方法都数值稳定,但默认的 fourier(傅里叶级数逆变换)ILT算法可能没问题。

参数化拉普拉斯表示功能 laplace_rep_func,$\mathbf{F}(\mathbf{p},\mathbf{s})$ 还接受一个复数值 $\mathbf{s}$ 作为输入。这个 $\mathbf{s}$ 在使用选定的逆拉普拉斯变换算法 ilt_algorithm 重建指定时间点时内部使用。

最大的 问题 是当使用 laplace_rep_func 函数时,laplace_rep_func 必须是一个 nn.Module。这是由于需要内部收集参数化拉普拉斯表示的参数。

要复制[1]中的实验,请查看 experiments 目录。

laplace_rep_func 的关键字参数

关键字参数

  • recon_dim (int):给定时间点的轨迹维度。对应于dim $d_{\text{obs}}$。如果没有明确指定,将使用与 p 相同的最后维度,即 $\text{K}$。
  • ilt_algorithm (str):要使用的逆拉普拉斯变换算法。默认:fourier。可用的有 {fourierdehoogcmefixed_tablotstehfest}。有关ILT的详细信息,请参阅API文档。
  • use_sphere_projection (bool):这使用 laplace_rep_func 在黎曼球体的球面投影中。默认 True
  • ilt_reconstruction_terms (int):ILT重建项的数量,即重建单个时间点时在 laplace_rep_func 中的复数 $s$ 点数。

ILT算法列表

实现的ILT算法

  • fourier 傅里叶级数逆变换 [默认]
  • dehoog DeHoog(傅里叶的加速版本) - 与比较慢的推理。
  • cme 集中矩阵指数。
  • fixed_tablot 固定Tablot。
  • stehfest Gaver-Stehfest。

对于大多数问题,好的选择是默认的 fourier。然而,在其他情况下,例如使用更高的ILT重建项时,如 cme 算法,其他ILT算法可能更合适。一些允许在速度和精度之间进行权衡,例如,如果表示已知或准确,dehoog 非常准确,但速度慢,在学习正确的表示时可能会不稳定。

详细文档

有关详细文档,请参阅 官方文档

常见问题解答

请参阅我们的 FAQ 以获取常见问题。

参考文献

有关拉普拉斯域中DE表示的使用、利用球面投影和其他应用,请参阅

[1] Samuel Holt, Zhaozhi Qian, 和 Mihaela van der Schaar. "拉普拉斯域中的神经网络:学习不同类别的微分方程." 机器学习国际会议. 2022. [arxiv]


如果您在研究中发现这个库很有用,请考虑引用。

@inproceedings{holt2022neural,
  title={Neural Laplace: Learning diverse classes of differential equations in the Laplace domain},
  author={Holt, Samuel I and Qian, Zhaozhi and van der Schaar, Mihaela},
  booktitle={International Conference on Machine Learning},
  pages={8811--8832},
  year={2022},
  organization={PMLR}
}

项目详情


下载文件

下载适用于您平台的文件。如果您不确定选择哪一个,请了解更多关于安装包的信息。

源代码分布

本发布没有可用的源代码分布文件。请参阅生成分布存档的教程。

构建的分布

torchlaplace-0.0.4-py3-none-macosx_10_14_x86_64.whl (3.1 MB 查看哈希值)

上传时间 Python 3 macOS 10.14+ x86_64

torchlaplace-0.0.4-py3-none-any.whl (3.1 MB 查看哈希值)

上传时间 Python 3

由支持