包含方便函数以训练局域网的包
项目描述
LANfactory
一个轻量级的Python包,用于帮助训练LANs(似然近似网络)。
请在此处找到原始文档。
快速入门
LANfactory
包是一个轻量级的便捷包,用于在torch(或keras)中从提供的训练数据开始训练似然近似网络
(LANs)。
虽然LANs在潜在应用范围上更为通用,但它们是在序列抽样建模的背景下构思的,用于解释在认知科学中常见的n-选择强制选择实验中产生的选择和反应时间数据中的认知过程。
在本快速教程中,我们将使用ssms
包来生成我们的训练数据,使用这种序列抽样模型(SSM)。使用并不局限于使用ssms
包。
安装
要安装ssms
包,请输入:
pip install git+https://github.com/AlexanderFengler/ssm_simulators
要安装LANfactory
包,请输入:
pip install git+https://github.com/AlexanderFengler/LANfactory
必要的依赖项将在安装过程中自动安装。
基本教程
# Load necessary packages
import ssms
import lanfactory
import os
import numpy as np
from copy import deepcopy
import torch
生成训练数据
首先,我们需要生成一些训练数据。如上所述,我们将使用ssms
Python 包来完成此操作,但不会深入解释该包。如果您想了解更多信息,请参考[基本 ssms 教程] (https://github.com/AlexanderFengler/ssm_simulators)。
# MAKE CONFIGS
# Initialize the generator config (for MLP LANs)
generator_config = deepcopy(ssms.config.data_generator_config['lan']['mlp'])
# Specify generative model (one from the list of included models mentioned above)
generator_config['model'] = 'angle'
# Specify number of parameter sets to simulate
generator_config['n_parameter_sets'] = 100
# Specify how many samples a simulation run should entail
generator_config['n_samples'] = 1000
# Specify folder in which to save generated data
generator_config['output_folder'] = 'data/lan_mlp/'
# Make model config dict
model_config = ssms.config.model_config['angle']
# MAKE DATA
my_dataset_generator = ssms.dataset_generators.data_generator(generator_config = generator_config,
model_config = model_config)
training_data = my_dataset_generator.generate_data_training_uniform(save = True)
n_cpus used: 6
checking: data/lan_mlp/
simulation round: 1 of 10
simulation round: 2 of 10
simulation round: 3 of 10
simulation round: 4 of 10
simulation round: 5 of 10
simulation round: 6 of 10
simulation round: 7 of 10
simulation round: 8 of 10
simulation round: 9 of 10
simulation round: 10 of 10
Writing to file: data/lan_mlp/training_data_0_nbins_0_n_1000/angle/training_data_angle_ef5b9e0eb76c11eca684acde48001122.pickle
准备训练
接下来,我们使用 pytorch 设置训练数据加载器。LANfactory 使用自定义数据加载器,考虑到预期训练数据的特性。具体来说,我们预计会收到一些训练数据文件(本例中仅生成一个),每个文件包含大量训练示例。因此,我们希望定义一个数据加载器,该加载器从特定的训练数据文件中生成批次,并检查何时加载新文件。这里的实现方式是通过 lanfactory.trainers
中的 DatasetTorch
类,它继承自 torch.utils.data.Dataset
并预先指定了 batch_size
。最后,我们将它提供给一个 DataLoader
,其中我们保持 batch_size
参数为 0。
然后,DatasetTorch
类通过 DataLoader 作为迭代器调用,并负责内部进行批处理和文件加载。
您可以选择自己的方式定义 DataLoader
类,之后只需提供一个。
# MAKE DATALOADERS
# List of datafiles (here only one)
folder_ = 'data/lan_mlp/training_data_0_nbins_0_n_1000/angle/'
file_list_ = [folder_ + file_ for file_ in os.listdir(folder_)]
# Training dataset
torch_training_dataset = lanfactory.trainers.DatasetTorch(file_IDs = file_list_,
batch_size = 128)
torch_training_dataloader = torch.utils.data.DataLoader(torch_training_dataset,
shuffle = True,
batch_size = None,
num_workers = 1,
pin_memory = True)
# Validation dataset
torch_validation_dataset = lanfactory.trainers.DatasetTorch(file_IDs = file_list_,
batch_size = 128)
torch_validation_dataloader = torch.utils.data.DataLoader(torch_validation_dataset,
shuffle = True,
batch_size = None,
num_workers = 1,
pin_memory = True)
现在我们定义两个配置字典,
network_config
字典定义了网络的架构和属性train_config
字典定义了关于训练超参数的属性
以下提供了两个示例(我们将其作为包提供的示例,但您可以根据自己的需求进行调整)。
# SPECIFY NETWORK CONFIGS AND TRAINING CONFIGS
network_config = lanfactory.config.network_configs.network_config_mlp
print('Network config: ')
print(network_config)
train_config = lanfactory.config.network_configs.train_config_mlp
print('Train config: ')
print(train_config)
Network config:
{'layer_types': ['dense', 'dense', 'dense'], 'layer_sizes': [100, 100, 1], 'activations': ['tanh', 'tanh', 'linear'], 'loss': ['huber'], 'callbacks': ['checkpoint', 'earlystopping', 'reducelr']}
Train config:
{'batch_size': 128, 'n_epochs': 10, 'optimizer': 'adam', 'learning_rate': 0.002, 'loss': 'huber', 'save_history': True, 'metrics': [<keras.losses.MeanSquaredError object at 0x12c403d30>, <keras.losses.Huber object at 0x12c1c78e0>], 'callbacks': ['checkpoint', 'earlystopping', 'reducelr']}
现在我们可以加载网络,并保存配置文件以便于使用。
# LOAD NETWORK
net = lanfactory.trainers.TorchMLP(network_config = deepcopy(network_config),
input_shape = torch_training_dataset.input_dim,
save_folder = '/data/torch_models/',
generative_model_id = 'angle')
# SAVE CONFIGS
lanfactory.utils.save_configs(model_id = net.model_id + '_torch_',
save_folder = 'data/torch_models/angle/',
network_config = network_config,
train_config = train_config,
allow_abs_path_folder_generation = True)
最后,为了训练网络,我们将我们的网络、数据加载器和训练配置提供给来自 lanfactory.trainers
的 ModelTrainerTorchMLP
类。
# TRAIN MODEL
model_trainer.train_model(save_history = True,
save_model = True,
verbose = 0)
Epoch took 0 / 10, took 11.54538607597351 seconds
epoch 0 / 10, validation_loss: 0.3431
Epoch took 1 / 10, took 13.032279014587402 seconds
epoch 1 / 10, validation_loss: 0.2732
Epoch took 2 / 10, took 12.421074867248535 seconds
epoch 2 / 10, validation_loss: 0.1941
Epoch took 3 / 10, took 12.097641229629517 seconds
epoch 3 / 10, validation_loss: 0.2028
Epoch took 4 / 10, took 12.030233144760132 seconds
epoch 4 / 10, validation_loss: 0.184
Epoch took 5 / 10, took 12.695374011993408 seconds
epoch 5 / 10, validation_loss: 0.1433
Epoch took 6 / 10, took 12.177874326705933 seconds
epoch 6 / 10, validation_loss: 0.1115
Epoch took 7 / 10, took 11.908828258514404 seconds
epoch 7 / 10, validation_loss: 0.1084
Epoch took 8 / 10, took 12.066670179367065 seconds
epoch 8 / 10, validation_loss: 0.0864
Epoch took 9 / 10, took 12.37562108039856 seconds
epoch 9 / 10, validation_loss: 0.07484
Saving training history
Saving model state dict
Training finished successfully...
加载模型进行推理和调用
LANfactory 提供了一些方便的函数,用于在训练后使用网络进行推理。我们可以使用 LoadTorchMLPInfer
类加载模型,然后可以通过直接调用,该调用期望一个 torch.tensor
作为输入,或者 predict_on_batch
方法,该方法期望一个 numpy.array
的 dtype
为 np.float32
。
network_path_list = os.listdir('data/torch_models/angle')
network_file_path = ['data/torch_models/angle/' + file_ for file_ in network_path_list if 'state_dict' in file_][0]
network = lanfactory.trainers.LoadTorchMLPInfer(model_file_path = network_file_path,
network_config = network_config,
input_dim = torch_training_dataset.input_dim)
# Two ways to call the network
# Direct call --> need tensor input
direct_out = network(torch.from_numpy(np.array([1, 1.5, 0.5, 1.0, 0.1, 0.65, 1], dtype = np.float32)))
print('direct call out: ', direct_out)
# predict_on_batch method
predict_on_batch_out = network.predict_on_batch(np.array([1, 1.5, 0.5, 1.0, 0.1, 0.65, 1], dtype = np.float32))
print('predict_on_batch out: ', predict_on_batch_out)
direct call out: tensor([-16.4997])
predict_on_batch out: [-16.499687]
查看网络计算的第一段分布
我们可以将我们 network
中学习到的似然函数与底层生成模型的模拟数据进行比较。为此,我们再次调用 ssms
包。
import pandas as pd
import matplotlib.pyplot as plt
data = pd.DataFrame(np.zeros((2000, 7), dtype = np.float32), columns = ['v', 'a', 'z', 't', 'theta', 'rt', 'choice'])
data['v'] = 0.5
data['a'] = 0.75
data['z'] = 0.5
data['t'] = 0.2
data['theta'] = 0.1
data['rt'].iloc[:1000] = np.linspace(5, 0, 1000)
data['rt'].iloc[1000:] = np.linspace(0, 5, 1000)
data['choice'].iloc[:1000] = -1
data['choice'].iloc[1000:] = 1
# Network predictions
predict_on_batch_out = network.predict_on_batch(data.values.astype(np.float32))
# Simulations
from ssms.basic_simulators import simulator
sim_out = simulator(model = 'angle',
theta = data.values[0, :-2],
n_samples = 2000)
# Plot network predictions
plt.plot(data['rt'] * data['choice'], np.exp(predict_on_batch_out), color = 'black', label = 'network')
# Plot simulations
plt.hist(sim_out['rts'] * sim_out['choices'], bins = 30, histtype = 'step', label = 'simulations', color = 'blue', density = True)
plt.legend()
plt.title('SSM likelihood')
plt.xlabel('rt')
plt.ylabel('likelihod')
Text(0, 0.5, 'likelihod')
TorchMLP 到 ONNX 转换器
transform_onnx.py
脚本将 TorchMLP 模型转换为 ONNX 格式。它接受网络配置文件(pickle 格式)、状态字典文件(Torch 模型权重)、输入张量的大小以及所需的输出 ONNX 文件路径。
用法
python onnx/transform_onnx.py <network_config_file> <state_dict_file> <input_shape> <output_onnx_file>
用适当的值替换占位符
- <network_config_file>: 包含网络配置的 pickle 文件的路径。
- <state_dict_file>: 包含模型状态字典的文件的路径。
- <input_shape>: 模型输入张量的大小(整数)。
- <output_onnx_file>: 输出 ONNX 文件的路径。
例如
python onnx/transform_onnx.py '0d9f0e94175b11eca9e93cecef057438_lca_no_bias_4_torch__network_config.pickle' '0d9f0e94175b11eca9e93cecef057438_lca_no_bias_4_torch_state_dict.pt' 11 'lca_no_bias_4_torch.onnx'
此 onnx 文件可以直接用于 HSSM
包。
我们希望这个包在您尝试为自己的研究训练 LANs 时可能有所帮助。
END
项目详情
lanfactory-0.4.6.tar.gz的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | f46bd0d76308d62643a10d2d3b9bb82d70c2de4b020b54404f866450ec52c180 |
|
MD5 | b01812f5c98e594197928c8d305ca615 |
|
BLAKE2b-256 | deaec9747ec019f4015137279c462235a3556ceafd79f19755673dcd70528298 |