Skip to main content

CPU-first pure-Rust supervised trainer for Selective State Space Models with Hyperspherical Prototype Networks.

Project description

RUST Trainer

CI Crates.io License: Apache-2.0

A CPU-first Rust training package implementing a Mamba SSM + Hyperspherical Prototype Network (HPN) architecture.

This is a concrete, working reference implementation — not a blank framework. The model is a stack of Mamba selective state-space layers with an HPN cosine-distance output head and learnable prototype matrix. Teams can use it as-is or fork and replace the layer/loss internals for their own architecture.

It works as both:

  • a library dependency in your Rust application
  • a ready-to-run CLI training binary

No Python runtime is required.


What this package gives you

Capability Status
End-to-end train loop binary
Library API for embedding custom model/training logic
Serializable optimizer state (AdamW)
Resume-safe checkpoints (model + optimizer + step)
JSONL metrics logging
Configurable layer expansion and freezing
Deterministic parity probe for save/load correctness
SIMD math kernels for high-throughput CPU training
Validation cadence + best-checkpoint tracking
Early stopping support
Gradient clipping controls
LR warmup + cosine decay controls
Non-finite update guardrails
Sharded streaming dataset support
Packed sequence batching on shard streams
Multi-worker sharded prefetch ingestion
Run-state resume for stream cursors
Atomic versioned checkpoints
Cross-framework parity harness (Rust vs Python/JAX)

Production readiness status

Current state: production-candidate for single-node CPU training with production-critical ingestion and parity validation implemented.

What is already robust:

  • deterministic resume behavior (checkpoint + optimizer state + step)
  • deterministic resume behavior for streaming shard cursors (run_state.json)
  • configurable expansion/freeze controls for staged training
  • validated SIMD and backward kernels with scalar parity probes
  • CI, release, and crate packaging automation

Operational note:

  • cross-framework parity runner requires jax to be installed in the active Python environment

Detailed roadmap and release milestones are tracked in roadmap.md.


Design philosophy

  • Keep trainer internals explicit and hackable.
  • Favor reproducible runs and resumability.
  • Make the package easy to fork and specialize for custom architectures.
  • Keep data ingestion simple at first (integer token files), then scale to streaming pipelines.

Repository layout

src/
  lib.rs            - crate root and public exports
  generic_trainer.rs  - full trainer state, train step, checkpoint/resume
  trainer.rs        - parameter and expansion/freezing config types
  optim.rs          - AdamW optimizer primitives
  nn.rs             - layer norm and output-loss helpers
  simd_ops.rs       - SIMD kernels used by the model path
  layer.rs          - cached layer forward/backward helpers
  stack.rs          - stack-level supervised step helpers
src/bin/
  train_generic.rs    - main CLI trainer
  trainer_parity.rs   - deterministic parity/resume checker
  parity_lab.rs     - expansion/freeze behavior harness
  *_probe.rs        - low-level probes used for validation

Quick start

git clone https://github.com/npradeep357/rust_trainer
cd rust_trainer
cargo test

Run a short smoke training job:

cargo run --release --bin train_generic -- \
  --steps 200 \
  --batch-size 4 \
  --seq-len 32 \
  --out-dir runs/smoke

Run deterministic resume parity check:

cargo run --release --bin trainer_parity

Run Rust vs Python/JAX parity check:

cargo run --release --bin cross_framework_parity

Train your own model data

The default trainer accepts a whitespace-separated integer token file.

cargo run --release --bin train_generic -- \
  --token-file /path/to/your_tokens.txt \
  --out-dir runs/experiment_v1 \
  --steps 50000 \
  --batch-size 8 \
  --seq-len 64 \
  --d-model 512 \
  --d-state 16 \
  --base-layers 2 \
  --target-layers 6 \
  --placement specific:1,3,4,5 \
  --freeze first:2 \
  --lr 1e-4

Resume training:

cargo run --release --bin train_generic -- \
  --resume runs/experiment_v1/latest.bincode \
  --out-dir runs/experiment_v1 \
  --steps 20000

CLI reference

