PyTorch-Lightning callback for spacecutter
Project description
Spacecutter Lightning
A PyTorch Lightning Callback for Spacecutter.
Installation
pip install spacecutter-lightning
Usage
import torch
import pytorch_lightning as pl
from spacecutter_lightning import ClipCutpoints
from spacecutter import OrdinalLogisticHead, CumulativeLinkLoss
num_classes = 10
num_features = 5
hidden_size = 10
size = 200
x = torch.randn(size, num_features)
y = torch.randint(0, num_classes, (size, 1))
train_dataloader = torch.utils.data.DataLoader(
torch.utils.data.TensorDataset(x, y),
)
model = torch.nn.Sequential(
torch.nn.Linear(num_features, hidden_size),
torch.nn.ReLU(),
torch.nn.Linear(hidden_size, 1),
OrdinalLogisticHead(num_classes),
)
loss_fn = CumulativeLinkLoss()
class LitModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = model
self.loss_fn = loss_fn
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = self.loss_fn(y_hat, y)
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
trainer = pl.Trainer(
callbacks=[ClipCutpoints()],
max_epochs=10,
)
trainer.fit(LitModel(), train_dataloader)
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file spacecutter_lightning-0.3.2.tar.gz
.
File metadata
- Download URL: spacecutter_lightning-0.3.2.tar.gz
- Upload date:
- Size: 2.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.16
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 7e2093cb794ffece13518914363dea980f317f54d198c6015531b32c9cfaad67 |
|
MD5 | 390abd857740301588c60de3982813a4 |
|
BLAKE2b-256 | 64d0d4ac992ec2ac409a369fd4ed105be7c7cd5156e64dc383de4328aaa427c4 |
File details
Details for the file spacecutter_lightning-0.3.2-py3-none-any.whl
.
File metadata
- Download URL: spacecutter_lightning-0.3.2-py3-none-any.whl
- Upload date:
- Size: 2.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.16
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 5e1b0d6fc5ea44d3dca7c15b4d81905257cfda51fc160d30fb2f417db11c7407 |
|
MD5 | c0befafeec25639f576559726b7087c6 |
|
BLAKE2b-256 | e0f750b91503ddc55f7fa4d74e99e6dc6628697288800a99e510f003dfea9dfd |