CPU-first pure-Rust supervised trainer for Selective State Space Models with Hyperspherical Prototype Networks.
Project description
RUST Trainer
A CPU-first Rust training package implementing a Mamba SSM + Hyperspherical Prototype Network (HPN) architecture.
This is a concrete, working reference implementation — not a blank framework. The model is a stack of Mamba selective state-space layers with an HPN cosine-distance output head and learnable prototype matrix. Teams can use it as-is or fork and replace the layer/loss internals for their own architecture.
It works as both:
- a library dependency in your Rust application
- a ready-to-run CLI training binary
No Python runtime is required.
What this package gives you
| Capability | Status |
|---|---|
| End-to-end train loop binary | ✅ |
| Library API for embedding custom model/training logic | ✅ |
| Serializable optimizer state (AdamW) | ✅ |
| Resume-safe checkpoints (model + optimizer + step) | ✅ |
| JSONL metrics logging | ✅ |
| Configurable layer expansion and freezing | ✅ |
| Deterministic parity probe for save/load correctness | ✅ |
| SIMD math kernels for high-throughput CPU training | ✅ |
| Validation cadence + best-checkpoint tracking | ✅ |
| Early stopping support | ✅ |
| Gradient clipping controls | ✅ |
| LR warmup + cosine decay controls | ✅ |
| Non-finite update guardrails | ✅ |
| Sharded streaming dataset support | ✅ |
| Packed sequence batching on shard streams | ✅ |
| Multi-worker sharded prefetch ingestion | ✅ |
| Run-state resume for stream cursors | ✅ |
| Atomic versioned checkpoints | ✅ |
| Cross-framework parity harness (Rust vs Python/JAX) | ✅ |
Production readiness status
Current state: production-candidate for single-node CPU training with production-critical ingestion and parity validation implemented.
What is already robust:
- deterministic resume behavior (checkpoint + optimizer state + step)
- deterministic resume behavior for streaming shard cursors (
run_state.json) - configurable expansion/freeze controls for staged training
- validated SIMD and backward kernels with scalar parity probes
- CI, release, and crate packaging automation
Operational note:
- cross-framework parity runner requires
jaxto be installed in the active Python environment
Detailed roadmap and release milestones are tracked in roadmap.md.
Design philosophy
- Keep trainer internals explicit and hackable.
- Favor reproducible runs and resumability.
- Make the package easy to fork and specialize for custom architectures.
- Keep data ingestion simple at first (integer token files), then scale to streaming pipelines.
Repository layout
src/
lib.rs - crate root and public exports
generic_trainer.rs - full trainer state, train step, checkpoint/resume
trainer.rs - parameter and expansion/freezing config types
optim.rs - AdamW optimizer primitives
nn.rs - layer norm and output-loss helpers
simd_ops.rs - SIMD kernels used by the model path
layer.rs - cached layer forward/backward helpers
stack.rs - stack-level supervised step helpers
src/bin/
train_generic.rs - main CLI trainer
trainer_parity.rs - deterministic parity/resume checker
parity_lab.rs - expansion/freeze behavior harness
*_probe.rs - low-level probes used for validation
Quick start
git clone https://github.com/npradeep357/rust_trainer
cd rust_trainer
cargo test
Run a short smoke training job:
cargo run --release --bin train_generic -- \
--steps 200 \
--batch-size 4 \
--seq-len 32 \
--out-dir runs/smoke
Run deterministic resume parity check:
cargo run --release --bin trainer_parity
Run Rust vs Python/JAX parity check:
cargo run --release --bin cross_framework_parity
Train your own model data
The default trainer accepts a whitespace-separated integer token file.
cargo run --release --bin train_generic -- \
--token-file /path/to/your_tokens.txt \
--out-dir runs/experiment_v1 \
--steps 50000 \
--batch-size 8 \
--seq-len 64 \
--d-model 512 \
--d-state 16 \
--base-layers 2 \
--target-layers 6 \
--placement specific:1,3,4,5 \
--freeze first:2 \
--lr 1e-4
Resume training:
cargo run --release --bin train_generic -- \
--resume runs/experiment_v1/latest.bincode \
--out-dir runs/experiment_v1 \
--steps 20000
CLI reference
| Flag | Default | Description |
|---|---|---|
--out-dir PATH |
runs/ |
Output directory for checkpoints and metrics |
--steps N |
5000 |
Number of train steps |
--save-every N |
200 |
Checkpoint interval |
--log-every N |
20 |
Metric logging interval |
--batch-size N |
8 |
Batch size |
--seq-len N |
64 |
Sequence length |
--seed N |
42 |
RNG seed |
--base-layers N |
2 |
Initial layer count before expansion |
--target-layers N |
6 |
Final layer count after expansion |
--d-model N |
512 |
Hidden width |
--d-state N |
16 |
State width |
--d-conv N |
4 |
Convolution kernel width |
--placement STR |
specific:1,3,4,5 |
Expansion placement |
--freeze STR |
first:2 |
Freeze policy |
--lr F |
1e-4 |
AdamW learning rate |
--ff-lr F |
1e-4 |
Forward-Forward local learning rate for d_skip updates |
--bp-cadence-steps N |
32 |
Apply global BP every N train steps (FF runs each step) |
--gradient-surgery-method STR |
pcgrad |
Conflict handling method: pcgrad, gradnorm, cagradstep |
--gradient-surgery-epsilon F |
1e-8 |
Numerical stability epsilon for surgery operations |
--gradnorm-alpha F |
0.2 |
GradNorm disagreement scaling factor |
--cagrad-lambda F |
1.0 |
CAGradStep conflict-aversion strength |
--freeze-embedding 1 |
false |
Freeze embedding table |
--token-file PATH |
none | Integer token dataset |
--token-dir PATH |
none | Directory of shard files for streaming training |
--val-token-file PATH |
none | Optional dedicated validation token dataset |
--val-token-dir PATH |
none | Optional validation shard directory |
--shard-ext EXT |
txt |
Extension filter used with --token-dir / --val-token-dir |
--shuffle-shards 1 |
true |
Shuffle shard order each epoch in streaming mode |
--packed-sequences 1 |
true |
Use packed contiguous token windows in streaming mode |
--prefetch-workers N |
0 |
Number of worker threads for sharded prefetch (>1 enables multi-worker mode) |
--prefetch-buffer N |
16 |
Bounded channel capacity for prefetched worker batches |
--resume PATH |
none | Resume checkpoint |
--vocab-size N |
auto | Override vocab size |
--val-ratio F |
0.05 |
Validation split ratio when --val-token-file is not provided |
--val-every N |
200 |
Validation cadence in train steps |
--eval-batches N |
8 |
Number of validation batches per eval pass |
--early-stopping-patience N |
0 |
Stop when validation does not improve for N eval windows (0 disables) |
--grad-clip-norm F |
0.0 |
Global gradient clipping threshold (0 disables clipping) |
--fail-on-non-finite 1 |
false |
Panic on NaN/Inf detection instead of skipping the update |
--lr-warmup-steps N |
0 |
Linear warmup length before decay |
--lr-min-scale F |
0.1 |
Minimum LR floor as fraction of base LR for cosine decay |
Debug + recovery artifacts
latest.bincode: latest atomic, versioned checkpoint (model + optimizer + step)best.bincode: best validation checkpointrun_state.json: resumable data-pipeline state (in-memory cursor or shard stream cursor)metrics.jsonl: train/validation metrics stream for dashboards and debugging
Placement values
| Value | Meaning |
|---|---|
append |
Add new layers at the end |
prepend |
Add new layers at the beginning |
insert:N |
Insert all new layers starting at index N |
specific:1,3,4,5 |
Place each new layer at specific final indices |
Freeze values
| Value | Meaning |
|---|---|
first:N |
Freeze first N layers |
indices:0,2,5 |
Freeze explicit layer indices |
Use as a library
Add dependency:
[dependencies]
rust_trainer = "0.1"
Use the package name from your own Cargo.toml.
Minimal integration example:
use rust_trainer::generic_trainer::{
GenericTrainer, default_trainer_config, make_batch_from_tokens,
};
use rust_trainer::{ExpansionPlacement, FreezeSelection, LayerSpec};
let spec = LayerSpec { d_model: 512, d_state: 16, d_conv: 4 };
let cfg = default_trainer_config(
8192,
spec,
6,
ExpansionPlacement::SpecificPositions(vec![1, 3, 4, 5]),
FreezeSelection::FirstN(2),
false,
1e-4,
);
let mut trainer = GenericTrainer::new_random(cfg, 2, 42);
let tokens: Vec<i64> = (0..8192).collect();
let (ids, targets) = make_batch_from_tokens(&tokens, 0, 8, 64);
let stats = trainer.train_step(&ids, &targets);
println!("loss: {}", stats.loss);
trainer.save_checkpoint("checkpoint.bincode").unwrap();
Architecture
The default model uses:
| Component | Implementation |
|---|---|
| Sequence layers | Mamba SSM (causal conv1d + SiLU + discretized state scan) |
| Output head | Hyperspherical Prototype Network (HPN) |
| Loss | Squared cosine distance to nearest prototype |
| Optimizer | AdamW with serializable moment buffers |
| Inference path | CPU-only; no GPU required |
Customize for your own architecture
The package is designed to be forked for other architectures. Replace or extend:
- Layer forward/backward path in
src/layer.rs— swap Mamba for Transformer, LSTM, etc. - Output loss/head logic in
src/nn.rs— swap HPN for cross-entropy, contrastive loss, etc. - Trainer state wiring in
src/generic_trainer.rs— add or remove parameter groups - Data loading logic in
src/bin/train_generic.rs
The checkpointing, optimizer state, logging, expansion, and freeze infrastructure are all architecture-independent and can be kept as-is.
Release flow
Releases are tag-driven via GitHub Actions.
# bump version in Cargo.toml, commit, then:
git tag v0.2.0
git push origin v0.2.0
The release workflow runs tests, builds binaries, creates a GitHub Release, and can publish to crates.io when credentials are configured.
License
Apache-2.0. See LICENSE.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distributions
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file rust_trainer-0.1.4-cp313-cp313-win_amd64.whl.
File metadata
- Download URL: rust_trainer-0.1.4-cp313-cp313-win_amd64.whl
- Upload date:
- Size: 267.8 kB
- Tags: CPython 3.13, Windows x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
21013ba7f236f7b9e49b81eb50a3517d8948698b00716592872f2d1bf1afe8a9
|
|
| MD5 |
eb7b77304f038ec03e4c2342a806ca37
|
|
| BLAKE2b-256 |
e80934b9dc979c0a2fbc1ae39c0217d3dfe476be6539e96de43bc0e4aca2344d
|
File details
Details for the file rust_trainer-0.1.4-cp313-cp313-manylinux_2_28_x86_64.whl.
File metadata
- Download URL: rust_trainer-0.1.4-cp313-cp313-manylinux_2_28_x86_64.whl
- Upload date:
- Size: 373.2 kB
- Tags: CPython 3.13, manylinux: glibc 2.28+ x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7b5cdeb87bde9e132906c4262dda6d170eb8167c30f3c6135d2a8551cc17ddd6
|
|
| MD5 |
e899850c923b725d831ee75048a404c3
|
|
| BLAKE2b-256 |
4c6b50f45000c412d6edaf22aefc33535c723e61ded9153a70327ffc5c36c940
|
File details
Details for the file rust_trainer-0.1.4-cp313-cp313-manylinux_2_28_aarch64.whl.
File metadata
- Download URL: rust_trainer-0.1.4-cp313-cp313-manylinux_2_28_aarch64.whl
- Upload date:
- Size: 363.3 kB
- Tags: CPython 3.13, manylinux: glibc 2.28+ ARM64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e053b09c35571ed3f87fe7bd8cc65ff9655f7987fde3d54c142b4643b722710e
|
|
| MD5 |
3f8c2a2cc78d43d714c9541e5eddabce
|
|
| BLAKE2b-256 |
944684982165cd4ccd408f90aeec498c2d058d1e4b13a44d66efb65f30d277d3
|
File details
Details for the file rust_trainer-0.1.4-cp313-cp313-macosx_11_0_arm64.whl.
File metadata
- Download URL: rust_trainer-0.1.4-cp313-cp313-macosx_11_0_arm64.whl
- Upload date:
- Size: 330.2 kB
- Tags: CPython 3.13, macOS 11.0+ ARM64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ca8b6ca9d842ebdb6b34db7ec7c203e43ed13aecf5a7abd82644fba63e9efe77
|
|
| MD5 |
002ebe72749f2cf33a3ba1589c0d7f80
|
|
| BLAKE2b-256 |
a210b00fc7c2e94b5aa854dd2d854b8640d2b388bd3493f6d6958c13e9765223
|
File details
Details for the file rust_trainer-0.1.4-cp313-cp313-macosx_10_12_x86_64.whl.
File metadata
- Download URL: rust_trainer-0.1.4-cp313-cp313-macosx_10_12_x86_64.whl
- Upload date:
- Size: 349.7 kB
- Tags: CPython 3.13, macOS 10.12+ x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e8a962a667d23b8588d73d490290573e9e15215a9843770919e4947aa266da6f
|
|
| MD5 |
f7042cdea80bec8daa2462d3b8e88420
|
|
| BLAKE2b-256 |
c61d01db94154d3f7e1394802f9acc9c28d2d5bc148b1a849cf6eeca09c2835b
|
File details
Details for the file rust_trainer-0.1.4-cp312-cp312-win_amd64.whl.
File metadata
- Download URL: rust_trainer-0.1.4-cp312-cp312-win_amd64.whl
- Upload date:
- Size: 267.9 kB
- Tags: CPython 3.12, Windows x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f5d832bba4115b1cdb4714c8f7c6392a3724e90860720e462d50a4ef05eace52
|
|
| MD5 |
655cf189e2259f10bf4d5513d5289dec
|
|
| BLAKE2b-256 |
889cfdf3c44cdd387c57480aebc52a9327b2099a42facc0e2179e72d204113e6
|
File details
Details for the file rust_trainer-0.1.4-cp312-cp312-manylinux_2_28_x86_64.whl.
File metadata
- Download URL: rust_trainer-0.1.4-cp312-cp312-manylinux_2_28_x86_64.whl
- Upload date:
- Size: 373.3 kB
- Tags: CPython 3.12, manylinux: glibc 2.28+ x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
14bf7c141836b670b6ec4da89c947f7104222e7198f5ca535c027e9e33b863b4
|
|
| MD5 |
fcf001f4c2cfa9c939ebd19e2101a9f1
|
|
| BLAKE2b-256 |
34e8e9821432ac1f49421721b1386f3358f5a9b41b8f2c428b986445f5a62e07
|
File details
Details for the file rust_trainer-0.1.4-cp312-cp312-manylinux_2_28_aarch64.whl.
File metadata
- Download URL: rust_trainer-0.1.4-cp312-cp312-manylinux_2_28_aarch64.whl
- Upload date:
- Size: 363.3 kB
- Tags: CPython 3.12, manylinux: glibc 2.28+ ARM64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
240b43bd451809a3b79c9127441a6aea396e9b0f688780faf5137c26f19e8686
|
|
| MD5 |
99324549f146b28045555041762269e1
|
|
| BLAKE2b-256 |
eb27b5f64a9aa68b2e5c0887bf050e71f8117fe05cc6611a07b0d9932e507b82
|
File details
Details for the file rust_trainer-0.1.4-cp312-cp312-macosx_11_0_arm64.whl.
File metadata
- Download URL: rust_trainer-0.1.4-cp312-cp312-macosx_11_0_arm64.whl
- Upload date:
- Size: 330.3 kB
- Tags: CPython 3.12, macOS 11.0+ ARM64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
c60bf0b78f1024a2630527a5315c1ad0dd0c0b6a8da82e6a07a03c8923605cb4
|
|
| MD5 |
7a36773b3ec9f876f7d01dda885b6530
|
|
| BLAKE2b-256 |
d783cfa517c2a0810b812bd26124e9a00adbb761d2d3ce08a49903f8490b2133
|
File details
Details for the file rust_trainer-0.1.4-cp312-cp312-macosx_10_12_x86_64.whl.
File metadata
- Download URL: rust_trainer-0.1.4-cp312-cp312-macosx_10_12_x86_64.whl
- Upload date:
- Size: 349.8 kB
- Tags: CPython 3.12, macOS 10.12+ x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
92af2dff2467b570294aa3319d0fa5f02609ad6cdf10fc2ce213f7e8e6cce2ec
|
|
| MD5 |
32192dc55d608e8ee81decb90a90becb
|
|
| BLAKE2b-256 |
249887ed6c6874e783cec47d1afcb9fb02e257e873279f2592c153269f146eac
|
File details
Details for the file rust_trainer-0.1.4-cp311-cp311-win_amd64.whl.
File metadata
- Download URL: rust_trainer-0.1.4-cp311-cp311-win_amd64.whl
- Upload date:
- Size: 267.8 kB
- Tags: CPython 3.11, Windows x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
90fe84df4633b23fa3e03ed0bb557adc6c736f429dab5c32a979167608c69cd7
|
|
| MD5 |
8bfda5ee92e354edf5cb7f95966817de
|
|
| BLAKE2b-256 |
2b301ee54e93c8e75142a45ddc51bdc428b49a412bec59a8f1205faff9c71012
|
File details
Details for the file rust_trainer-0.1.4-cp311-cp311-manylinux_2_28_x86_64.whl.
File metadata
- Download URL: rust_trainer-0.1.4-cp311-cp311-manylinux_2_28_x86_64.whl
- Upload date:
- Size: 373.3 kB
- Tags: CPython 3.11, manylinux: glibc 2.28+ x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
8cf26d2a92722676c9235c4f19b44888456cec45565cf2526a098b65d54a4f75
|
|
| MD5 |
06b56abba4ea39e62ffd4a0710f0c5c4
|
|
| BLAKE2b-256 |
0bb1b79470f0f71a2f614690dd0249d7bcd63826dea3ad6558993b3a2de6cce3
|
File details
Details for the file rust_trainer-0.1.4-cp311-cp311-manylinux_2_28_aarch64.whl.
File metadata
- Download URL: rust_trainer-0.1.4-cp311-cp311-manylinux_2_28_aarch64.whl
- Upload date:
- Size: 363.5 kB
- Tags: CPython 3.11, manylinux: glibc 2.28+ ARM64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
5168160238bb491a1ea6679c6c513d6ddc4f5465c75d36dec1f0c78d0304a969
|
|
| MD5 |
84fabe2b4a290e1093ad22cdb81d432f
|
|
| BLAKE2b-256 |
c12eb448b21d41d9ec644bc8c2fa21a48d850957dd28bba93e61bd12827d0f0f
|
File details
Details for the file rust_trainer-0.1.4-cp311-cp311-macosx_11_0_arm64.whl.
File metadata
- Download URL: rust_trainer-0.1.4-cp311-cp311-macosx_11_0_arm64.whl
- Upload date:
- Size: 330.4 kB
- Tags: CPython 3.11, macOS 11.0+ ARM64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
482730fd3798068385cb87aa7e67042e7a839b1834afb92614cc1e8b1eff969a
|
|
| MD5 |
a1b7306a5788e5a12a7e8d0ac5f0052e
|
|
| BLAKE2b-256 |
3be4fcdbef2ba9b45ead903abc8ec86c3bf8bd166f101c88023b7af6897cb1d8
|
File details
Details for the file rust_trainer-0.1.4-cp311-cp311-macosx_10_12_x86_64.whl.
File metadata
- Download URL: rust_trainer-0.1.4-cp311-cp311-macosx_10_12_x86_64.whl
- Upload date:
- Size: 350.0 kB
- Tags: CPython 3.11, macOS 10.12+ x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e1a18dda89041ade2fe62d28b78da40e3be1fe02f25b3eaf3ab145e4bf191839
|
|
| MD5 |
045fa6923cc6558f3c1f73714ddc5b4e
|
|
| BLAKE2b-256 |
40d2a2eb70217d24f00fc9d7c500ab65da98b4d5471035e6073313b66dd252cf
|