Flag Default Description
--out-dir PATH runs/ Output directory for checkpoints and metrics
--steps N 5000 Number of train steps
--save-every N 200 Checkpoint interval
--log-every N 20 Metric logging interval
--batch-size N 8 Batch size
--seq-len N 64 Sequence length
--seed N 42 RNG seed
--base-layers N 2 Initial layer count before expansion
--target-layers N 6 Final layer count after expansion
--d-model N 512 Hidden width
--d-state N 16 State width
--d-conv N 4 Convolution kernel width
--placement STR specific:1,3,4,5 Expansion placement
--freeze STR first:2 Freeze policy
--lr F 1e-4 AdamW learning rate
--ff-lr F 1e-4 Forward-Forward local learning rate for d_skip updates
--bp-cadence-steps N 32 Apply global BP every N train steps (FF runs each step)
--gradient-surgery-method STR pcgrad Conflict handling method: pcgrad, gradnorm, cagradstep
--gradient-surgery-epsilon F 1e-8 Numerical stability epsilon for surgery operations
--gradnorm-alpha F 0.2 GradNorm disagreement scaling factor
--cagrad-lambda F 1.0 CAGradStep conflict-aversion strength
--freeze-embedding 1 false Freeze embedding table
--token-file PATH none Integer token dataset
--token-dir PATH none Directory of shard files for streaming training
--val-token-file PATH none Optional dedicated validation token dataset
--val-token-dir PATH none Optional validation shard directory
--shard-ext EXT txt Extension filter used with --token-dir / --val-token-dir
--shuffle-shards 1 true Shuffle shard order each epoch in streaming mode
--packed-sequences 1 true Use packed contiguous token windows in streaming mode
--prefetch-workers N 0 Number of worker threads for sharded prefetch (>1 enables multi-worker mode)
--prefetch-buffer N 16 Bounded channel capacity for prefetched worker batches
--resume PATH none Resume checkpoint
--vocab-size N auto Override vocab size
--val-ratio F 0.05 Validation split ratio when --val-token-file is not provided
--val-every N 200 Validation cadence in train steps
--eval-batches N 8 Number of validation batches per eval pass
--early-stopping-patience N 0 Stop when validation does not improve for N eval windows (0 disables)
--grad-clip-norm F 0.0 Global gradient clipping threshold (0 disables clipping)
--fail-on-non-finite 1 false Panic on NaN/Inf detection instead of skipping the update
--lr-warmup-steps N 0 Linear warmup length before decay
--lr-min-scale F 0.1 Minimum LR floor as fraction of base LR for cosine decay

Debug + recovery artifacts

  • latest.bincode: latest atomic, versioned checkpoint (model + optimizer + step)
  • best.bincode: best validation checkpoint
  • run_state.json: resumable data-pipeline state (in-memory cursor or shard stream cursor)
  • metrics.jsonl: train/validation metrics stream for dashboards and debugging

Placement values

Value Meaning
append Add new layers at the end
prepend Add new layers at the beginning
insert:N Insert all new layers starting at index N
specific:1,3,4,5 Place each new layer at specific final indices

Freeze values

Value Meaning
first:N Freeze first N layers
indices:0,2,5 Freeze explicit layer indices

Use as a library

Add dependency:

[dependencies]
rust_trainer = "0.1"

Use the package name from your own Cargo.toml.

Minimal integration example:

use rust_trainer::generic_trainer::{
    GenericTrainer, default_trainer_config, make_batch_from_tokens,
};
use rust_trainer::{ExpansionPlacement, FreezeSelection, LayerSpec};

let spec = LayerSpec { d_model: 512, d_state: 16, d_conv: 4 };
let cfg = default_trainer_config(
    8192,
    spec,
    6,
    ExpansionPlacement::SpecificPositions(vec![1, 3, 4, 5]),
    FreezeSelection::FirstN(2),
    false,
    1e-4,
);

