跳转到主要内容

SSMS是一个收集认知科学/神经科学和近似贝叶斯计算社区中感兴趣的一组生成模型模拟器和训练数据生成器的软件包

项目描述

SSMS(顺序抽样模型模拟器)

Python软件包,收集顺序抽样模型的模拟器。

在此处查找软件包文档 此处

PyPI PyPI_dl Code style: black License: MIT

快速开始

ssms软件包有两个用途。

  1. 轻松访问顺序抽样模型的快速模拟器
  2. 支持构建各种似然/后验分配方法训练数据的基础设施

我们在此提供两个最小示例,说明如何使用这两种功能。

安装

让我们从安装ssms软件包开始。

您可以在终端中输入以下命令来完成此操作,

pip install git+https://github.com/AlexanderFengler/ssm_simulators

下面您将找到关于如何使用该软件包的基本教程。

教程

# Import necessary packages
import numpy as np
import pandas as pd
import ssms

使用模拟器

让我们从使用基本模拟器开始。您可以通过ssms.basic_simulators.simulator函数访问主要模拟器。

要了解ssms中包含的模型,请使用config模块。包含模型的元数据的中心字典位于ssms.config.model_config

# Check included models
list(ssms.config.model_config.keys())[:10]
['ddm',
 'ddm_legacy',
 'angle',
 'weibull',
 'levy',
 'levy_angle',
 'full_ddm',
 'ornstein',
 'ornstein_angle',
 'ddm_sdv']
# Take an example config for a given model
ssms.config.model_config['ddm']
{'name': 'ddm',
 'params': ['v', 'a', 'z', 't'],
 'param_bounds': [[-3.0, 0.3, 0.1, 0.0], [3.0, 2.5, 0.9, 2.0]],
 'boundary': <function ssms.basic_simulators.boundary_functions.constant(t=0)>,
 'n_params': 4,
 'default_params': [0.0, 1.0, 0.5, 0.001],
 'hddm_include': ['z'],
 'nchoices': 2}

注意:这些模型的通常结构包括,

  • 参数名称(《params》)
  • 参数的界限(《param_bounds》)
  • 定义相应模型边界的函数(《boundary》)
  • 参数的数量(《n_params》)
  • 参数的默认值(《default_params》)
  • 过程可以生成的选择数量(《nchoices》)

《hddm_include》键涉及与hddm Python包集成的有用信息,该包简化了顺序采样模型的分层贝叶斯推理。这对于本教程不重要。

from ssms.basic_simulators.simulator import simulator
sim_out = simulator(model = 'ddm', 
                    theta = {'v': 0, 
                             'a': 1,
                             'z': 0.5,
                             't': 0.5,
                    },
                    n_samples = 1000)

模拟器的输出是一个包含三个元素的《dictionary》。

  1. 《rts》(数组)
  2. 《choices》(数组)
  3. 《metadata》(字典)

《metadata》包括命名参数、模拟器设置等。

使用训练数据生成器

训练数据生成器位于模拟器函数之上,将原始模拟转换为可用于训练机器学习算法的可用训练数据,这些算法旨在进行后验或似然率装甲。

我们将使用来自 ssms.dataset_generatorsdata_generator 类。初始化 data_generator 简化为提供两个配置字典。

  1. 《generator_config》涉及关于要生成何种类型训练数据的选项。
  2. 《model_config》涉及关于底层生成 顺序采样模型 的选项。

我们将考虑一个基本示例,涉及为训练 LANs 而生成数据。

让我们先看看一个示例 generator_config

ssms.config.data_generator_config['lan']['mlp']
{'output_folder': 'data/lan_mlp/',
 'dgp_list': 'ddm',
 'nbins': 0,
 'n_samples': 100000,
 'n_parameter_sets': 10000,
 'n_parameter_sets_rejected': 100,
 'n_training_samples_by_parameter_set': 1000,
 'max_t': 20.0,
 'delta_t': 0.001,
 'pickleprotocol': 4,
 'n_cpus': 'all',
 'kde_data_mixture_probabilities': [0.8, 0.1, 0.1],
 'simulation_filters': {'mode': 20,
  'choice_cnt': 0,
  'mean_rt': 17,
  'std': 0,
  'mode_cnt_rel': 0.9},
 'negative_rt_cutoff': -66.77497,
 'n_subruns': 10,
 'bin_pointwise': False,
 'separate_response_channels': False}

通常只需对基本配置字典进行少数几个更改。以下是一个示例。

from copy import deepcopy
# Initialize the generator config (for MLP LANs)
generator_config = deepcopy(ssms.config.data_generator_config['lan']['mlp'])
# Specify generative model (one from the list of included models mentioned above)
generator_config['dgp_list'] = 'angle' 
# Specify number of parameter sets to simulate
generator_config['n_parameter_sets'] = 100 
# Specify how many samples a simulation run should entail
generator_config['n_samples'] = 1000

