跳转到主要内容

包含方便函数以训练局域网的包

项目描述

LANfactory

PyPI PyPI_dl Code style: black License: MIT

一个轻量级的Python包,用于帮助训练LANs(似然近似网络)。

请在此处找到原始文档

快速入门

LANfactory包是一个轻量级的便捷包,用于在torch(或keras)中从提供的训练数据开始训练似然近似网络(LANs)。

虽然LANs在潜在应用范围上更为通用,但它们是在序列抽样建模的背景下构思的,用于解释在认知科学中常见的n-选择强制选择实验中产生的选择反应时间数据中的认知过程。

在本快速教程中,我们将使用ssms包来生成我们的训练数据,使用这种序列抽样模型(SSM)。使用并不局限于使用ssms包。

安装

要安装ssms包,请输入:

pip install git+https://github.com/AlexanderFengler/ssm_simulators

要安装LANfactory包,请输入:

pip install git+https://github.com/AlexanderFengler/LANfactory

必要的依赖项将在安装过程中自动安装。

基本教程

# Load necessary packages
import ssms
import lanfactory 
import os
import numpy as np
from copy import deepcopy
import torch

生成训练数据

首先,我们需要生成一些训练数据。如上所述,我们将使用ssms Python 包来完成此操作,但不会深入解释该包。如果您想了解更多信息,请参考[基本 ssms 教程] (https://github.com/AlexanderFengler/ssm_simulators)。

# MAKE CONFIGS

# Initialize the generator config (for MLP LANs)
generator_config = deepcopy(ssms.config.data_generator_config['lan']['mlp'])
# Specify generative model (one from the list of included models mentioned above)
generator_config['model'] = 'angle' 
# Specify number of parameter sets to simulate
generator_config['n_parameter_sets'] = 100 
# Specify how many samples a simulation run should entail
generator_config['n_samples'] = 1000
# Specify folder in which to save generated data
generator_config['output_folder'] = 'data/lan_mlp/'

# Make model config dict
model_config = ssms.config.model_config['angle']
# MAKE DATA

my_dataset_generator = ssms.dataset_generators.data_generator(generator_config = generator_config,
                                                              model_config = model_config)

training_data = my_dataset_generator.generate_data_training_uniform(save = True)
n_cpus used:  6
checking:  data/lan_mlp/
simulation round: 1  of 10
simulation round: 2  of 10
simulation round: 3  of 10
simulation round: 4  of 10
simulation round: 5  of 10
simulation round: 6  of 10
simulation round: 7  of 10
simulation round: 8  of 10
simulation round: 9  of 10
simulation round: 10  of 10
Writing to file:  data/lan_mlp/training_data_0_nbins_0_n_1000/angle/training_data_angle_ef5b9e0eb76c11eca684acde48001122.pickle

准备训练

接下来,我们使用 pytorch 设置训练数据加载器。LANfactory 使用自定义数据加载器,考虑到预期训练数据的特性。具体来说,我们预计会收到一些训练数据文件(本例中仅生成一个),每个文件包含大量训练示例。因此,我们希望定义一个数据加载器,该加载器从特定的训练数据文件中生成批次,并检查何时加载新文件。这里的实现方式是通过 lanfactory.trainers 中的 DatasetTorch 类,它继承自 torch.utils.data.Dataset 并预先指定了 batch_size。最后,我们将它提供给一个 DataLoader,其中我们保持 batch_size 参数为 0。

然后,DatasetTorch 类通过 DataLoader 作为迭代器调用,并负责内部进行批处理和文件加载。

您可以选择自己的方式定义 DataLoader 类,之后只需提供一个。

# MAKE DATALOADERS

# List of datafiles (here only one)
folder_ = 'data/lan_mlp/training_data_0_nbins_0_n_1000/angle/'
file_list_ = [folder_ + file_ for file_ in os.listdir(folder_)]

# Training dataset
torch_training_dataset = lanfactory.trainers.DatasetTorch(file_IDs = file_list_,
                                                          batch_size = 128)

torch_training_dataloader = torch.utils.data.DataLoader(torch_training_dataset,
                                                         shuffle = True,
                                                         batch_size = None,
                                                         num_workers = 1,
                                                         pin_memory = True)

# Validation dataset
torch_validation_dataset = lanfactory.trainers.DatasetTorch(file_IDs = file_list_,
                                                          batch_size = 128)

torch_validation_dataloader = torch.utils.data.DataLoader(torch_validation_dataset,
                                                          shuffle = True,
                                                          batch_size = None,
                                                          num_workers = 1,
                                                          pin_memory = True)

现在我们定义两个配置字典,

  1. network_config 字典定义了网络的架构和属性
  2. train_config 字典定义了关于训练超参数的属性

以下提供了两个示例(我们将其作为包提供的示例,但您可以根据自己的需求进行调整)。

# SPECIFY NETWORK CONFIGS AND TRAINING CONFIGS

network_config = lanfactory.config.network_configs.network_config_mlp

print('Network config: ')
print(network_config)

train_config = lanfactory.config.network_configs.train_config_mlp

print('Train config: ')
print(train_config)
Network config: 
{'layer_types': ['dense', 'dense', 'dense'], 'layer_sizes': [100, 100, 1], 'activations': ['tanh', 'tanh', 'linear'], 'loss': ['huber'], 'callbacks': ['checkpoint', 'earlystopping', 'reducelr']}
Train config: 
{'batch_size': 128, 'n_epochs': 10, 'optimizer': 'adam', 'learning_rate': 0.002, 'loss': 'huber', 'save_history': True, 'metrics': [<keras.losses.MeanSquaredError object at 0x12c403d30>, <keras.losses.Huber object at 0x12c1c78e0>], 'callbacks': ['checkpoint', 'earlystopping', 'reducelr']}

现在我们可以加载网络,并保存配置文件以便于使用。

# LOAD NETWORK
net = lanfactory.trainers.TorchMLP(network_config = deepcopy(network_config),
                                   input_shape = torch_training_dataset.input_dim,
                                   save_folder = '/data/torch_models/',
                                   generative_model_id = 'angle')

# SAVE CONFIGS
lanfactory.utils.save_configs(model_id = net.model_id + '_torch_',
                              save_folder = 'data/torch_models/angle/', 
                              network_config = network_config, 
                              train_config = train_config, 
                              allow_abs_path_folder_generation = True)

最后,为了训练网络,我们将我们的网络、数据加载器和训练配置提供给来自 lanfactory.trainersModelTrainerTorchMLP 类。

# TRAIN MODEL
model_trainer.train_model(save_history = True,
                          save_model = True,
                          verbose = 0)
Epoch took 0 / 10,  took 11.54538607597351 seconds
epoch 0 / 10, validation_loss: 0.3431
Epoch took 1 / 10,  took 13.032279014587402 seconds
epoch 1 / 10, validation_loss: 0.2732
Epoch took 2 / 10,  took 12.421074867248535 seconds
epoch 2 / 10, validation_loss: 0.1941
Epoch took 3 / 10,  took 12.097641229629517 seconds
epoch 3 / 10, validation_loss: 0.2028
Epoch took 4 / 10,  took 12.030233144760132 seconds
epoch 4 / 10, validation_loss: 0.184
Epoch took 5 / 10,  took 12.695374011993408 seconds
epoch 5 / 10, validation_loss: 0.1433
Epoch took 6 / 10,  took 12.177874326705933 seconds
epoch 6 / 10, validation_loss: 0.1115
Epoch took 7 / 10,  took 11.908828258514404 seconds
epoch 7 / 10, validation_loss: 0.1084
Epoch took 8 / 10,  took 12.066670179367065 seconds
epoch 8 / 10, validation_loss: 0.0864
Epoch took 9 / 10,  took 12.37562108039856 seconds
epoch 9 / 10, validation_loss: 0.07484
Saving training history
Saving model state dict
Training finished successfully...

加载模型进行推理和调用

LANfactory 提供了一些方便的函数,用于在训练后使用网络进行推理。我们可以使用 LoadTorchMLPInfer 类加载模型,然后可以通过直接调用,该调用期望一个 torch.tensor 作为输入,或者 predict_on_batch 方法,该方法期望一个 numpy.arraydtypenp.float32

network_path_list = os.listdir('data/torch_models/angle')
network_file_path = ['data/torch_models/angle/' + file_ for file_ in network_path_list if 'state_dict' in file_][0]

network = lanfactory.trainers.LoadTorchMLPInfer(model_file_path = network_file_path,
                                                network_config = network_config,
                                                input_dim = torch_training_dataset.input_dim)
# Two ways to call the network

# Direct call --> need tensor input
direct_out = network(torch.from_numpy(np.array([1, 1.5, 0.5, 1.0, 0.1, 0.65, 1], dtype  = np.float32)))
print('direct call out: ', direct_out)

# predict_on_batch method
predict_on_batch_out = network.predict_on_batch(np.array([1, 1.5, 0.5, 1.0, 0.1, 0.65, 1], dtype  = np.float32))
print('predict_on_batch out: ', predict_on_batch_out)
direct call out:  tensor([-16.4997])
predict_on_batch out:  [-16.499687]

查看网络计算的第一段分布

我们可以将我们 network 中学习到的似然函数与底层生成模型的模拟数据进行比较。为此,我们再次调用 ssms 包。

import pandas as pd
import matplotlib.pyplot as plt

data = pd.DataFrame(np.zeros((2000, 7), dtype = np.float32), columns = ['v', 'a', 'z', 't', 'theta', 'rt', 'choice'])
data['v'] = 0.5
data['a'] = 0.75
data['z'] = 0.5
data['t'] = 0.2
data['theta'] = 0.1
data['rt'].iloc[:1000] = np.linspace(5, 0, 1000)
data['rt'].iloc[1000:] = np.linspace(0, 5, 1000)
data['choice'].iloc[:1000] = -1
data['choice'].iloc[1000:] = 1

# Network predictions
predict_on_batch_out = network.predict_on_batch(data.values.astype(np.float32))

# Simulations
from ssms.basic_simulators import simulator
sim_out = simulator(model = 'angle', 
                    theta = data.values[0, :-2],
                    n_samples = 2000)
# Plot network predictions
plt.plot(data['rt'] * data['choice'], np.exp(predict_on_batch_out), color = 'black', label = 'network')

# Plot simulations
plt.hist(sim_out['rts'] * sim_out['choices'], bins = 30, histtype = 'step', label = 'simulations', color = 'blue', density  = True)
plt.legend()
plt.title('SSM likelihood')
plt.xlabel('rt')
plt.ylabel('likelihod')
Text(0, 0.5, 'likelihod')

png

TorchMLP 到 ONNX 转换器

transform_onnx.py 脚本将 TorchMLP 模型转换为 ONNX 格式。它接受网络配置文件(pickle 格式)、状态字典文件(Torch 模型权重)、输入张量的大小以及所需的输出 ONNX 文件路径。

用法

python onnx/transform_onnx.py <network_config_file> <state_dict_file> <input_shape> <output_onnx_file>

用适当的值替换占位符

  • <network_config_file>: 包含网络配置的 pickle 文件的路径。
  • <state_dict_file>: 包含模型状态字典的文件的路径。
  • <input_shape>: 模型输入张量的大小(整数)。
  • <output_onnx_file>: 输出 ONNX 文件的路径。

例如

python onnx/transform_onnx.py '0d9f0e94175b11eca9e93cecef057438_lca_no_bias_4_torch__network_config.pickle' '0d9f0e94175b11eca9e93cecef057438_lca_no_bias_4_torch_state_dict.pt' 11 'lca_no_bias_4_torch.onnx'

此 onnx 文件可以直接用于 HSSM 包。

我们希望这个包在您尝试为自己的研究训练 LANs 时可能有所帮助。

END

项目详情


下载文件

下载适合您平台的文件。如果您不确定选择哪个,请了解更多关于安装包的信息。

源分布

lanfactory-0.4.6.tar.gz (638.1 kB 查看哈希值)

上传时间

由以下支持