let mut trainer = GenericTrainer::new_random(cfg, 2, 42);
let tokens: Vec<i64> = (0..8192).collect();
let (ids, targets) = make_batch_from_tokens(&tokens, 0, 8, 64);
let stats = trainer.train_step(&ids, &targets);
println!("loss: {}", stats.loss);
trainer.save_checkpoint("checkpoint.bincode").unwrap();

Architecture

The default model uses:

Component Implementation
Sequence layers Mamba SSM (causal conv1d + SiLU + discretized state scan)
Output head Hyperspherical Prototype Network (HPN)
Loss Squared cosine distance to nearest prototype
Optimizer AdamW with serializable moment buffers
Inference path CPU-only; no GPU required

Customize for your own architecture

The package is designed to be forked for other architectures. Replace or extend:

  1. Layer forward/backward path in src/layer.rs — swap Mamba for Transformer, LSTM, etc.
  2. Output loss/head logic in src/nn.rs — swap HPN for cross-entropy, contrastive loss, etc.
  3. Trainer state wiring in src/generic_trainer.rs — add or remove parameter groups
  4. Data loading logic in src/bin/train_generic.rs

The checkpointing, optimizer state, logging, expansion, and freeze infrastructure are all architecture-independent and can be kept as-is.


Release flow

Releases are tag-driven via GitHub Actions.

# bump version in Cargo.toml, commit, then:
git tag v0.2.0
git push origin v0.2.0

The release workflow runs tests, builds binaries, creates a GitHub Release, and can publish to crates.io when credentials are configured.


License

Apache-2.0. See LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

rust_trainer-0.1.4-cp313-cp313-win_amd64.whl (267.8 kB view details)

Uploaded CPython 3.13Windows x86-64

rust_trainer-0.1.4-cp313-cp313-manylinux_2_28_x86_64.whl (373.2 kB view details)

Uploaded CPython 3.13manylinux: glibc 2.28+ x86-64

rust_trainer-0.1.4-cp313-cp313-manylinux_2_28_aarch64.whl (363.3 kB view details)

Uploaded CPython 3.13manylinux: glibc 2.28+ ARM64

rust_trainer-0.1.4-cp313-cp313-macosx_11_0_arm64.whl (330.2 kB view details)

Uploaded CPython 3.13macOS 11.0+ ARM64

rust_trainer-0.1.4-cp313-cp313-macosx_10_12_x86_64.whl (349.7 kB view details)

Uploaded CPython 3.13macOS 10.12+ x86-64

rust_trainer-0.1.4-cp312-cp312-win_amd64.whl (267.9 kB view details)

Uploaded CPython 3.12Windows x86-64

rust_trainer-0.1.4-cp312-cp312-manylinux_2_28_x86_64.whl (373.3 kB view details)

Uploaded CPython 3.12manylinux: glibc 2.28+ x86-64

rust_trainer-0.1.4-cp312-cp312-manylinux_2_28_aarch64.whl (363.3 kB view details)

Uploaded CPython 3.12manylinux: glibc 2.28+ ARM64

rust_trainer-0.1.4-cp312-cp312-macosx_11_0_arm64.whl (330.3 kB view details)

Uploaded CPython 3.12macOS 11.0+ ARM64

rust_trainer-0.1.4-cp312-cp312-macosx_10_12_x86_64.whl (349.8 kB view details)

Uploaded CPython 3.12macOS 10.12+ x86-64

rust_trainer-0.1.4-cp311-cp311-win_amd64.whl (267.8 kB view details)

Uploaded CPython 3.11Windows x86-64

rust_trainer-0.1.4-cp311-cp311-manylinux_2_28_x86_64.whl (373.3 kB view details)

Uploaded CPython 3.11manylinux: glibc 2.28+ x86-64

rust_trainer-0.1.4-cp311-cp311-manylinux_2_28_aarch64.whl (363.5 kB view details)

Uploaded CPython 3.11manylinux: glibc 2.28+ ARM64

rust_trainer-0.1.4-cp311-cp311-macosx_11_0_arm64.whl (330.4 kB view details)

Uploaded CPython 3.11macOS 11.0+ ARM64

rust_trainer-0.1.4-cp311-cp311-macosx_10_12_x86_64.whl (350.0 kB view details)

