Skip to main content

PyTorch-Lightning callback for spacecutter

Project description

Spacecutter Lightning

A PyTorch Lightning Callback for Spacecutter.

Installation

pip install spacecutter-lightning

Usage

import torch
import pytorch_lightning as pl
from spacecutter_lightning import ClipCutpoints
from spacecutter import OrdinalLogisticHead, CumulativeLinkLoss

num_classes = 10
num_features = 5
hidden_size = 10
size = 200

x = torch.randn(size, num_features)
y = torch.randint(0, num_classes, (size, 1))

train_dataloader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(x, y),
)

model = torch.nn.Sequential(
    torch.nn.Linear(num_features, hidden_size),
    torch.nn.ReLU(),
    torch.nn.Linear(hidden_size, 1),
    OrdinalLogisticHead(num_classes),
)

loss_fn = CumulativeLinkLoss()


class LitModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = model
        self.loss_fn = loss_fn

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)


trainer = pl.Trainer(
    callbacks=[ClipCutpoints()],
    max_epochs=10,
)
trainer.fit(LitModel(), train_dataloader)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

spacecutter_lightning-1.0.1.tar.gz (2.1 kB view details)

Uploaded Source

Built Distribution

spacecutter_lightning-1.0.1-py3-none-any.whl (2.8 kB view details)

Uploaded Python 3

File details

Details for the file spacecutter_lightning-1.0.1.tar.gz.

File metadata

  • Download URL: spacecutter_lightning-1.0.1.tar.gz
  • Upload date:
  • Size: 2.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.17

File hashes

Hashes for spacecutter_lightning-1.0.1.tar.gz
Algorithm Hash digest
SHA256 77db2120565ef53e3cd4f92cbadda0ed4a07d7eb96cf60d587bfc5325d6ac690
MD5 76942ce3d4f6674f1914eb7944b65b53
BLAKE2b-256 0eceddadd5ddc2aabe3c6bf89187612a4768ed2b3c007da3a7969d34d6d1df3b

See more details on using hashes here.

File details

Details for the file spacecutter_lightning-1.0.1-py3-none-any.whl.

File metadata

File hashes

Hashes for spacecutter_lightning-1.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 4bec94780bd31052ef491897eaa54b004cabfa8b5981e5c0e40da8a7178be98b
MD5 2487194c36fd7e30267bca6defc14ceb
BLAKE2b-256 f89fa374a4cab2668a9710705a7272077aad1af5c417d583c902f7169526a04b

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page