Skip to main content

PyTorch-Lightning callback for spacecutter

Project description

Spacecutter Lightning

A PyTorch Lightning Callback for Spacecutter.

Installation

pip install spacecutter-lightning

Usage

import torch
import pytorch_lightning as pl
from spacecutter_lightning import ClipCutpoints
from spacecutter import OrdinalLogisticHead, CumulativeLinkLoss

num_classes = 10
num_features = 5
hidden_size = 10
size = 200

x = torch.randn(size, num_features)
y = torch.randint(0, num_classes, (size, 1))

train_dataloader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(x, y),
)

model = torch.nn.Sequential(
    torch.nn.Linear(num_features, hidden_size),
    torch.nn.ReLU(),
    torch.nn.Linear(hidden_size, 1),
    OrdinalLogisticHead(num_classes),
)

loss_fn = CumulativeLinkLoss()


class LitModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = model
        self.loss_fn = loss_fn

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)


trainer = pl.Trainer(
    callbacks=[ClipCutpoints()],
    max_epochs=10,
)
trainer.fit(LitModel(), train_dataloader)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

spacecutter_lightning-0.3.2.tar.gz (2.1 kB view details)

Uploaded Source

Built Distribution

spacecutter_lightning-0.3.2-py3-none-any.whl (2.8 kB view details)

Uploaded Python 3

File details

Details for the file spacecutter_lightning-0.3.2.tar.gz.

File metadata

  • Download URL: spacecutter_lightning-0.3.2.tar.gz
  • Upload date:
  • Size: 2.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.16

File hashes

Hashes for spacecutter_lightning-0.3.2.tar.gz
Algorithm Hash digest
SHA256 7e2093cb794ffece13518914363dea980f317f54d198c6015531b32c9cfaad67
MD5 390abd857740301588c60de3982813a4
BLAKE2b-256 64d0d4ac992ec2ac409a369fd4ed105be7c7cd5156e64dc383de4328aaa427c4

See more details on using hashes here.

File details

Details for the file spacecutter_lightning-0.3.2-py3-none-any.whl.

File metadata

File hashes

Hashes for spacecutter_lightning-0.3.2-py3-none-any.whl
Algorithm Hash digest
SHA256 5e1b0d6fc5ea44d3dca7c15b4d81905257cfda51fc160d30fb2f417db11c7407
MD5 c0befafeec25639f576559726b7087c6
BLAKE2b-256 e0f750b91503ddc55f7fa4d74e99e6dc6628697288800a99e510f003dfea9dfd

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page