现在让我们定义我们的相应 model_config

model_config = ssms.config.model_config['angle']
print(model_config)
{'name': 'angle', 'params': ['v', 'a', 'z', 't', 'theta'], 
'param_bounds': [[-3.0, 0.3, 0.1, 0.001, -0.1], [3.0, 3.0, 0.9, 2.0, 1.3]], 
'boundary': <function angle at 0x11b2a7c10>, 
'n_params': 5, 
'default_params': [0.0, 1.0, 0.5, 0.001, 0.0], 
'hddm_include': ['z', 'theta'], 'nchoices': 2}

我们现在可以初始化一个 data_generator,之后我们可以使用 generate_data_training_uniform 函数生成训练数据,该函数将使用由我们的 model_config 中定义的参数界限定义的超立方体来均匀生成参数集和相应的模拟数据集。

my_dataset_generator = ssms.dataset_generators.data_generator(generator_config = generator_config,
                                                              model_config = model_config)
n_cpus used:  6
checking:  data/lan_mlp/
training_data = my_dataset_generator.generate_data_training_uniform(save = False)
simulation round: 1  of 10
simulation round: 2  of 10
simulation round: 3  of 10
simulation round: 4  of 10
simulation round: 5  of 10
simulation round: 6  of 10
simulation round: 7  of 10
simulation round: 8  of 10
simulation round: 9  of 10
simulation round: 10  of 10

《training_data》是一个包含四个键的字典

  1. 《data》是 LANs 的特征,包含 模型参数 的向量,以及 rtschoices
  2. 《labels》包含近似似然值
  3. 《generator_config》,如上所述
  4. 《model_config》,如上所述

你现在可以为此目的使用这些训练数据。如果你想自己训练 LANs,你可能觉得 LANfactory 包很有帮助。

你也可以简单地发现 ssms 包中提供的模拟器很有用,而不想将输出用于训练数据的减损目的。

结束

项目详情


下载文件

下载适用于您平台的文件。如果您不确定选择哪一个,请了解更多关于 安装包 的信息。

源分布

ssm_simulators-0.7.5.tar.gz (1.1 MB 查看哈希值

上传

构建版本

ssm_simulators-0.7.5-cp311-cp311-win_amd64.whl (361.4 kB 查看哈希值)

上传时间 CPython 3.11 Windows x86-64

ssm_simulators-0.7.5-cp311-cp311-musllinux_1_2_x86_64.whl (2.5 MB 查看哈希值)

上传时间 CPython 3.11 musllinux: musl 1.2+ x86-64

ssm_simulators-0.7.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.5 MB 查看哈希值)

上传时间 CPython 3.11 manylinux: glibc 2.17+ x86-64

ssm_simulators-0.7.5-cp311-cp311-macosx_11_0_arm64.whl (371.9 kB 查看哈希值)

上传时间 CPython 3.11 macOS 11.0+ ARM64

ssm_simulators-0.7.5-cp311-cp311-macosx_10_9_x86_64.whl (428.2 kB 查看哈希值)

上传时间 CPython 3.11 macOS 10.9+ x86-64

ssm_simulators-0.7.5-cp311-cp311-macosx_10_9_universal2.whl (766.9 kB 查看哈希值)

上传时间 CPython 3.11 macOS 10.9+ universal2 (ARM64, x86-64)

ssm_simulators-0.7.5-cp310-cp310-win_amd64.whl (360.3 kB 查看哈希值)

上传时间 CPython 3.10 Windows x86-64

ssm_simulators-0.7.5-cp310-cp310-musllinux_1_2_x86_64.whl (2.3 MB 查看哈希值)

上传时间 CPython 3.10 musllinux: musl 1.2+ x86-64

ssm_simulators-0.7.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.3 MB 查看哈希值)

上传时间 CPython 3.10 manylinux: glibc 2.17+ x86-64

ssm_simulators-0.7.5-cp310-cp310-macosx_11_0_arm64.whl (371.1 kB 查看哈希值)

上传时间 CPython 3.10 macOS 11.0+ ARM64

ssm_simulators-0.7.5-cp310-cp310-macosx_10_9_x86_64.whl (427.8 kB 查看哈希)

上传于 CPython 3.10 macOS 10.9+ x86-64

ssm_simulators-0.7.5-cp310-cp310-macosx_10_9_universal2.whl (765.4 kB 查看哈希)

上传于 CPython 3.10 macOS 10.9+ universal2 (ARM64, x86-64)

由以下组织支持