Skip to main content

Typed PyTorch training utilities for checkpoints, scheduling, freezing, memory planning, and structured losses.

Project description

trainkit-vp

trainkit-vp collects typed training utilities for PyTorch projects: checkpointing, staged freezing, loss aggregation, scheduler setup, and memory planning.

PyPI package name:

pip install trainkit-vp

Import name:

import trainkit

The package is intentionally generic. It is not tied to a single model family, and it is useful both in research code and in reusable application packages.

What is included

trainkit-vp currently provides:

  • checkpoint save/load helpers
  • staged freezing helpers for hierarchical models
  • a structured multi-level sequence loss
  • cosine warmup scheduling
  • accelerator-aware memory estimation and batch-size planning

Installation

Requirements:

  • Python >=3.14
  • PyTorch >=2.0

Install from PyPI:

pip install trainkit-vp

Checkpoint helpers

Public functions:

  • save_checkpoint
  • load_checkpoint

These wrap torch.save and torch.load behind a typed, small surface that downstream packages can standardize around.

Typical usage:

from pathlib import Path

from trainkit import load_checkpoint, save_checkpoint

payload = {
    "epoch": 3,
    "model_state": model.state_dict(),
    "optimizer_state": optimizer.state_dict(),
}

path = Path("artifacts/best.pt")
save_checkpoint(path, payload)
restored = load_checkpoint(path)

Freezing helpers

Public functions:

  • set_requires_grad
  • freeze_stages

set_requires_grad is a simple utility for enabling or disabling training on a module.

freeze_stages is aimed at hierarchical setups. In the current code it knows how to freeze universal, family, language, or joint stages for a model exposing a CELMoE-style level_modules layout.

That makes it useful for:

  • staged pretraining
  • late unfreezing
  • expert specialization phases

Multi-level sequence loss

The main structured loss class is MultiLevelSequenceLoss.

It expects a model output dictionary containing:

  • logits
  • universal_logits
  • family_logits
  • language_logits

It returns a LossBreakdown dataclass with:

  • total
  • final
  • universal
  • family
  • language

Example:

from chartoken import PAD
from trainkit import MultiLevelSequenceLoss

criterion = MultiLevelSequenceLoss(
    pad_id=PAD,
    final_weight=1.0,
    universal_weight=0.25,
    family_weight=0.35,
    language_weight=0.5,
)

breakdown = criterion(output, target_ids)
print(breakdown.total)
print(breakdown.as_scalars())

Scheduler helper

create_cosine_schedule builds a linear-warmup cosine-decay schedule on top of torch.optim.lr_scheduler.LambdaLR.

Example:

from trainkit import create_cosine_schedule

scheduler = create_cosine_schedule(
    optimizer,
    warmup_steps=500,
    total_steps=12000,
)

Memory planning utilities

Public functions:

  • format_bytes
  • detect_memory
  • estimate_model_memory
  • estimate_training_vram
  • suggest_batch_size
  • plan_training

Main dataclasses:

  • MemoryProfile
  • TrainingPlan

These helpers are intended for early experiment planning rather than exact runtime accounting. They give fast estimates for:

  • available system RAM
  • available accelerator memory
  • approximate model memory footprint
  • approximate training VRAM requirements
  • a suggested batch size

The implementation includes typed wrappers around CUDA, XPU, and MPS memory queries so static analyzers remain happy in strict mode.

Example workflow

from trainkit import detect_memory, plan_training

profile = detect_memory()
plan = plan_training(
    model,
    sequence_length=96,
    batch_size=64,
    optimizer_states_factor=2.0,
)

print(profile)
print(plan)

Scope boundaries

trainkit-vp does not define:

  • your model architecture
  • your dataset format
  • your optimizer choice
  • distributed training logic

Instead it provides small, reliable pieces that application packages can compose however they want.

Why publish it separately

Separating training utilities into their own package keeps:

  • model packages smaller
  • checkpoint format decisions reusable
  • type contracts stable across projects
  • experimentation code less duplicated

In this repository that matters because libraries are published independently and should remain usable outside morphoformer.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

trainkit_vp-2.3.4.tar.gz (11.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

trainkit_vp-2.3.4-py3-none-any.whl (10.0 kB view details)

Uploaded Python 3

File details

Details for the file trainkit_vp-2.3.4.tar.gz.

File metadata

  • Download URL: trainkit_vp-2.3.4.tar.gz
  • Upload date:
  • Size: 11.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.2

File hashes

Hashes for trainkit_vp-2.3.4.tar.gz
Algorithm Hash digest
SHA256 e7307d4e43f70e894b40d6d1a575c34d0db3e529829163fee69b90d7e7cab59d
MD5 2ae03a88a920e54f3ad9647afa82870d
BLAKE2b-256 b88db75fd33fbab3f2d1dd7cc8cd77c5a4c5a5aab285aeb662a706a424947556

See more details on using hashes here.

File details

Details for the file trainkit_vp-2.3.4-py3-none-any.whl.

File metadata

  • Download URL: trainkit_vp-2.3.4-py3-none-any.whl
  • Upload date:
  • Size: 10.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.2

File hashes

Hashes for trainkit_vp-2.3.4-py3-none-any.whl
Algorithm Hash digest
SHA256 ca7d7b721bb6fba725b470f02a4e97554f5ffa1ce012d1d22e27504c9a569965
MD5 e9db7d1d50a1e5423e216cd178fe74b4
BLAKE2b-256 a471bdfa816e1eb1164fad16d0ad89c77593bfa6334e6dbd257f30eee82198b5

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page