跳转到主要内容

核Stein差异下降

项目描述

GHActions PyPI

通过优化核Stein差异进行采样

论文可在arxiv.org/abs/2105.09994找到。

代码使用Pytorch,svgd可用numpy后端。

ksd_picture

安装

代码可在pip上找到

$ pip install ksddescent

文档

文档在pierreablin.github.io/ksddescent/

示例

主函数是ksdd_lbfgs,它使用快速L-BFGS算法快速收敛。它以粒子初始位置和得分函数为输入。例如,要从高斯(得分是恒等函数)中采样,可以使用以下简单代码行

>>> import torch
>>> from ksddescent import ksdd_lbfgs
>>> n, p = 50, 2
>>> x0 = torch.rand(n, p)  # start from uniform distribution
>>> score = lambda x: x  # simple score function
>>> x = ksdd_lbfgs(x0, score)  # run the algorithm

参考文献

如果您在项目中使用此代码,请引用

Anna Korba, Pierre-Cyril Aubin-Frankowski, Simon Majewski, Pierre Ablin
Kernel Stein Discrepancy Descent
International Conference on Machine Learning, 2021

错误报告

使用github issue tracker报告错误。

项目详情


下载文件

下载适用于您平台文件。如果您不确定选择哪个,请了解更多关于安装包的信息。

源代码分发

ksddescent-0.3.tar.gz (13.2 kB 查看哈希值)

上传时间: 源代码

构建版本

ksddescent-0.3-py3-none-any.whl (8.4 kB 查看哈希值)

上传时间 Python 3

支持者