Skip to main content

Package with convenience functions to train LANs

Project description

LANfactory

PyPI PyPI_dl Code style: black License: MIT

Lightweight python package to help with training LANs (Likelihood approximation networks).

Please find the original documentation here.

Quick Start

The LANfactory package is a light-weight convenience package for training likelihood approximation networks (LANs) in torch (or keras), starting from supplied training data.

LANs, although more general in potential scope of applications, were conceived in the context of sequential sampling modeling to account for cognitive processes giving rise to choice and reaction time data in n-alternative forced choice experiments commonly encountered in the cognitive sciences.

For a basic tutorial on how to use the LANfactory package, please refer to the basic tutorial notebook.

In this quick tutorial we will use the ssms package to generate our training data using such a sequential sampling model (SSM). The use is in no way bound to utilize the ssms package.

Install

To install the ssms package type,

pip install ssm-simulators

To install the LANfactory package type,

pip install lanfactory

Necessary dependency should be installed automatically in the process.

Basic Tutorial

Check the basic tutorial here.

TorchMLP to ONNX Converter

Once you have trained your model, you can convert it to the ONNX format using the transform_onnx.py script.

The transform_onnx.py script converts a TorchMLP model to the ONNX format. It takes a network configuration file (in pickle format), a state dictionary file (Torch model weights), the size of the input tensor, and the desired output ONNX file path.

Usage

python onnx/transform_onnx.py <network_config_file> <state_dict_file> <input_shape> <output_onnx_file>

Replace the placeholders with the appropriate values:

  • <network_config_file>: Path to the pickle file containing the network configuration.
  • <state_dict_file>: Path to the file containing the state dictionary of the model.
  • <input_shape>: The size of the input tensor for the model (integer).
  • <output_onnx_file>: Path to the output ONNX file.

For example:

python onnx/transform_onnx.py '0d9f0e94175b11eca9e93cecef057438_lca_no_bias_4_torch__network_config.pickle' '0d9f0e94175b11eca9e93cecef057438_lca_no_bias_4_torch_state_dict.pt' 11 'lca_no_bias_4_torch.onnx'

This onnx file can be used directly with the HSSM package.

We hope this package may be helpful in case you attempt to train LANs for your own research.

END

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

lanfactory-0.4.7.tar.gz (631.9 kB view details)

Uploaded Source

File details

Details for the file lanfactory-0.4.7.tar.gz.

File metadata

  • Download URL: lanfactory-0.4.7.tar.gz
  • Upload date:
  • Size: 631.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for lanfactory-0.4.7.tar.gz
Algorithm Hash digest
SHA256 dadfa3fe0c83996d7d2248cfd90c5100529d13f2361be3146c5d3f0d3c0fea7e
MD5 d7d1daaefa2f36c93d77715d988d80b1
BLAKE2b-256 81c9e0d3d79572bdeaf4cb8eeae734e12e7bd0eb7c95e38a7bdd123736cb7fbc

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page