Skip to main content

A simple MNL package that uses pytorch under the hood.

Project description

A simple pytorch based MNL lib

Fit your Multinomial Logistic Regression with Pytorch

Install

pip install pytorch_mnl

How to use

import the lib

import pandas as pd
from pytorch_mnl.core import *

load data

data = pd.read_csv("./data/Iris.csv").drop("Id", axis=1)

choose x, y cols:

x_cols=['SepalLengthCm', 'SepalWidthCm', 'PetalLengthCm', 'PetalWidthCm']
target_col = 'Species'

the number of classes to predict:

n_targets = len(data[target_col].unique())
n_targets
3
X, y = prepare_data(data, x_cols=x_cols, target_col=target_col)

we get pytorch tensors ready to use!

type(X), type(y)
(torch.Tensor, torch.Tensor)

let's split in train/valid choosing a percenage as holdout, and choose a batch size to fit our model

dls = DataLoaders.from_Xy(X, y, pct=0.2, batch_size=8)

as our model has 4 variables, we will fit a 4 MNL, with 3 targets.

model = LinearMNL(len(x_cols), n_targets)
learn = Learner(dls, model)
learn.fit(25)
epoch =   0, val_loss = 2.072, accuracy = 0.53
epoch =   1, val_loss = 1.908, accuracy = 0.53
epoch =   2, val_loss = 1.770, accuracy = 0.80
epoch =   3, val_loss = 1.657, accuracy = 0.80
epoch =   4, val_loss = 1.564, accuracy = 0.80
epoch =   5, val_loss = 1.487, accuracy = 0.80
epoch =   6, val_loss = 1.422, accuracy = 0.80
epoch =   7, val_loss = 1.368, accuracy = 0.80
epoch =   8, val_loss = 1.321, accuracy = 0.80
epoch =   9, val_loss = 1.282, accuracy = 0.83
epoch =  10, val_loss = 1.247, accuracy = 0.83
epoch =  11, val_loss = 1.217, accuracy = 0.83
epoch =  12, val_loss = 1.190, accuracy = 0.83
epoch =  13, val_loss = 1.166, accuracy = 0.83
epoch =  14, val_loss = 1.144, accuracy = 0.87
epoch =  15, val_loss = 1.125, accuracy = 0.87
epoch =  16, val_loss = 1.107, accuracy = 0.90
epoch =  17, val_loss = 1.091, accuracy = 0.90
epoch =  18, val_loss = 1.076, accuracy = 0.90
epoch =  19, val_loss = 1.063, accuracy = 0.90
epoch =  20, val_loss = 1.050, accuracy = 0.90
epoch =  21, val_loss = 1.038, accuracy = 0.90
epoch =  22, val_loss = 1.027, accuracy = 0.90
epoch =  23, val_loss = 1.016, accuracy = 0.90
epoch =  24, val_loss = 1.007, accuracy = 0.90

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch_mnl-0.0.1.tar.gz (12.2 kB view details)

Uploaded Source

Built Distribution

pytorch_mnl-0.0.1-py3-none-any.whl (9.6 kB view details)

Uploaded Python 3

File details

Details for the file pytorch_mnl-0.0.1.tar.gz.

File metadata

  • Download URL: pytorch_mnl-0.0.1.tar.gz
  • Upload date:
  • Size: 12.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.5.0 pkginfo/1.7.1 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.1 CPython/3.8.10

File hashes

Hashes for pytorch_mnl-0.0.1.tar.gz
Algorithm Hash digest
SHA256 8b0edaf57e697457b44c3977118b510117dabbc5bbe16af94532d7313677f35e
MD5 7a9a837b2130115604fbda185822e438
BLAKE2b-256 203d6791358cbecefecde632b262c78e72c9aebefebeccbae1b4cd6e8ee560c0

See more details on using hashes here.

File details

Details for the file pytorch_mnl-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: pytorch_mnl-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 9.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.5.0 pkginfo/1.7.1 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.1 CPython/3.8.10

File hashes

Hashes for pytorch_mnl-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 83702e582df7fa341cb93b04c035d4d6d9627c617a91e7ac646ace8a426b5642
MD5 345bb34ef8ca16c132808d2a5aedf737
BLAKE2b-256 35daf9741c7542858a4b347f3eef01ad078370460e964c63a54a603ed908b58b

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page