Skip to main content

Typed PyTorch training utilities for checkpoints, scheduling, freezing, memory planning, and structured losses.

Project description

trainkit-vp

trainkit-vp collects typed training utilities for PyTorch projects: checkpointing, staged freezing, loss aggregation, scheduler setup, and memory planning.

PyPI package name:

pip install trainkit-vp

Import name:

import trainkit

The package is intentionally generic. It is not tied to a single model family, and it is useful both in research code and in reusable application packages.

What is included

trainkit-vp currently provides:

  • checkpoint save/load helpers
  • staged freezing helpers for hierarchical models
  • a structured multi-level sequence loss
  • cosine warmup scheduling
  • accelerator-aware memory estimation and batch-size planning

Installation

Requirements:

  • Python >=3.14
  • PyTorch >=2.0

Install from PyPI:

pip install trainkit-vp

Checkpoint helpers

Public functions:

  • save_checkpoint
  • load_checkpoint

These wrap torch.save and torch.load behind a typed, small surface that downstream packages can standardize around.

Typical usage:

from pathlib import Path

from trainkit import load_checkpoint, save_checkpoint

payload = {
    "epoch": 3,
    "model_state": model.state_dict(),
    "optimizer_state": optimizer.state_dict(),
}

path = Path("artifacts/best.pt")
save_checkpoint(path, payload)
restored = load_checkpoint(path)

Freezing helpers

Public functions:

  • set_requires_grad
  • freeze_stages

set_requires_grad is a simple utility for enabling or disabling training on a module.

freeze_stages is aimed at hierarchical setups. In the current code it knows how to freeze universal, family, language, or joint stages for a model exposing a CELMoE-style level_modules layout.

That makes it useful for:

  • staged pretraining
  • late unfreezing
  • expert specialization phases

Multi-level sequence loss

The main structured loss class is MultiLevelSequenceLoss.

It expects a model output dictionary containing:

  • logits
  • universal_logits
  • family_logits
  • language_logits

It returns a LossBreakdown dataclass with:

  • total
  • final
  • universal
  • family
  • language

Example:

from chartoken import PAD
from trainkit import MultiLevelSequenceLoss

criterion = MultiLevelSequenceLoss(
    pad_id=PAD,
    final_weight=1.0,
    universal_weight=0.25,
    family_weight=0.35,
    language_weight=0.5,
)

breakdown = criterion(output, target_ids)
print(breakdown.total)
print(breakdown.as_scalars())

Scheduler helper

create_cosine_schedule builds a linear-warmup cosine-decay schedule on top of torch.optim.lr_scheduler.LambdaLR.

Example:

from trainkit import create_cosine_schedule

scheduler = create_cosine_schedule(
    optimizer,
    warmup_steps=500,
    total_steps=12000,
)

Memory planning utilities

Public functions:

  • format_bytes
  • detect_memory
  • estimate_model_memory
  • estimate_training_vram
  • suggest_batch_size
  • plan_training

Main dataclasses:

  • MemoryProfile
  • TrainingPlan

These helpers are intended for early experiment planning rather than exact runtime accounting. They give fast estimates for:

  • available system RAM
  • available accelerator memory
  • approximate model memory footprint
  • approximate training VRAM requirements
  • a suggested batch size

The implementation includes typed wrappers around CUDA, XPU, and MPS memory queries so static analyzers remain happy in strict mode.

Example workflow

from trainkit import detect_memory, plan_training

profile = detect_memory()
plan = plan_training(
    model,
    sequence_length=96,
    batch_size=64,
    optimizer_states_factor=2.0,
)

print(profile)
print(plan)

Scope boundaries

trainkit-vp does not define:

  • your model architecture
  • your dataset format
  • your optimizer choice
  • distributed training logic

Instead it provides small, reliable pieces that application packages can compose however they want.

Why publish it separately

Separating training utilities into their own package keeps:

  • model packages smaller
  • checkpoint format decisions reusable
  • type contracts stable across projects
  • experimentation code less duplicated

In this repository that matters because libraries are published independently and should remain usable outside morphoformer.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

trainkit_vp-2.1.0.tar.gz (9.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

trainkit_vp-2.1.0-py3-none-any.whl (8.6 kB view details)

Uploaded Python 3

File details

Details for the file trainkit_vp-2.1.0.tar.gz.

File metadata

  • Download URL: trainkit_vp-2.1.0.tar.gz
  • Upload date:
  • Size: 9.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.2

File hashes

Hashes for trainkit_vp-2.1.0.tar.gz
Algorithm Hash digest
SHA256 06042da8c137ff480a1964e34a61235c0bfa1ee80196a77dfa8da9fe07400d9b
MD5 d64dfcc8e6f4197a62ab302238d77010
BLAKE2b-256 fcb6d7da0ce70feb458c10202d60812e35a2a9d109795e05b1eb83525d85d4e0

See more details on using hashes here.

File details

Details for the file trainkit_vp-2.1.0-py3-none-any.whl.

File metadata

  • Download URL: trainkit_vp-2.1.0-py3-none-any.whl
  • Upload date:
  • Size: 8.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.2

File hashes

Hashes for trainkit_vp-2.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 d85fe69e15080a7e5111723f21c67c16994f777346ddc3abf0745e095fef77a3
MD5 ee2b3076687f0b8ad95e10a8ae38e675
BLAKE2b-256 80091d8ff6b82f485c02b557cafa895c8ed8652f832a88972a2625f31823b640

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page