Skip to main content

Package with convenience functions to train LANs

Project description

LANfactory

PyPI PyPI_dl Code style: black License: MIT

Lightweight python package to help with training LANs (Likelihood approximation networks).

Please find the original documentation here.

Quick Start

The LANfactory package is a light-weight convenience package for training likelihood approximation networks (LANs) in torch (or keras), starting from supplied training data.

LANs, although more general in potential scope of applications, were conceived in the context of sequential sampling modeling to account for cognitive processes giving rise to choice and reaction time data in n-alternative forced choice experiments commonly encountered in the cognitive sciences.

In this quick tutorial we will use the ssms package to generate our training data using such a sequential sampling model (SSM). The use is in no way bound to utilize the ssms package.

Install

To install the ssms package type,

pip install git+https://github.com/AlexanderFengler/ssm_simulators

To install the LANfactory package type,

pip install git+https://github.com/AlexanderFengler/LANfactory

Necessary dependency should be installed automatically in the process.

Basic Tutorial

# Load necessary packages
import ssms
import lanfactory 
import os
import numpy as np
from copy import deepcopy
import torch

Generate Training Data

First we need to generate some training data. As mentioned above we will do so using the ssms python package, however without delving into a detailed explanation of this package. Please refer to the [basic ssms tutorial] (https://github.com/AlexanderFengler/ssm_simulators) in case you want to learn more.

# MAKE CONFIGS

# Initialize the generator config (for MLP LANs)
generator_config = deepcopy(ssms.config.data_generator_config['lan']['mlp'])
# Specify generative model (one from the list of included models mentioned above)
generator_config['model'] = 'angle' 
# Specify number of parameter sets to simulate
generator_config['n_parameter_sets'] = 100 
# Specify how many samples a simulation run should entail
generator_config['n_samples'] = 1000
# Specify folder in which to save generated data
generator_config['output_folder'] = 'data/lan_mlp/'

# Make model config dict
model_config = ssms.config.model_config['angle']
# MAKE DATA

my_dataset_generator = ssms.dataset_generators.data_generator(generator_config = generator_config,
                                                              model_config = model_config)

training_data = my_dataset_generator.generate_data_training_uniform(save = True)
n_cpus used:  6
checking:  data/lan_mlp/
simulation round: 1  of 10
simulation round: 2  of 10
simulation round: 3  of 10
simulation round: 4  of 10
simulation round: 5  of 10
simulation round: 6  of 10
simulation round: 7  of 10
simulation round: 8  of 10
simulation round: 9  of 10
simulation round: 10  of 10
Writing to file:  data/lan_mlp/training_data_0_nbins_0_n_1000/angle/training_data_angle_ef5b9e0eb76c11eca684acde48001122.pickle

Prepare for Training

Next we set up dataloaders for training with pytorch. The LANfactory uses custom dataloaders, taking into account particularities of the expected training data. Specifically, we expect to receive a bunch of training data files (the present example generates only one), where each file hosts a large number of training examples. So we want to define a dataloader which spits out batches from data with a specific training data file, and keeps checking when to load in a new file. The way this is implemented here, is via the DatasetTorch class in lanfactory.trainers, which inherits from torch.utils.data.Dataset and prespecifies a batch_size. Finally this is supplied to a DataLoader, for which we keep the batch_size argument at 0.

The DatasetTorch class is then called as an iterator via the DataLoader and takes care of batching as well as file loading internally.

You may choose your own way of defining the DataLoader classes, downstream you are simply expected to supply one.

# MAKE DATALOADERS

# List of datafiles (here only one)
folder_ = 'data/lan_mlp/training_data_0_nbins_0_n_1000/angle/'
file_list_ = [folder_ + file_ for file_ in os.listdir(folder_)]

# Training dataset
torch_training_dataset = lanfactory.trainers.DatasetTorch(file_IDs = file_list_,
                                                          batch_size = 128)

torch_training_dataloader = torch.utils.data.DataLoader(torch_training_dataset,
                                                         shuffle = True,
                                                         batch_size = None,
                                                         num_workers = 1,
                                                         pin_memory = True)

# Validation dataset
torch_validation_dataset = lanfactory.trainers.DatasetTorch(file_IDs = file_list_,
                                                          batch_size = 128)

torch_validation_dataloader = torch.utils.data.DataLoader(torch_validation_dataset,
                                                          shuffle = True,
                                                          batch_size = None,
                                                          num_workers = 1,
                                                          pin_memory = True)

Now we define two configuration dictionariers,

  1. The network_config dictionary defines the architecture and properties of the network
  2. The train_config dictionary defines properties concerning training hyperparameters

Two examples (which we take as provided by the package, but which you can adjust according to your needs) are provided below.

# SPECIFY NETWORK CONFIGS AND TRAINING CONFIGS

network_config = lanfactory.config.network_configs.network_config_mlp

print('Network config: ')
print(network_config)

train_config = lanfactory.config.network_configs.train_config_mlp

print('Train config: ')
print(train_config)
Network config: 
{'layer_types': ['dense', 'dense', 'dense'], 'layer_sizes': [100, 100, 1], 'activations': ['tanh', 'tanh', 'linear'], 'loss': ['huber'], 'callbacks': ['checkpoint', 'earlystopping', 'reducelr']}
Train config: 
{'batch_size': 128, 'n_epochs': 10, 'optimizer': 'adam', 'learning_rate': 0.002, 'loss': 'huber', 'save_history': True, 'metrics': [<keras.losses.MeanSquaredError object at 0x12c403d30>, <keras.losses.Huber object at 0x12c1c78e0>], 'callbacks': ['checkpoint', 'earlystopping', 'reducelr']}

We can now load a network, and save the configuration files for convenience.

# LOAD NETWORK
net = lanfactory.trainers.TorchMLP(network_config = deepcopy(network_config),
                                   input_shape = torch_training_dataset.input_dim,
                                   save_folder = '/data/torch_models/',
                                   generative_model_id = 'angle')

# SAVE CONFIGS
lanfactory.utils.save_configs(model_id = net.model_id + '_torch_',
                              save_folder = 'data/torch_models/angle/', 
                              network_config = network_config, 
                              train_config = train_config, 
                              allow_abs_path_folder_generation = True)

To finally train the network we supply our network, the dataloaders and training config to the ModelTrainerTorchMLP class, from lanfactory.trainers.

# TRAIN MODEL
model_trainer.train_model(save_history = True,
                          save_model = True,
                          verbose = 0)
Epoch took 0 / 10,  took 11.54538607597351 seconds
epoch 0 / 10, validation_loss: 0.3431
Epoch took 1 / 10,  took 13.032279014587402 seconds
epoch 1 / 10, validation_loss: 0.2732
Epoch took 2 / 10,  took 12.421074867248535 seconds
epoch 2 / 10, validation_loss: 0.1941
Epoch took 3 / 10,  took 12.097641229629517 seconds
epoch 3 / 10, validation_loss: 0.2028
Epoch took 4 / 10,  took 12.030233144760132 seconds
epoch 4 / 10, validation_loss: 0.184
Epoch took 5 / 10,  took 12.695374011993408 seconds
epoch 5 / 10, validation_loss: 0.1433
Epoch took 6 / 10,  took 12.177874326705933 seconds
epoch 6 / 10, validation_loss: 0.1115
Epoch took 7 / 10,  took 11.908828258514404 seconds
epoch 7 / 10, validation_loss: 0.1084
Epoch took 8 / 10,  took 12.066670179367065 seconds
epoch 8 / 10, validation_loss: 0.0864
Epoch took 9 / 10,  took 12.37562108039856 seconds
epoch 9 / 10, validation_loss: 0.07484
Saving training history
Saving model state dict
Training finished successfully...

Load Model for Inference and Call

The LANfactory provides some convenience functions to use networks for inference after training. We can load a model using the LoadTorchMLPInfer class, which then allows us to run fast inference via either a direct call, which expects a torch.tensor as input, or the predict_on_batch method, which expects a numpy.array of dtype, np.float32.

network_path_list = os.listdir('data/torch_models/angle')
network_file_path = ['data/torch_models/angle/' + file_ for file_ in network_path_list if 'state_dict' in file_][0]

network = lanfactory.trainers.LoadTorchMLPInfer(model_file_path = network_file_path,
                                                network_config = network_config,
                                                input_dim = torch_training_dataset.input_dim)
# Two ways to call the network

# Direct call --> need tensor input
direct_out = network(torch.from_numpy(np.array([1, 1.5, 0.5, 1.0, 0.1, 0.65, 1], dtype  = np.float32)))
print('direct call out: ', direct_out)

# predict_on_batch method
predict_on_batch_out = network.predict_on_batch(np.array([1, 1.5, 0.5, 1.0, 0.1, 0.65, 1], dtype  = np.float32))
print('predict_on_batch out: ', predict_on_batch_out)
direct call out:  tensor([-16.4997])
predict_on_batch out:  [-16.499687]

A peek into the first passage distribution computed by the network

We can compare the learned likelihood function in our network with simulation data from the underlying generative model. For this purpose we recruit the ssms package again.

import pandas as pd
import matplotlib.pyplot as plt

data = pd.DataFrame(np.zeros((2000, 7), dtype = np.float32), columns = ['v', 'a', 'z', 't', 'theta', 'rt', 'choice'])
data['v'] = 0.5
data['a'] = 0.75
data['z'] = 0.5
data['t'] = 0.2
data['theta'] = 0.1
data['rt'].iloc[:1000] = np.linspace(5, 0, 1000)
data['rt'].iloc[1000:] = np.linspace(0, 5, 1000)
data['choice'].iloc[:1000] = -1
data['choice'].iloc[1000:] = 1

# Network predictions
predict_on_batch_out = network.predict_on_batch(data.values.astype(np.float32))

# Simulations
from ssms.basic_simulators import simulator
sim_out = simulator(model = 'angle', 
                    theta = data.values[0, :-2],
                    n_samples = 2000)
# Plot network predictions
plt.plot(data['rt'] * data['choice'], np.exp(predict_on_batch_out), color = 'black', label = 'network')

# Plot simulations
plt.hist(sim_out['rts'] * sim_out['choices'], bins = 30, histtype = 'step', label = 'simulations', color = 'blue', density  = True)
plt.legend()
plt.title('SSM likelihood')
plt.xlabel('rt')
plt.ylabel('likelihod')
Text(0, 0.5, 'likelihod')

png

TorchMLP to ONNX Converter

The transform_onnx.py script converts a TorchMLP model to the ONNX format. It takes a network configuration file (in pickle format), a state dictionary file (Torch model weights), the size of the input tensor, and the desired output ONNX file path.

Usage

python onnx/transform_onnx.py <network_config_file> <state_dict_file> <input_shape> <output_onnx_file>

Replace the placeholders with the appropriate values:

  • <network_config_file>: Path to the pickle file containing the network configuration.
  • <state_dict_file>: Path to the file containing the state dictionary of the model.
  • <input_shape>: The size of the input tensor for the model (integer).
  • <output_onnx_file>: Path to the output ONNX file.

For example:

python onnx/transform_onnx.py '0d9f0e94175b11eca9e93cecef057438_lca_no_bias_4_torch__network_config.pickle' '0d9f0e94175b11eca9e93cecef057438_lca_no_bias_4_torch_state_dict.pt' 11 'lca_no_bias_4_torch.onnx'

This onnx file can be used directly with the HSSM package.

We hope this package may be helpful in case you attempt to train LANs for your own research.

END

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

lanfactory-0.4.6.tar.gz (638.1 kB view hashes)

Uploaded Source

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page