Package with convenience functions to train LANs
Project description
LANfactory
Lightweight python package to help with training LANs (Likelihood approximation networks).
Please find the original documentation here.
Quick Start
The LANfactory package is a light-weight convenience package for training likelihood approximation networks (LANs) in torch (or keras),
starting from supplied training data.
LANs, although more general in potential scope of applications, were conceived in the context of sequential sampling modeling to account for cognitive processes giving rise to choice and reaction time data in n-alternative forced choice experiments commonly encountered in the cognitive sciences.
For a basic tutorial on how to use the LANfactory package, please refer to the basic tutorial notebook.
In this quick tutorial we will use the ssms package to generate our training data using such a sequential sampling model (SSM). The use is in no way bound to utilize the ssms package.
Install
To install the ssms package type,
pip install ssm-simulators
To install the LANfactory package type,
pip install lanfactory
Necessary dependency should be installed automatically in the process.
Basic Tutorial
Check the basic tutorial here.
TorchMLP to ONNX Converter
Once you have trained your model, you can convert it to the ONNX format using the transform_onnx.py script.
The transform_onnx.py script converts a TorchMLP model to the ONNX format. It takes a network configuration file (in pickle format), a state dictionary file (Torch model weights), the size of the input tensor, and the desired output ONNX file path.
Usage
python onnx/transform_onnx.py <network_config_file> <state_dict_file> <input_shape> <output_onnx_file>
Replace the placeholders with the appropriate values:
- <network_config_file>: Path to the pickle file containing the network configuration.
- <state_dict_file>: Path to the file containing the state dictionary of the model.
- <input_shape>: The size of the input tensor for the model (integer).
- <output_onnx_file>: Path to the output ONNX file.
For example:
python onnx/transform_onnx.py '0d9f0e94175b11eca9e93cecef057438_lca_no_bias_4_torch__network_config.pickle' '0d9f0e94175b11eca9e93cecef057438_lca_no_bias_4_torch_state_dict.pt' 11 'lca_no_bias_4_torch.onnx'
This onnx file can be used directly with the HSSM package.
We hope this package may be helpful in case you attempt to train LANs for your own research.
END
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file lanfactory-0.5.1.tar.gz.
File metadata
- Download URL: lanfactory-0.5.1.tar.gz
- Upload date:
- Size: 610.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
2b11cdb218d1cd9d00c369ab7df48c9c79113fd5a6ca69cee42add445f2d7df4
|
|
| MD5 |
48335cb7c6e98a04c771cd226d5c093c
|
|
| BLAKE2b-256 |
50c336d3a191c5518d9a773aa76026b1b66fc08410e1ad7936d1ed7d98e73a6d
|