Skip to main content

Typed PyTorch training utilities for checkpoints, scheduling, freezing, memory planning, and structured losses.

Project description

trainkit-vp

trainkit-vp collects typed training utilities for PyTorch projects: checkpointing, staged freezing, loss aggregation, scheduler setup, and memory planning.

PyPI package name:

pip install trainkit-vp

Import name:

import trainkit

The package is intentionally generic. It is not tied to a single model family, and it is useful both in research code and in reusable application packages.

What is included

trainkit-vp currently provides:

  • checkpoint save/load helpers
  • staged freezing helpers for hierarchical models
  • a structured multi-level sequence loss
  • cosine warmup scheduling
  • accelerator-aware memory estimation and batch-size planning

Installation

Requirements:

  • Python >=3.14
  • PyTorch >=2.0

Install from PyPI:

pip install trainkit-vp

Checkpoint helpers

Public functions:

  • save_checkpoint
  • load_checkpoint

These wrap torch.save and torch.load behind a typed, small surface that downstream packages can standardize around.

Typical usage:

from pathlib import Path

from trainkit import load_checkpoint, save_checkpoint

payload = {
    "epoch": 3,
    "model_state": model.state_dict(),
    "optimizer_state": optimizer.state_dict(),
}

path = Path("artifacts/best.pt")
save_checkpoint(path, payload)
restored = load_checkpoint(path)

Freezing helpers

Public functions:

  • set_requires_grad
  • freeze_stages

set_requires_grad is a simple utility for enabling or disabling training on a module.

freeze_stages is aimed at hierarchical setups. In the current code it knows how to freeze universal, family, language, or joint stages for a model exposing a CELMoE-style level_modules layout.

That makes it useful for:

  • staged pretraining
  • late unfreezing
  • expert specialization phases

Multi-level sequence loss

The main structured loss class is MultiLevelSequenceLoss.

It expects a model output dictionary containing:

  • logits
  • universal_logits
  • family_logits
  • language_logits

It returns a LossBreakdown dataclass with:

  • total
  • final
  • universal
  • family
  • language

Example:

from chartoken import PAD
from trainkit import MultiLevelSequenceLoss

criterion = MultiLevelSequenceLoss(
    pad_id=PAD,
    final_weight=1.0,
    universal_weight=0.25,
    family_weight=0.35,
    language_weight=0.5,
)

breakdown = criterion(output, target_ids)
print(breakdown.total)
print(breakdown.as_scalars())

Scheduler helper

create_cosine_schedule builds a linear-warmup cosine-decay schedule on top of torch.optim.lr_scheduler.LambdaLR.

Example:

from trainkit import create_cosine_schedule

scheduler = create_cosine_schedule(
    optimizer,
    warmup_steps=500,
    total_steps=12000,
)

Memory planning utilities

Public functions:

  • format_bytes
  • detect_memory
  • estimate_model_memory
  • estimate_training_vram
  • suggest_batch_size
  • plan_training

Main dataclasses:

  • MemoryProfile
  • TrainingPlan

These helpers are intended for early experiment planning rather than exact runtime accounting. They give fast estimates for:

  • available system RAM
  • available accelerator memory
  • approximate model memory footprint
  • approximate training VRAM requirements
  • a suggested batch size

The implementation includes typed wrappers around CUDA, XPU, and MPS memory queries so static analyzers remain happy in strict mode.

Example workflow

from trainkit import detect_memory, plan_training

profile = detect_memory()
plan = plan_training(
    model,
    sequence_length=96,
    batch_size=64,
    optimizer_states_factor=2.0,
)

print(profile)
print(plan)

Scope boundaries

trainkit-vp does not define:

  • your model architecture
  • your dataset format
  • your optimizer choice
  • distributed training logic

Instead it provides small, reliable pieces that application packages can compose however they want.

Why publish it separately

Separating training utilities into their own package keeps:

  • model packages smaller
  • checkpoint format decisions reusable
  • type contracts stable across projects
  • experimentation code less duplicated

In this repository that matters because libraries are published independently and should remain usable outside morphoformer.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

trainkit_vp-2.4.0.tar.gz (11.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

trainkit_vp-2.4.0-py3-none-any.whl (10.1 kB view details)

Uploaded Python 3

File details

Details for the file trainkit_vp-2.4.0.tar.gz.

File metadata

  • Download URL: trainkit_vp-2.4.0.tar.gz
  • Upload date:
  • Size: 11.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.2

File hashes

Hashes for trainkit_vp-2.4.0.tar.gz
Algorithm Hash digest
SHA256 54b9398bfae6d6eb1ccd2feaed9395ed64fef525c3a6f12d4f08e0363faae97d
MD5 e2b43459b7cea1ceaa1239f6d0a87a94
BLAKE2b-256 ffe7b5763e186c31f7563877a8e93a4200f44f4a8bdb2fede45c433063f4ec49

See more details on using hashes here.

File details

Details for the file trainkit_vp-2.4.0-py3-none-any.whl.

File metadata

  • Download URL: trainkit_vp-2.4.0-py3-none-any.whl
  • Upload date:
  • Size: 10.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.2

File hashes

Hashes for trainkit_vp-2.4.0-py3-none-any.whl
Algorithm Hash digest
SHA256 8eed3a4c5c221a3a86aec4d8540cc288bfeae141aaf80a76044e131c51c03ccb
MD5 0c640e892e4f661be0ee0ce94fbbed16
BLAKE2b-256 11d01a6f636cc4fb87aac3565227b55f9d7bf1e340f2cb0644268b2d323e455f

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page