A simple MNL package that uses pytorch under the hood.
Project description
A simple pytorch based MNL lib
Fit your Multinomial Logistic Regression with Pytorch
Install
pip install pytorch_mnl
How to use
import the lib
import pandas as pd
from pytorch_mnl.core import *
load data
data = pd.read_csv("./data/Iris.csv").drop("Id", axis=1)
choose x, y cols:
x_cols=['SepalLengthCm', 'SepalWidthCm', 'PetalLengthCm', 'PetalWidthCm']
target_col = 'Species'
the number of classes to predict:
n_targets = len(data[target_col].unique())
n_targets
3
X, y = prepare_data(data, x_cols=x_cols, target_col=target_col)
we get pytorch tensors ready to use!
type(X), type(y)
(torch.Tensor, torch.Tensor)
let's split in train/valid choosing a percenage as holdout, and choose a batch size to fit our model
dls = DataLoaders.from_Xy(X, y, pct=0.2, batch_size=8)
as our model has 4 variables, we will fit a 4 MNL, with 3 targets.
model = LinearMNL(len(x_cols), n_targets)
learn = Learner(dls, model)
learn.fit(25)
epoch = 0, val_loss = 2.072, accuracy = 0.53
epoch = 1, val_loss = 1.908, accuracy = 0.53
epoch = 2, val_loss = 1.770, accuracy = 0.80
epoch = 3, val_loss = 1.657, accuracy = 0.80
epoch = 4, val_loss = 1.564, accuracy = 0.80
epoch = 5, val_loss = 1.487, accuracy = 0.80
epoch = 6, val_loss = 1.422, accuracy = 0.80
epoch = 7, val_loss = 1.368, accuracy = 0.80
epoch = 8, val_loss = 1.321, accuracy = 0.80
epoch = 9, val_loss = 1.282, accuracy = 0.83
epoch = 10, val_loss = 1.247, accuracy = 0.83
epoch = 11, val_loss = 1.217, accuracy = 0.83
epoch = 12, val_loss = 1.190, accuracy = 0.83
epoch = 13, val_loss = 1.166, accuracy = 0.83
epoch = 14, val_loss = 1.144, accuracy = 0.87
epoch = 15, val_loss = 1.125, accuracy = 0.87
epoch = 16, val_loss = 1.107, accuracy = 0.90
epoch = 17, val_loss = 1.091, accuracy = 0.90
epoch = 18, val_loss = 1.076, accuracy = 0.90
epoch = 19, val_loss = 1.063, accuracy = 0.90
epoch = 20, val_loss = 1.050, accuracy = 0.90
epoch = 21, val_loss = 1.038, accuracy = 0.90
epoch = 22, val_loss = 1.027, accuracy = 0.90
epoch = 23, val_loss = 1.016, accuracy = 0.90
epoch = 24, val_loss = 1.007, accuracy = 0.90
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
pytorch_mnl-0.0.1.tar.gz
(12.2 kB
view details)
Built Distribution
File details
Details for the file pytorch_mnl-0.0.1.tar.gz
.
File metadata
- Download URL: pytorch_mnl-0.0.1.tar.gz
- Upload date:
- Size: 12.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.2 importlib_metadata/4.5.0 pkginfo/1.7.1 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.1 CPython/3.8.10
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 8b0edaf57e697457b44c3977118b510117dabbc5bbe16af94532d7313677f35e |
|
MD5 | 7a9a837b2130115604fbda185822e438 |
|
BLAKE2b-256 | 203d6791358cbecefecde632b262c78e72c9aebefebeccbae1b4cd6e8ee560c0 |
File details
Details for the file pytorch_mnl-0.0.1-py3-none-any.whl
.
File metadata
- Download URL: pytorch_mnl-0.0.1-py3-none-any.whl
- Upload date:
- Size: 9.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.2 importlib_metadata/4.5.0 pkginfo/1.7.1 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.1 CPython/3.8.10
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 83702e582df7fa341cb93b04c035d4d6d9627c617a91e7ac646ace8a426b5642 |
|
MD5 | 345bb34ef8ca16c132808d2a5aedf737 |
|
BLAKE2b-256 | 35daf9741c7542858a4b347f3eef01ad078370460e964c63a54a603ed908b58b |