Skip to main content

No project description provided

Project description

nshtrainer

A configuration-driven wrapper around PyTorch Lightning that simplifies deep learning experiment setup. Built on nshconfig (Pydantic-based) for type-safe, serializable configuration of every training aspect.

Key Features

  • Type-safe configuration — Every component (callbacks, loggers, optimizers, schedulers) has a paired Config class with full IDE autocompletion and validation
  • Automatic checkpointing with metadata — Best/last/on-exception checkpoints with JSON metadata files containing metrics, environment info, git state, and SHA256 checksums
  • Environment capture — Automatically records hardware info, installed packages, git state, and cluster details (SLURM/LSF) with every run
  • Registry-based extensibility — Add custom callbacks, optimizers, schedulers, and loggers by subclassing and registering
  • HPC support — Automatic node detection on SLURM/LSF clusters, signal handling, and auto-requeue on preemption
  • Builder-style API — Fluent configuration with with_*() (returns copy) and *_() (in-place) methods
  • HuggingFace Hub integration — Optionally push checkpoints to HuggingFace Hub

Installation

pip install nshtrainer

# With all optional dependencies (wandb, tensorboard, etc.)
pip install nshtrainer[extra]

Quick Start

import nshconfig as C
import torch
from torch.utils.data import DataLoader, TensorDataset
from typing_extensions import override

import nshtrainer as nt

# 1. Define your hyperparameters as a config class
class MyModelConfig(C.Config):
    hidden_size: int = 64
    lr: float = 1e-3

# 2. Subclass LightningModuleBase with your config
class MyModel(nt.LightningModuleBase[MyModelConfig]):
    @override
    @classmethod
    def hparams_cls(cls):
        return MyModelConfig

    def __init__(self, hparams: MyModelConfig):
        super().__init__(hparams)
        self.net = torch.nn.Linear(10, hparams.hidden_size)
        self.head = torch.nn.Linear(hparams.hidden_size, 1)

    @override
    def forward(self, x: torch.Tensor):
        return self.head(torch.relu(self.net(x)))

    @override
    def training_step(self, batch, batch_idx):
        x, y = batch
        loss = torch.nn.functional.mse_loss(self(x), y)
        self.log("train_loss", loss)
        return loss

    @override
    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.hparams.lr)

# 3. Configure the trainer
trainer_config = nt.TrainerConfig(
    max_epochs=10,
    accelerator="cpu",
    primary_metric=nt.MetricConfig(name="train_loss", mode="min"),
).with_project_root("./outputs")

# 4. Train
trainer = nt.Trainer(trainer_config)
model = MyModel(MyModelConfig())

dataset = TensorDataset(torch.randn(100, 10), torch.randn(100, 1))
trainer.fit(model, train_dataloaders=DataLoader(dataset, batch_size=16))

Documentation

License

See LICENSE for details.

Project details


Release history Release notifications | RSS feed

This version

1.5.4

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

nshtrainer-1.5.4.tar.gz (115.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

nshtrainer-1.5.4-py3-none-any.whl (194.7 kB view details)

Uploaded Python 3

File details

Details for the file nshtrainer-1.5.4.tar.gz.

File metadata

  • Download URL: nshtrainer-1.5.4.tar.gz
  • Upload date:
  • Size: 115.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.18 {"installer":{"name":"uv","version":"0.9.18","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for nshtrainer-1.5.4.tar.gz
Algorithm Hash digest
SHA256 4281f96670991f7aea6904bdfec1cc88b86f976b6dfcb991a091182290ffb22d
MD5 3e8a716cee775e7a8efde735e0f21d4c
BLAKE2b-256 a95beaefc6b2c8092359c285abe8ae7e74e11255a428f14fd936be0081714a35

See more details on using hashes here.

File details

Details for the file nshtrainer-1.5.4-py3-none-any.whl.

File metadata

  • Download URL: nshtrainer-1.5.4-py3-none-any.whl
  • Upload date:
  • Size: 194.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.18 {"installer":{"name":"uv","version":"0.9.18","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for nshtrainer-1.5.4-py3-none-any.whl
Algorithm Hash digest
SHA256 70cb800d63ebd38f56e29b39ef0d663e5133bc61d4520f56dc8e87304f3ab8cd
MD5 9d1c1ddd2bb21cce8bb33d2c2824ada8
BLAKE2b-256 350dd6b4266b0281591b44da26f2351cb0aacc3e0cdb78e8c1decd076c091c08

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page