核Stein差异下降
项目描述
通过优化核Stein差异进行采样
论文可在arxiv.org/abs/2105.09994找到。
代码使用Pytorch,svgd可用numpy后端。
安装
代码可在pip上找到
$ pip install ksddescent
文档
示例
主函数是ksdd_lbfgs,它使用快速L-BFGS算法快速收敛。它以粒子初始位置和得分函数为输入。例如,要从高斯(得分是恒等函数)中采样,可以使用以下简单代码行
>>> import torch
>>> from ksddescent import ksdd_lbfgs
>>> n, p = 50, 2
>>> x0 = torch.rand(n, p) # start from uniform distribution
>>> score = lambda x: x # simple score function
>>> x = ksdd_lbfgs(x0, score) # run the algorithm
参考文献
如果您在项目中使用此代码,请引用
Anna Korba, Pierre-Cyril Aubin-Frankowski, Simon Majewski, Pierre Ablin Kernel Stein Discrepancy Descent International Conference on Machine Learning, 2021
错误报告
使用github issue tracker报告错误。
项目详情
下载文件
下载适用于您平台文件。如果您不确定选择哪个,请了解更多关于安装包的信息。
源代码分发
ksddescent-0.3.tar.gz (13.2 kB 查看哈希值)
构建版本
ksddescent-0.3-py3-none-any.whl (8.4 kB 查看哈希值)
关闭
ksddescent-0.3.tar.gz 的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | 8104173b61049244aa6649e6f2970398c8d5aabc3b70a4d1ce8f4de9955a771f |
|
MD5 | 85eba9189349f185550c55b5043f7d6a |
|
BLAKE2b-256 | 76720a9bf5e7ceae33c77e09c51f72d6bb9be9bb8ef6cdc381f08da82ac576ab |
关闭
ksddescent-0.3-py3-none-any.whl 的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | 093e32135cb48f2b711d4d4c407738d47bf93b050204054a126e7176b7404c31 |
|
MD5 | 11831d593dd1bec8fdca46da28790205 |
|
BLAKE2b-256 | 8d6fd96f19bbaae7e7a89c7a66e4192bce38d43f783570a4b76ee537ef4381fa |