Skip to main content

Typed PyTorch training utilities for checkpoints, scheduling, freezing, memory planning, and structured losses.

Project description

trainkit-vp

trainkit-vp collects typed training utilities for PyTorch projects: checkpointing, staged freezing, loss aggregation, scheduler setup, and memory planning.

PyPI package name:

pip install trainkit-vp

Import name:

import trainkit

The package is intentionally generic. It is not tied to a single model family, and it is useful both in research code and in reusable application packages.

What is included

trainkit-vp currently provides:

  • checkpoint save/load helpers
  • staged freezing helpers for hierarchical models
  • a structured multi-level sequence loss
  • cosine warmup scheduling
  • accelerator-aware memory estimation and batch-size planning

Installation

Requirements:

  • Python >=3.14
  • PyTorch >=2.0

Install from PyPI:

pip install trainkit-vp

Checkpoint helpers

Public functions:

  • save_checkpoint
  • load_checkpoint

These wrap torch.save and torch.load behind a typed, small surface that downstream packages can standardize around.

Typical usage:

from pathlib import Path

from trainkit import load_checkpoint, save_checkpoint

payload = {
    "epoch": 3,
    "model_state": model.state_dict(),
    "optimizer_state": optimizer.state_dict(),
}

path = Path("artifacts/best.pt")
save_checkpoint(path, payload)
restored = load_checkpoint(path)

Freezing helpers

Public functions:

  • set_requires_grad
  • freeze_stages

set_requires_grad is a simple utility for enabling or disabling training on a module.

freeze_stages is aimed at hierarchical setups. In the current code it knows how to freeze universal, family, language, or joint stages for a model exposing a CELMoE-style level_modules layout.

That makes it useful for:

  • staged pretraining
  • late unfreezing
  • expert specialization phases

Multi-level sequence loss

The main structured loss class is MultiLevelSequenceLoss.

It expects a model output dictionary containing:

  • logits
  • universal_logits
  • family_logits
  • language_logits

It returns a LossBreakdown dataclass with:

  • total
  • final
  • universal
  • family
  • language

Example:

from chartoken import PAD
from trainkit import MultiLevelSequenceLoss

criterion = MultiLevelSequenceLoss(
    pad_id=PAD,
    final_weight=1.0,
    universal_weight=0.25,
    family_weight=0.35,
    language_weight=0.5,
)

breakdown = criterion(output, target_ids)
print(breakdown.total)
print(breakdown.as_scalars())

Scheduler helper

create_cosine_schedule builds a linear-warmup cosine-decay schedule on top of torch.optim.lr_scheduler.LambdaLR.

Example:

from trainkit import create_cosine_schedule

scheduler = create_cosine_schedule(
    optimizer,
    warmup_steps=500,
    total_steps=12000,
)

Memory planning utilities

Public functions:

  • format_bytes
  • detect_memory
  • estimate_model_memory
  • estimate_training_vram
  • suggest_batch_size
  • plan_training

Main dataclasses:

  • MemoryProfile
  • TrainingPlan

These helpers are intended for early experiment planning rather than exact runtime accounting. They give fast estimates for:

  • available system RAM
  • available accelerator memory
  • approximate model memory footprint
  • approximate training VRAM requirements
  • a suggested batch size

The implementation includes typed wrappers around CUDA, XPU, and MPS memory queries so static analyzers remain happy in strict mode.

Example workflow

from trainkit import detect_memory, plan_training

profile = detect_memory()
plan = plan_training(
    model,
    sequence_length=96,
    batch_size=64,
    optimizer_states_factor=2.0,
)

print(profile)
print(plan)

Scope boundaries

trainkit-vp does not define:

  • your model architecture
  • your dataset format
  • your optimizer choice
  • distributed training logic

Instead it provides small, reliable pieces that application packages can compose however they want.

Why publish it separately

Separating training utilities into their own package keeps:

  • model packages smaller
  • checkpoint format decisions reusable
  • type contracts stable across projects
  • experimentation code less duplicated

In this repository that matters because libraries are published independently and should remain usable outside morphoformer.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

trainkit_vp-2.2.0.tar.gz (11.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

trainkit_vp-2.2.0-py3-none-any.whl (10.0 kB view details)

Uploaded Python 3

File details

Details for the file trainkit_vp-2.2.0.tar.gz.

File metadata

  • Download URL: trainkit_vp-2.2.0.tar.gz
  • Upload date:
  • Size: 11.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.2

File hashes

Hashes for trainkit_vp-2.2.0.tar.gz
Algorithm Hash digest
SHA256 1e8ab70551ed4cefa6891e0da71a852312c2579df496f36303706b1b35d4a7ed
MD5 c2ce82d868070b328d4c50ff50cda832
BLAKE2b-256 6274b86b38dd6e49b20afc8210ee010b28fd2f15dec25f591ca63b5845d981a8

See more details on using hashes here.

File details

Details for the file trainkit_vp-2.2.0-py3-none-any.whl.

File metadata

  • Download URL: trainkit_vp-2.2.0-py3-none-any.whl
  • Upload date:
  • Size: 10.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.2

File hashes

Hashes for trainkit_vp-2.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 1e08a682dfb4e4bd5e513f412929ed520131d6b6e93db25054f5ebd50dbd6e65
MD5 4642eea5fd88db878ab5bc4261c4c2b2
BLAKE2b-256 bfac47bdf1819e3705b3d80bd5cfc7218b2cf6cfef2c72c33c855651a19dfb3c

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page