Skip to main content

Ordinal regression models in PyTorch

Project description

spacecutter-torch

spacecutter-torch is a library for implementing ordinal regression models in PyTorch. The library consists of models and loss functions.

Installation

pip install spacecutter-torch

Usage

Models

Define any PyTorch model you want that generates a single, scalar prediction value. This will be our predictor model. This model can then be wrapped with spacecutter.models.OrdinalLogisticModel which will convert the output of the predictor from a single number to an array of ordinal class probabilities. The following example shows how to do this for a two layer neural network predictor for a problem with three ordinal classes.

import torch
from torch import nn

from spacecutter.models import OrdinalLogisticHead


X = torch.tensor([[0.5, 0.1, -0.1],
              [1.0, 0.2, 0.6],
              [-2.0, 0.4, 0.8]]).float()

y = torch.tensor([0, 1, 2]).reshape(-1, 1).long()

num_features = X.shape[1]
num_classes = len(torch.unique(y))

model = nn.Sequential(
    nn.Linear(num_features, num_features),
    nn.ReLU(),
    nn.Linear(num_features, 1),
    OrdinalLogisticHead(num_classes),
)

y_pred = model(X)

print(y_pred)

# tensor([[0.2325, 0.2191, 0.5485],
#         [0.2324, 0.2191, 0.5485],
#         [0.2607, 0.2287, 0.5106]], grad_fn=<CatBackward>)

Training

The following shows how to train the model from the previous section using cumulative link loss:

import torch
from spacecutter.callbacks import AscensionCallback
from spacecutter.losses import CumulativeLinkLoss

def train(model, optimizer, X, y, num_epochs = 10) -> list:
    """
    you can bring your own training loop if you want, but we provide a very simple one here. 
    """
    model.train()
    on_batch_end_callbacks = [AscensionCallback()]
    loss_fn = CumulativeLinkLoss()
    losses = []
    for epoch in range(num_epochs):
        optimizer.zero_grad()
        y_pred = model(X)
        loss = loss_fn(y_pred, y)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        with torch.no_grad():
            for callback in on_batch_end_callbacks:
                model.apply(callback)
    return losses

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
losses = train(model, optimizer, X, y)

Note that we must add the AscensionCallback. This ensures that the ordinal cutpoints stay in ascending order. While ideally this constraint would be factored directly into the model optimization, spacecutter currently hacks an SGD-compatible solution by utilizing a post-backwards-pass callback to clip the cutpoint values.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

spacecutter_torch-0.4.0.tar.gz (5.3 kB view details)

Uploaded Source

Built Distribution

spacecutter_torch-0.4.0-py3-none-any.whl (6.6 kB view details)

Uploaded Python 3

File details

Details for the file spacecutter_torch-0.4.0.tar.gz.

File metadata

  • Download URL: spacecutter_torch-0.4.0.tar.gz
  • Upload date:
  • Size: 5.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.16

File hashes

Hashes for spacecutter_torch-0.4.0.tar.gz
Algorithm Hash digest
SHA256 7954fd2843658e31da833ca159c36d4c1ae5baae369dcb720130338a4c464078
MD5 4d34cfdd56c7b94213dd1d6ebe53dad6
BLAKE2b-256 7d2b979bff426b4f7952c71b6672fcaac1e65df0c32798ecafaec8597806196d

See more details on using hashes here.

File details

Details for the file spacecutter_torch-0.4.0-py3-none-any.whl.

File metadata

File hashes

Hashes for spacecutter_torch-0.4.0-py3-none-any.whl
Algorithm Hash digest
SHA256 5013912747c9bbccec7f0f0063d9b9137268cc5a3da8a7efb941a70326f19aca
MD5 d5eb5da3d1b9117a4f56b664fc39d4c1
BLAKE2b-256 7f9dcc360dcd5051b3d1915ca02d3edfe33e60ff5e45f8f8147aa0193bd42cc4

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page