Skip to main content

PyTorch-Lightning callback for spacecutter

Project description

Spacecutter Lightning

A PyTorch Lightning Callback for Spacecutter.

Installation

pip install spacecutter-lightning

Usage

import torch
import pytorch_lightning as pl
from spacecutter_lightning import ClipCutpoints
from spacecutter import OrdinalLogisticHead, CumulativeLinkLoss

num_classes = 10
num_features = 5
hidden_size = 10
size = 200

x = torch.randn(size, num_features)
y = torch.randint(0, num_classes, (size, 1))

train_dataloader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(x, y),
)

model = torch.nn.Sequential(
    torch.nn.Linear(num_features, hidden_size),
    torch.nn.ReLU(),
    torch.nn.Linear(hidden_size, 1),
    OrdinalLogisticHead(num_classes),
)

loss_fn = CumulativeLinkLoss()


class LitModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = model
        self.loss_fn = loss_fn

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)


trainer = pl.Trainer(
    callbacks=[ClipCutpoints()],
    max_epochs=10,
)
trainer.fit(LitModel(), train_dataloader)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

spacecutter_lightning-0.3.0.tar.gz (2.0 kB view details)

Uploaded Source

Built Distribution

spacecutter_lightning-0.3.0-py3-none-any.whl (2.7 kB view details)

Uploaded Python 3

File details

Details for the file spacecutter_lightning-0.3.0.tar.gz.

File metadata

  • Download URL: spacecutter_lightning-0.3.0.tar.gz
  • Upload date:
  • Size: 2.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.16

File hashes

Hashes for spacecutter_lightning-0.3.0.tar.gz
Algorithm Hash digest
SHA256 9b57d24f0e2fc91e0f10e45d78aa58b1160c467b64d90dedb2b8fd93be5179ca
MD5 d1430d9cd5bdfa4312f7f14aebe8760d
BLAKE2b-256 63215a7accd43b9eccbe2f5967f02df9b070ae66e8c822d281957e4909545671

See more details on using hashes here.

File details

Details for the file spacecutter_lightning-0.3.0-py3-none-any.whl.

File metadata

File hashes

Hashes for spacecutter_lightning-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 d5100c937b1499b57c811b8b56aa7d333512fe53975f4a2b2ee3b623d784eb71
MD5 63d6a86450aadc9a9b9f7ef38c29dd17
BLAKE2b-256 3311104f12fd862ddc17f849c58c27ca02495959340ea36cc42b6b5f85850615

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page