A Pytorch Backend Library for Choice Modelling
Project description
torch-choice
Authors: Tianyu Du and Ayush Kanodia; PI: Susan Athey; Contact: tianyudu@stanford.edu
torch-choice
is a flexible, fast choice modeling with PyTorch: logit and nested logit models, designed for both estimation and prediction. See the complete documentation for more details.
Unique features:
- GPU support via torch for speed
- Specify customized models
- Specify availability sets
- Report standard errors
Installation
- Clone the repository to your local machine or server.
- Install required dependencies using:
pip3 install -r requirements.txt
. - Run
pip3 install torch-choice
. - Check installation by running
python3 -c 'import torch_choice; print(torch_choice.__version__)'
.
In this demonstration, we will guide you through a minimal example of fitting a conditional logit model using our package. We will be referencing to R code and Stata code as well to deliver a smooth knowledge transfer.
Mode Canada Example
In this demonstration, we will guide you through a minimal example of fitting a conditional logit model using our package. We will be referencing R code as well to deliver a smooth knowledge transfer.
More information about the ModeCanada: Mode Choice for the Montreal-Toronto Corridor.
Mode Canada with Torch-Choice
# load packages.
import pandas as pd
import torch_choice
# load data.
df = pd.read_csv('https://raw.githubusercontent.com/gsbDBI/torch-choice/main/tutorials/public_datasets/ModeCanada.csv?token=GHSAT0AAAAAABRGHCCSNNQARRMU63W7P7F4YWYP5HA').query('noalt == 4').reset_index(drop=True)
# format data.
data = torch_choice.utils.easy_data_wrapper.EasyDatasetWrapper(
main_data=df,
purchase_record_column='case',
choice_column='choice',
item_name_column='alt',
user_index_column='case',
session_index_column='case',
session_observable_columns=['income'],
price_observable_columns=['cost', 'freq', 'ovt', 'ivt'])
# define the conditional logit model.
model = torch_choice.model.ConditionalLogitModel(
coef_variation_dict={'price_cost': 'constant',
'price_freq': 'constant',
'price_ovt': 'constant',
'session_income': 'item',
'price_ivt': 'item-full',
'intercept': 'item'},
num_items=4)
# fit the conditional logit model.
torch_choice.utils.run_helper.run(model, data.choice_dataset, num_epochs=5000, learning_rate=0.01, batch_size=-1)
Mode Canada with R
We include the R code for the ModeCanada example as well.
# load packages.
library("mlogit")
# load data.
ModeCanada <- read.csv('https://raw.githubusercontent.com/gsbDBI/torch-choice/main/tutorials/public_datasets/ModeCanada.csv?token=GHSAT0AAAAAABRGHCCSNNQARRMU63W7P7F4YWYP5HA')
ModeCanada <- select(ModeCanada, -X)
ModeCanada$alt <- as.factor(ModeCanada$alt)
# format data.
MC <- dfidx(ModeCanada, subset = noalt == 4)
# fit the data.
ml.MC1 <- mlogit(choice ~ cost + freq + ovt | income | ivt, MC, reflevel='air')
summary(ml.MC1)
What's in the package?
-
The package includes a data management tool based on
PyTorch
's dataset calledChoiceDataset
. Our dataset implementation allows users to easily move data between CPU and GPU. Unlike traditional long or wide formats, theChoiceDataset
offers a memory-efficient way to manage observables. -
The package provides a (1) conditional logit model for consumer choice modeling, (2) a nested logit model for consumer choice modeling.
-
The package leverage GPU acceleration using PyTorch and easily scale to large dataset of millions of choice records. All models are trained using state-of-the-art optimizers by in PyTorch. These optimization algorithms are tested to be scalable by modern machine learning practitioners. However, you can rest assure that the package runs flawlessly when no GPU is used as well.
-
For those without much experience in model PyTorch development, setting up optimizers and training loops can be frustrating. We provide easy-to-use PyTorch lightning wrapper of models to free researchers from the hassle from setting up PyTorch optimizers and training loops.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for torch_choice-0.0.5-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 6bce6f96e31e300264b17094237d863155c5deed422a243963a72b0b33222bb5 |
|
MD5 | eee3ffa4c366dce25d8c5ff9f9afb7dd |
|
BLAKE2b-256 | e1d18937da0610aa409159bd2aeac65f1495c6b5fddb1e730ffd91ff40904b00 |