GPU-accelerated training for CALM (Catastrophically Abridged Language Models)
Project description
calm-mamba
Python companion to the CALM (Catastrophically Abridged Language Models) system from Lilush. Provides GPU-accelerated training via PyTorch while maintaining bit-level weight compatibility with the embedded C inference engine through the CWGT binary format.
CALM is a small decoder-only language model that powers shell completion and semantic search in the Lilush shell. It uses a Mamba selective state space model (SSM) with input-dependent gating and O(1) per-token state-cached generation as its sequence mixer.
Architecture
input_ids (B, L)
|
token_emb [vocab_size=320, d_model] (weight-tied with lm_head)
|
for each Block (n_layers):
|-- LayerNorm -> Mixer (Mamba SSM) -> residual add
'-- LayerNorm -> FFN(GELU, no bias) -> residual add
|
LayerNorm ──┬── lm_head ── logits (B, L, 320) [forward / loss]
└── pool + L2 norm ── emb (B, d) [embed]
Mamba operator
The selective SSM mixer. For each block:
- in_proj splits input into SSM branch (x') and gate branch (z)
- Causal depthwise conv1d (kernel=d_conv, default 4) + SiLU on x'
- x_proj produces input-dependent dt, B, C (selectivity)
- dt_proj + softplus gives discretized time steps
- SSM scan:
h[t] = exp(dt*A)*h[t-1] + dt*B[t]*x[t];y[t] = C[t]·h[t] + D*x[t] - Gate:
y = y * SiLU(z) - Output projection (bias-free)
SSM parameters: d_state (state dim, default 16), d_conv (conv
kernel, default 4), expand (inner expansion factor, default 2),
dt_rank (delta projection bottleneck, default ceil(d_model/16)).
A_log and D are excluded from weight decay during training.
Tokenizer
Byte-level identity mapping with 320 tokens total:
- 0-255: raw byte values
- 256-269: 14 special tokens for context framing (PAD, BOS, EOS, ATN, CWD, GIT, HIST, EXIT, CMD, ENV, COMP, FILE, NEXT, END)
- 270-276: 7 dictionary-domain tokens (WORD, POS, NOTE, IPA, DEF, QUOTE, BY)
- 277-319: reserved
Context window format used for shell completion:
<BOS> <CWD>/home/user<END> <GIT>main<END> <HIST>ls<EXIT>0<END> <ATN> <CMD>git c...
Model sizes
Presets:
| Preset | d_model | layers | expand | ffn_expand | params |
|---|---|---|---|---|---|
| nano | 64 | 3 | 2 | 2 | ~168K |
| micro | 96 | 5 | 2 | 3 | ~649K |
| mini | 128 | 6 | 3 | 4 | ~1.9M |
| small | 192 | 8 | 4 | 4 | ~6.4M |
Installation
uv add calm-mamba
Or for development:
git clone <repo-url> && cd calm
uv sync --dev
Requires Python 3.12+ and PyTorch with CUDA. Triton is included for fused GPU kernels that significantly accelerate training.
Training
Training uses CTDS (CALM Training DataSet) binary files as input. CTDS
files contain packed token sequences with per-sequence ATN positions for
loss masking -- loss is computed on all tokens after <ATN> (or the
full sequence if <ATN> is absent).
Generate CTDS files using the Lilush calm dataset builtin, or using
the dataset tools in the
calm_datasets repo
together with calm-text-to-ctds.
Using a preset
calm-train --preset nano --dataset train.ctds --save-dir ./runs/nano
calm-train --preset micro --dataset train.ctds --val-dataset val.ctds --save-dir ./runs/micro
calm-train --preset mini --dataset train.ctds --save-dir ./runs/mini
Using a config file
calm-train --config my_config.yaml --dataset train.ctds --save-dir ./runs/mini
CLI options
--preset NAME Use a preset (nano/micro/mini/small)
--config PATH Use a YAML config file
--dataset PATH CTDS training dataset (required)
--val-dataset PATH CTDS validation dataset
--save-dir PATH Output directory (default: ./runs/calm)
--epochs N Override number of epochs
--lr FLOAT Override learning rate
--batch-size N Override batch size
--resume PATH Resume from safetensors checkpoint
--init-cwgt PATH Initialize weights from CWGT file
--shuffle / --no-shuffle Shuffle training data each epoch
--domain NAME Model domain (default: shell)
--template SPEC Template DSL spec (overrides domain default)
--stop-conditions SPEC Stop conditions (overrides domain default)
The --domain flag determines the prompt template, stop conditions,
and default sampler parameters baked into the exported CWGT file.
Known domains: shell, dictionary, netref. When using
--init-cwgt, the source model's metadata is preserved unless
explicitly overridden.
Training produces:
checkpoints/-- safetensors checkpoints with optimizer statemodel.cwgt-- CWGT weight file for deployment in Lilush
Config file format
model:
d_model: 128
n_layers: 6
l_max: 768
ffn_expand: 4
mamba_expand: 3
mamba_d_state: 16
mamba_d_conv: 4
training:
epochs: 10
batch_size: 16
learning_rate: 0.0003
warmup_steps: 200
clip_grad_norm: 1.0
patience: 5
weight_decay: 0.01
amp: true # mixed precision (bf16/fp16), default true on CUDA
gradient_checkpointing: false # trade compute for VRAM
Continual learning (EWC)
Add a continual section to the config to enable Online Elastic Weight
Consolidation, which prevents catastrophic forgetting when fine-tuning
on new data:
continual:
method: online_ewc
lambda: 10.0
decay: 0.99
fisher_update_interval_steps: 1000
fisher_batches: 16
buffer_batches: 1024
Embeddings
CalmLM.embed() extracts dense vector representations from the model.
It runs the full block stack and final LayerNorm but skips the LM head,
then pools across positions and optionally L2-normalises. This matches
calm_model_embed() in the C engine.
from calm.checkpoint import load_cwgt
model, _, _, _, _ = load_cwgt("model.cwgt")
model.eval()
import torch
ids = torch.tensor([[257, 104, 101, 108, 108, 111, 258]]) # <BOS>hello<EOS>
emb = model.embed(ids, pool_mode=0, normalize=True) # [1, d_model]
Pool modes: 0 = mean over all positions (default), 1 = last
token only. When processing padded batches, pass an attention_mask
(bool [B, L], True for real tokens) so the pooling ignores padding.
Normalisation: enabled by default -- output vectors have unit L2 norm, suitable for cosine similarity via dot product.
Contrastive training (InfoNCE)
Fine-tunes a pretrained CALM model to produce high-quality embeddings
using symmetric InfoNCE contrastive loss. This matches the C-side
calm_contrastive_step().
Pair data format
Training pairs use a text format. Each pair has a <QUERY> (the search
text) and a <REF> (the target passage), separated by <ATN>. Pairs
are separated by blank lines:
<QUERY>how does TCP handle retransmission
<ATN>
<REF>When a TCP sender detects segment loss using a retransmission
timer or duplicate acknowledgments, it retransmits the lost segment.
<QUERY>TLS 1.3 key exchange
<ATN>
<REF>The handshake protocol negotiates the cryptographic parameters
using Diffie-Hellman key exchange in a single round trip.
Multi-line passages are supported (continuation lines after <REF>).
Pairs with query < 5 chars or passage < 20 chars are filtered out.
Training
# Fine-tune a pretrained model
calm-train-contrastive \
--init-cwgt pretrained.cwgt \
--pairs rfc_pairs.txt \
--save-dir ./runs/rfc_embed \
--epochs 5 --batch-size 16 --lr 1e-4
# Train from scratch with a preset
calm-train-contrastive \
--preset mini \
--pairs rfc_pairs.txt \
--save-dir ./runs/embed_mini
# Resume from checkpoint
calm-train-contrastive \
--resume ./runs/rfc_embed/checkpoints/epoch0003 \
--pairs rfc_pairs.txt \
--save-dir ./runs/rfc_embed
CLI options
--preset NAME Use a preset (nano/micro/mini/small)
--config PATH Use a YAML config file
--pairs PATH Text pair file (required)
--save-dir PATH Output directory (default: ./runs/contrastive)
--epochs N Override number of epochs
--lr FLOAT Override learning rate (default: 1e-4)
--batch-size N Override batch size (default: 16)
--temperature FLOAT InfoNCE temperature (default: 0.07)
--pool-mode {mean,last} Embedding pooling mode (default: mean)
--init-cwgt PATH Initialize weights from CWGT file
--resume PATH Resume from checkpoint directory
--domain NAME Model domain (default: netref)
--template SPEC Template DSL spec
--stop-conditions SPEC Stop conditions spec
InfoNCE loss
The loss function (calm.contrastive.infonce_loss) computes symmetric
InfoNCE over in-batch negatives:
- Compute
B x Bsimilarity matrix:sim = queries @ positives.T / temperature - Query-side cross-entropy: each query targets its corresponding positive
- Positive-side cross-entropy (symmetric): each positive targets its query
- Average:
(loss_q2p + loss_p2q) / 2
All 2B samples (queries + positives) are batched in a single
model.embed() call for efficient GPU utilisation.
from calm.contrastive import infonce_loss
# q_emb, p_emb: [B, d] L2-normalised embeddings
loss = infonce_loss(q_emb, p_emb, temperature=0.07)
loss.backward()
Inference
Mirrors calm generate from Lilush. Two modes: template-driven
(default) and raw. Output includes model info, scored candidates,
and generation stats unless --quiet is used.
Template mode (default)
Introspects the model's template metadata to build context sequences.
For shell models, the template defines frames for CWD, GIT, HIST, etc.
For multi-field templates, input is parsed for field:value patterns
(e.g. pos:n. headword:cat). Stop conditions come from model metadata.
calm-inference -m model.cwgt -i "git c" -k 5
calm-inference -m model.cwgt -i "ls -" -t 0
calm-inference -m model.cwgt -i "echo " -k 10 -n 3
calm-inference -m model.cwgt -i "cwd:/home/user git:main input:git c"
Raw mode
Parses inline special token patterns (<NAME> or <:ID:>) with no
automatic framing and no stop conditions. The caller controls the
entire token sequence.
calm-inference -m model.cwgt -r "<BOS><WORD>cat<POS>n.<END><ATN>" --max-tokens 500
calm-inference -m model.cwgt -r "<BOS>The quick brown fox" --max-tokens 100
calm-inference -m model.cwgt -r "<:257:>def fibonacci(" --max-tokens 100
CLI options
-m, --model PATH Path to CWGT weight file (required)
-i, --input TEXT Input text (or read from stdin)
-r, --raw TEXT Raw input with <NAME>/<:ID:> patterns
-k, --top-k N Top-K sampling (default: 5)
-p, --top-p FLOAT Nucleus (top-p) sampling threshold (0 = disabled)
--min-p FLOAT Min-P relative probability threshold (0 = disabled)
-t, --temperature STR Sampling temperature (default: 0.8)
--max-tokens N Maximum tokens to generate (default: 256)
-n, --candidates N Number of completions to generate (default: 1)
-q, --quiet Output completion text only, no stats
--full Output prompt + completion concatenated
-s, --special-tokens Render special tokens as <BOS> etc in output
--show-fields Show model template fields and exit
Weight format (CWGT v5)
The CWGT binary format enables direct weight exchange between Python and the Lilush C runtime. All weights are float32, little-endian.
[Header: 48 bytes, packed, little-endian]
magic "CWGT" 4 bytes
arch_version uint16 (5)
flags uint16 (bit 0: tied, bit 1: EWC)
vocab_size uint16
d_model uint16
n_layers uint8
ffn_expand uint8
expand uint8 (offset 14)
d_state uint8 (offset 15)
l_max uint16
param_count uint32
def_temperature uint16 (x1000, e.g. 800 = 0.8)
def_top_k uint16
def_top_p uint16 (x1000)
def_min_p uint16 (x1000)
def_max_tokens uint16
def_candidates uint8
d_conv uint8 (offset 33)
meta_size uint32 (byte count of metadata blob, 0 if none)
dt_rank uint8 (offset 38)
reserved 9 bytes
[Metadata blob: meta_size bytes, 3 newline-terminated UTF-8 lines]
domain\n e.g. "shell"
template\n e.g. "BOS;CWD:cwd;GIT:git;...;ATN;CMD:input"
stop_conditions\n e.g. "| ; && ||"
[Optional EWC data: 2 x param_count x float32]
[Weights: param_count x float32]
Weight order: token_emb, then per-layer (ln1, in_proj, conv1d_weight, x_proj, dt_proj weight+bias, A_log, D, out_proj, ln2, ffn), then final ln.
Linear layer weights are transposed between PyTorch [out, in] and
CWGT [in, out] layout during save/load.
Programmatic usage
from calm.checkpoint import load_cwgt, save_cwgt
# Load -- returns (model, config, meta, ewc_fisher, ewc_anchor)
model, config, meta, _, _ = load_cwgt("model.cwgt")
print(meta.domain, meta.template, meta.stop_conditions)
# Save with domain metadata for deployment in Lilush
save_cwgt("exported.cwgt", model,
domain="shell",
template="BOS;CWD:cwd;GIT:git;...;ATN;CMD:input",
stop_conditions="| ; && ||",
sampler_defaults={"temperature": 0.8, "top_k": 5})
Dataset format (CTDS)
CTDS (CALM Training DataSet) is a binary format for packed training sequences with per-sequence ATN positions.
[Header: 14 bytes, packed]
magic "CTDS" 4 bytes
vocab_version uint32
count uint32
max_len uint16
[Lengths: count x uint16]
[ATN positions: count x uint16]
[Tokens: sum(lengths) x uint16]
Creating CTDS files from Python
from calm.tokenizer import CalmTokenizer
from calm.dataset import write_ctds
tok = CalmTokenizer()
sequences = [
tok.build_sequence(cwd="/home/user", history=[("ls", 0)], partial_cmd="git status"),
tok.build_sequence(cwd="/tmp", partial_cmd="echo hello"),
]
cmd_positions = [tok.find_cmd_pos(s) for s in sequences]
write_ctds("train.ctds", sequences, cmd_positions)
GPU acceleration (Triton kernels)
When Triton is installed and tensors are on CUDA, the Mamba operator
automatically dispatches to fused GPU kernels in calm/triton_kernels/:
- Selective scan -- replaces the sequential Python loop with a single fused kernel (one program per batch x d_inner tile, full sequence loop in-kernel). Backward uses forward-recompute strategy.
- Causal conv1d + SiLU -- fuses left-pad, depthwise conv1d, and SiLU activation in BLD layout, eliminating two transpose copies.
- Gated SiLU -- fuses
y * SiLU(z)without materialising the SiLU intermediate.
All kernels fall back to pure PyTorch on CPU or when Triton is absent. Combined with AMP (enabled by default), expect ~3-5x training speedup.
Testing
uv run pytest tests/ -v
Project structure
calm/
__init__.py Package root (CalmLM, CalmTokenizer, etc.)
compat.py Constants, CWGT/CTDS headers, param count formulas
tokenizer.py Byte-level CALM tokenizer (320 vocab)
template.py Template DSL parser and sequence builder
dataset.py CTDS binary dataset reader/writer + collator
contrastive_dataset.py Text pair dataset reader + collator for InfoNCE
lm.py CalmLM model (forward, loss, embed, incremental decode)
mamba_operator.py Mamba operator (selective SSM, sequential scan)
contrastive.py Symmetric InfoNCE loss function
checkpoint.py Safetensors checkpoints + CWGT v5 serialization
ewc.py Online Elastic Weight Consolidation
triton_kernels/ Fused Triton GPU kernels (auto-dispatch on CUDA)
configs/ YAML model preset configs (nano/micro/mini/small)
cli/
train.py calm-train entry point
train_contrastive.py calm-train-contrastive entry point
inference.py calm-inference entry point
text_to_ctds.py calm-text-to-ctds entry point
scripts/ Domain training shell scripts
tests/ Pytest test suite
pyproject.toml Package metadata and build config
Dataset conversion and preparation scripts live in the calm_datasets repo.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file calm_mamba-0.1.3.tar.gz.
File metadata
- Download URL: calm_mamba-0.1.3.tar.gz
- Upload date:
- Size: 655.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.11.6 {"installer":{"name":"uv","version":"0.11.6","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Debian GNU/Linux","version":"13","id":"trixie","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
faefb0ac61df7ad94e7d8039ad122604b74ce864303cde776223357d433417fe
|
|
| MD5 |
5723dd180bea7f1c517a9e70a0979c01
|
|
| BLAKE2b-256 |
4b1ff47a0b8a487a538f19019623c460872b4f9408408a7a34e66c08f19298d1
|
File details
Details for the file calm_mamba-0.1.3-py3-none-any.whl.
File metadata
- Download URL: calm_mamba-0.1.3-py3-none-any.whl
- Upload date:
- Size: 69.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.11.6 {"installer":{"name":"uv","version":"0.11.6","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Debian GNU/Linux","version":"13","id":"trixie","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
482e37d312e04b6e48a68b909e7a49b8f26af8e8c56920f66dba08937b15bcd9
|
|
| MD5 |
88348f104c1c0356f7a41f2695492e35
|
|
| BLAKE2b-256 |
0fa876fe54a08e2f849839e23390bc5a7127350f733b1389fb3d9381a8a0a86d
|