Skip to main content

Typed PyTorch training utilities for checkpoints, scheduling, freezing, memory planning, and structured losses.

Project description

trainkit-vp

trainkit-vp collects typed training utilities for PyTorch projects: checkpointing, staged freezing, loss aggregation, scheduler setup, and memory planning.

PyPI package name:

pip install trainkit-vp

Import name:

import trainkit

The package is intentionally generic. It is not tied to a single model family, and it is useful both in research code and in reusable application packages.

What is included

trainkit-vp currently provides:

  • checkpoint save/load helpers
  • staged freezing helpers for hierarchical models
  • a structured multi-level sequence loss
  • cosine warmup scheduling
  • accelerator-aware memory estimation and batch-size planning

Installation

Requirements:

  • Python >=3.14
  • PyTorch >=2.0

Install from PyPI:

pip install trainkit-vp

Checkpoint helpers

Public functions:

  • save_checkpoint
  • load_checkpoint

These wrap torch.save and torch.load behind a typed, small surface that downstream packages can standardize around.

Typical usage:

from pathlib import Path

from trainkit import load_checkpoint, save_checkpoint

payload = {
    "epoch": 3,
    "model_state": model.state_dict(),
    "optimizer_state": optimizer.state_dict(),
}

path = Path("artifacts/best.pt")
save_checkpoint(path, payload)
restored = load_checkpoint(path)

Freezing helpers

Public functions:

  • set_requires_grad
  • freeze_stages

set_requires_grad is a simple utility for enabling or disabling training on a module.

freeze_stages is aimed at hierarchical setups. In the current code it knows how to freeze universal, family, language, or joint stages for a model exposing a CELMoE-style level_modules layout.

That makes it useful for:

  • staged pretraining
  • late unfreezing
  • expert specialization phases

Multi-level sequence loss

The main structured loss class is MultiLevelSequenceLoss.

It expects a model output dictionary containing:

  • logits
  • universal_logits
  • family_logits
  • language_logits

It returns a LossBreakdown dataclass with:

  • total
  • final
  • universal
  • family
  • language

Example:

from chartoken import PAD
from trainkit import MultiLevelSequenceLoss

criterion = MultiLevelSequenceLoss(
    pad_id=PAD,
    final_weight=1.0,
    universal_weight=0.25,
    family_weight=0.35,
    language_weight=0.5,
)

breakdown = criterion(output, target_ids)
print(breakdown.total)
print(breakdown.as_scalars())

Scheduler helper

create_cosine_schedule builds a linear-warmup cosine-decay schedule on top of torch.optim.lr_scheduler.LambdaLR.

Example:

from trainkit import create_cosine_schedule

scheduler = create_cosine_schedule(
    optimizer,
    warmup_steps=500,
    total_steps=12000,
)

Memory planning utilities

Public functions:

  • format_bytes
  • detect_memory
  • estimate_model_memory
  • estimate_training_vram
  • suggest_batch_size
  • plan_training

Main dataclasses:

  • MemoryProfile
  • TrainingPlan

These helpers are intended for early experiment planning rather than exact runtime accounting. They give fast estimates for:

  • available system RAM
  • available accelerator memory
  • approximate model memory footprint
  • approximate training VRAM requirements
  • a suggested batch size

The implementation includes typed wrappers around CUDA, XPU, and MPS memory queries so static analyzers remain happy in strict mode.

Example workflow

from trainkit import detect_memory, plan_training

profile = detect_memory()
plan = plan_training(
    model,
    sequence_length=96,
    batch_size=64,
    optimizer_states_factor=2.0,
)

print(profile)
print(plan)

Scope boundaries

trainkit-vp does not define:

  • your model architecture
  • your dataset format
  • your optimizer choice
  • distributed training logic

Instead it provides small, reliable pieces that application packages can compose however they want.

Why publish it separately

Separating training utilities into their own package keeps:

  • model packages smaller
  • checkpoint format decisions reusable
  • type contracts stable across projects
  • experimentation code less duplicated

In this repository that matters because libraries are published independently and should remain usable outside morphoformer.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

trainkit_vp-2.3.3.tar.gz (11.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

trainkit_vp-2.3.3-py3-none-any.whl (10.0 kB view details)

Uploaded Python 3

File details

Details for the file trainkit_vp-2.3.3.tar.gz.

File metadata

  • Download URL: trainkit_vp-2.3.3.tar.gz
  • Upload date:
  • Size: 11.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.2

File hashes

Hashes for trainkit_vp-2.3.3.tar.gz
Algorithm Hash digest
SHA256 f54df10c17834aac7687f33bb4dfec468f7592248eacb2978f1c44a658799933
MD5 9a6a5bd1716957ca887f3d9ebc664c45
BLAKE2b-256 7286e6bb429eda594e5fb9c695650d53de15ae8695ac2d1d68025b15e5187651

See more details on using hashes here.

File details

Details for the file trainkit_vp-2.3.3-py3-none-any.whl.

File metadata

  • Download URL: trainkit_vp-2.3.3-py3-none-any.whl
  • Upload date:
  • Size: 10.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.2

File hashes

Hashes for trainkit_vp-2.3.3-py3-none-any.whl
Algorithm Hash digest
SHA256 22766e5c5b188fff63c858b88f6ba16854979914ebcd52241924a6628671ec9e
MD5 5188ccbe6594ed18dacec451dc584a80
BLAKE2b-256 28c7e1f4ff069dbf16919e7c06388ad40bdb5fe3283d420aebe031b4ae6942d6

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page