Uploaded CPython 3.11macOS 10.12+ x86-64

File details

Details for the file rust_trainer-0.1.4-cp313-cp313-win_amd64.whl.

File metadata

File hashes

Hashes for rust_trainer-0.1.4-cp313-cp313-win_amd64.whl
Algorithm Hash digest
SHA256 21013ba7f236f7b9e49b81eb50a3517d8948698b00716592872f2d1bf1afe8a9
MD5 eb7b77304f038ec03e4c2342a806ca37
BLAKE2b-256 e80934b9dc979c0a2fbc1ae39c0217d3dfe476be6539e96de43bc0e4aca2344d

See more details on using hashes here.

File details

Details for the file rust_trainer-0.1.4-cp313-cp313-manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for rust_trainer-0.1.4-cp313-cp313-manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 7b5cdeb87bde9e132906c4262dda6d170eb8167c30f3c6135d2a8551cc17ddd6
MD5 e899850c923b725d831ee75048a404c3
BLAKE2b-256 4c6b50f45000c412d6edaf22aefc33535c723e61ded9153a70327ffc5c36c940

See more details on using hashes here.

File details

Details for the file rust_trainer-0.1.4-cp313-cp313-manylinux_2_28_aarch64.whl.

File metadata

File hashes

Hashes for rust_trainer-0.1.4-cp313-cp313-manylinux_2_28_aarch64.whl
Algorithm Hash digest
SHA256 e053b09c35571ed3f87fe7bd8cc65ff9655f7987fde3d54c142b4643b722710e
MD5 3f8c2a2cc78d43d714c9541e5eddabce
BLAKE2b-256 944684982165cd4ccd408f90aeec498c2d058d1e4b13a44d66efb65f30d277d3

See more details on using hashes here.

File details

Details for the file rust_trainer-0.1.4-cp313-cp313-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for rust_trainer-0.1.4-cp313-cp313-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 ca8b6ca9d842ebdb6b34db7ec7c203e43ed13aecf5a7abd82644fba63e9efe77
MD5 002ebe72749f2cf33a3ba1589c0d7f80
BLAKE2b-256 a210b00fc7c2e94b5aa854dd2d854b8640d2b388bd3493f6d6958c13e9765223

See more details on using hashes here.

File details

Details for the file rust_trainer-0.1.4-cp313-cp313-macosx_10_12_x86_64.whl.

File metadata

File hashes

Hashes for rust_trainer-0.1.4-cp313-cp313-macosx_10_12_x86_64.whl
Algorithm Hash digest
SHA256 e8a962a667d23b8588d73d490290573e9e15215a9843770919e4947aa266da6f
MD5 f7042cdea80bec8daa2462d3b8e88420
BLAKE2b-256 c61d01db94154d3f7e1394802f9acc9c28d2d5bc148b1a849cf6eeca09c2835b

See more details on using hashes here.

File details

Details for the file rust_trainer-0.1.4-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for rust_trainer-0.1.4-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 f5d832bba4115b1cdb4714c8f7c6392a3724e90860720e462d50a4ef05eace52
MD5 655cf189e2259f10bf4d5513d5289dec
BLAKE2b-256 889cfdf3c44cdd387c57480aebc52a9327b2099a42facc0e2179e72d204113e6

See more details on using hashes here.

File details

Details for the file rust_trainer-0.1.4-cp312-cp312-manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for rust_trainer-0.1.4-cp312-cp312-manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 14bf7c141836b670b6ec4da89c947f7104222e7198f5ca535c027e9e33b863b4
MD5 fcf001f4c2cfa9c939ebd19e2101a9f1
BLAKE2b-256 34e8e9821432ac1f49421721b1386f3358f5a9b41b8f2c428b986445f5a62e07

See more details on using hashes here.

File details

Details for the file rust_trainer-0.1.4-cp312-cp312-manylinux_2_28_aarch64.whl.

File metadata

File hashes

