Skip to main content

PyTorch-Lightning callback for spacecutter

Project description

Spacecutter Lightning

A PyTorch Lightning Callback for Spacecutter.

Installation

pip install spacecutter-lightning

Usage

import torch
import pytorch_lightning as pl
from spacecutter_lightning import ClipCutpoints
from spacecutter import OrdinalLogisticHead, CumulativeLinkLoss

num_classes = 10
num_features = 5
hidden_size = 10
size = 200

x = torch.randn(size, num_features)
y = torch.randint(0, num_classes, (size, 1))

train_dataloader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(x, y),
)

model = torch.nn.Sequential(
    torch.nn.Linear(num_features, hidden_size),
    torch.nn.ReLU(),
    torch.nn.Linear(hidden_size, 1),
    OrdinalLogisticHead(num_classes),
)

loss_fn = CumulativeLinkLoss()


class LitModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = model
        self.loss_fn = loss_fn

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)


trainer = pl.Trainer(
    callbacks=[ClipCutpoints()],
    max_epochs=10,
)
trainer.fit(LitModel(), train_dataloader)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

spacecutter_lightning-0.3.1.tar.gz (2.1 kB view details)

Uploaded Source

Built Distribution

spacecutter_lightning-0.3.1-py3-none-any.whl (2.7 kB view details)

Uploaded Python 3

File details

Details for the file spacecutter_lightning-0.3.1.tar.gz.

File metadata

  • Download URL: spacecutter_lightning-0.3.1.tar.gz
  • Upload date:
  • Size: 2.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.16

File hashes

Hashes for spacecutter_lightning-0.3.1.tar.gz
Algorithm Hash digest
SHA256 0474674db5f525c0e1f422f4b3d02047fd71a74cfc1643e4097c05a0ec4d2ad7
MD5 4b96b004b2c4671e01ea2d329523380a
BLAKE2b-256 8212e519989637b0216a5c5886034e053df0d3872499d0fa310bea4e511c092b

See more details on using hashes here.

File details

Details for the file spacecutter_lightning-0.3.1-py3-none-any.whl.

File metadata

File hashes

Hashes for spacecutter_lightning-0.3.1-py3-none-any.whl
Algorithm Hash digest
SHA256 cb590c7fe8388b051a3dadb6f359e46ff153de60992b27325353eed13af118b7
MD5 90c663761529d97ad261bfede1c625e9
BLAKE2b-256 b813137ad867acfa9ed1621c6490c376f5b063cc718109730478cf30bbdf2f22

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page