Skip to main content

Ordinal regression models in PyTorch

Project description

spacecutter-torch

spacecutter-torch is a library for implementing ordinal regression models in PyTorch. The library consists of models and loss functions.

Installation

pip install spacecutter-torch

Usage

Models

Define any PyTorch model you want that generates a single, scalar prediction value. This will be our predictor model. This model can then be wrapped with spacecutter.models.OrdinalLogisticModel which will convert the output of the predictor from a single number to an array of ordinal class probabilities. The following example shows how to do this for a two layer neural network predictor for a problem with three ordinal classes.

import torch
from torch import nn

from spacecutter.models import OrdinalLogisticHead


X = torch.tensor([[0.5, 0.1, -0.1],
              [1.0, 0.2, 0.6],
              [-2.0, 0.4, 0.8]]).float()

y = torch.tensor([0, 1, 2]).reshape(-1, 1).long()

num_features = X.shape[1]
num_classes = len(torch.unique(y))

model = nn.Sequential(
    nn.Linear(num_features, num_features),
    nn.ReLU(),
    nn.Linear(num_features, 1),
    OrdinalLogisticHead(num_classes),
)

y_pred = model(X)

print(y_pred)

# tensor([[0.2325, 0.2191, 0.5485],
#         [0.2324, 0.2191, 0.5485],
#         [0.2607, 0.2287, 0.5106]], grad_fn=<CatBackward>)

Training

The following shows how to train the model from the previous section using cumulative link loss:

import torch
from spacecutter.callbacks import AscensionCallback
from spacecutter.losses import CumulativeLinkLoss

def train(model, optimizer, X, y, num_epochs = 10) -> list:
    """
    you can bring your own training loop if you want, but we provide a very simple one here. 
    """
    model.train()
    on_batch_end_callbacks = [AscensionCallback()]
    loss_fn = CumulativeLinkLoss()
    losses = []
    for epoch in range(num_epochs):
        optimizer.zero_grad()
        y_pred = model(X)
        loss = loss_fn(y_pred, y)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        with torch.no_grad():
            for callback in on_batch_end_callbacks:
                model.apply(callback)
    return losses

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
losses = train(model, optimizer, X, y)

Note that we must add the AscensionCallback. This ensures that the ordinal cutpoints stay in ascending order. While ideally this constraint would be factored directly into the model optimization, spacecutter currently hacks an SGD-compatible solution by utilizing a post-backwards-pass callback to clip the cutpoint values.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

spacecutter_torch-1.0.0.tar.gz (5.3 kB view details)

Uploaded Source

Built Distribution

spacecutter_torch-1.0.0-py3-none-any.whl (6.6 kB view details)

Uploaded Python 3

File details

Details for the file spacecutter_torch-1.0.0.tar.gz.

File metadata

  • Download URL: spacecutter_torch-1.0.0.tar.gz
  • Upload date:
  • Size: 5.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.17

File hashes

Hashes for spacecutter_torch-1.0.0.tar.gz
Algorithm Hash digest
SHA256 b564914c966498a1df9f8f5d852cb284cda3a325b0b0f605f31883488a0eee42
MD5 924a96c49593332924a500bd8493b4ef
BLAKE2b-256 609b134d3654fec1611c704fd3438330eb4bb1d46929accaa6731e573ef33958

See more details on using hashes here.

File details

Details for the file spacecutter_torch-1.0.0-py3-none-any.whl.

File metadata

File hashes

Hashes for spacecutter_torch-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 fbf9534fd1ae37e67600feca4f2c4aec16e68df73cc32ddfdcce46bfc4ed9cc0
MD5 af3b6b48213fca5919ea73237541492b
BLAKE2b-256 5f2291552cd44bc4453d494cb3c37f7682a2a8aaea108dc884394bfc05d3cc93

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page