Hashes for rust_trainer-0.1.4-cp312-cp312-manylinux_2_28_aarch64.whl
Algorithm Hash digest
SHA256 240b43bd451809a3b79c9127441a6aea396e9b0f688780faf5137c26f19e8686
MD5 99324549f146b28045555041762269e1
BLAKE2b-256 eb27b5f64a9aa68b2e5c0887bf050e71f8117fe05cc6611a07b0d9932e507b82

See more details on using hashes here.

File details

Details for the file rust_trainer-0.1.4-cp312-cp312-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for rust_trainer-0.1.4-cp312-cp312-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 c60bf0b78f1024a2630527a5315c1ad0dd0c0b6a8da82e6a07a03c8923605cb4
MD5 7a36773b3ec9f876f7d01dda885b6530
BLAKE2b-256 d783cfa517c2a0810b812bd26124e9a00adbb761d2d3ce08a49903f8490b2133

See more details on using hashes here.

File details

Details for the file rust_trainer-0.1.4-cp312-cp312-macosx_10_12_x86_64.whl.

File metadata

File hashes

Hashes for rust_trainer-0.1.4-cp312-cp312-macosx_10_12_x86_64.whl
Algorithm Hash digest
SHA256 92af2dff2467b570294aa3319d0fa5f02609ad6cdf10fc2ce213f7e8e6cce2ec
MD5 32192dc55d608e8ee81decb90a90becb
BLAKE2b-256 249887ed6c6874e783cec47d1afcb9fb02e257e873279f2592c153269f146eac

See more details on using hashes here.

File details

Details for the file rust_trainer-0.1.4-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for rust_trainer-0.1.4-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 90fe84df4633b23fa3e03ed0bb557adc6c736f429dab5c32a979167608c69cd7
MD5 8bfda5ee92e354edf5cb7f95966817de
BLAKE2b-256 2b301ee54e93c8e75142a45ddc51bdc428b49a412bec59a8f1205faff9c71012

See more details on using hashes here.

File details

Details for the file rust_trainer-0.1.4-cp311-cp311-manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for rust_trainer-0.1.4-cp311-cp311-manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 8cf26d2a92722676c9235c4f19b44888456cec45565cf2526a098b65d54a4f75
MD5 06b56abba4ea39e62ffd4a0710f0c5c4
BLAKE2b-256 0bb1b79470f0f71a2f614690dd0249d7bcd63826dea3ad6558993b3a2de6cce3

See more details on using hashes here.

File details

Details for the file rust_trainer-0.1.4-cp311-cp311-manylinux_2_28_aarch64.whl.

File metadata

File hashes

Hashes for rust_trainer-0.1.4-cp311-cp311-manylinux_2_28_aarch64.whl
Algorithm Hash digest
SHA256 5168160238bb491a1ea6679c6c513d6ddc4f5465c75d36dec1f0c78d0304a969
MD5 84fabe2b4a290e1093ad22cdb81d432f
BLAKE2b-256 c12eb448b21d41d9ec644bc8c2fa21a48d850957dd28bba93e61bd12827d0f0f

See more details on using hashes here.

File details

Details for the file rust_trainer-0.1.4-cp311-cp311-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for rust_trainer-0.1.4-cp311-cp311-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 482730fd3798068385cb87aa7e67042e7a839b1834afb92614cc1e8b1eff969a
MD5 a1b7306a5788e5a12a7e8d0ac5f0052e
BLAKE2b-256 3be4fcdbef2ba9b45ead903abc8ec86c3bf8bd166f101c88023b7af6897cb1d8

See more details on using hashes here.

File details

Details for the file rust_trainer-0.1.4-cp311-cp311-macosx_10_12_x86_64.whl.

File metadata

File hashes

Hashes for rust_trainer-0.1.4-cp311-cp311-macosx_10_12_x86_64.whl
Algorithm Hash digest
SHA256 e1a18dda89041ade2fe62d28b78da40e3be1fe02f25b3eaf3ab145e4bf191839
MD5 045fa6923cc6558f3c1f73714ddc5b4e
BLAKE2b-256 40d2a2eb70217d24f00fc9d7c500ab65da98b4d5471035e6073313b66dd252cf

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page