Skip to main content

GPU-accelerated training for CALM (Catastrophically Abridged Language Models)

Project description

calm-mamba

calm logo

Python companion to the CALM (Catastrophically Abridged Language Models) system from Lilush. Provides GPU-accelerated training via PyTorch while maintaining bit-level weight compatibility with the embedded C inference engine through the CWGT binary format.

CALM is a small decoder-only language model that powers shell completion and semantic search in the Lilush shell. It uses a Mamba selective state space model (SSM) with input-dependent gating and O(1) per-token state-cached generation as its sequence mixer.

Architecture

input_ids (B, L)
    |
token_emb [vocab_size=320, d_model]     (weight-tied with lm_head)
    |
for each Block (n_layers):
    |-- LayerNorm -> Mixer (Mamba SSM) -> residual add
    '-- LayerNorm -> FFN(GELU, no bias) -> residual add
    |
LayerNorm ──┬── lm_head ── logits (B, L, 320)    [forward / loss]
            └── pool + L2 norm ── emb (B, d)      [embed]

Mamba operator

The selective SSM mixer. For each block:

  1. in_proj splits input into SSM branch (x') and gate branch (z)
  2. Causal depthwise conv1d (kernel=d_conv, default 4) + SiLU on x'
  3. x_proj produces input-dependent dt, B, C (selectivity)
  4. dt_proj + softplus gives discretized time steps
  5. SSM scan: h[t] = exp(dt*A)*h[t-1] + dt*B[t]*x[t]; y[t] = C[t]·h[t] + D*x[t]
  6. Gate: y = y * SiLU(z)
  7. Output projection (bias-free)

SSM parameters: d_state (state dim, default 16), d_conv (conv kernel, default 4), expand (inner expansion factor, default 2), dt_rank (delta projection bottleneck, default ceil(d_model/16)). A_log and D are excluded from weight decay during training.

Tokenizer

Byte-level identity mapping with 320 tokens total:

  • 0-255: raw byte values
  • 256-269: 14 special tokens for context framing (PAD, BOS, EOS, ATN, CWD, GIT, HIST, EXIT, CMD, ENV, COMP, FILE, NEXT, END)
  • 270-276: 7 dictionary-domain tokens (WORD, POS, NOTE, IPA, DEF, QUOTE, BY)
  • 277-319: reserved

Context window format used for shell completion:

<BOS> <CWD>/home/user<END> <GIT>main<END> <HIST>ls<EXIT>0<END> <ATN> <CMD>git c...

Model sizes

Presets:

Preset d_model layers expand ffn_expand params
nano 64 3 2 2 ~168K
micro 96 5 2 3 ~649K
mini 128 6 3 4 ~1.9M
small 192 8 4 4 ~6.4M

Installation

uv add calm-mamba

Or for development:

git clone <repo-url> && cd calm
uv sync --dev

Requires Python 3.12+ and PyTorch with CUDA. Triton is included for fused GPU kernels that significantly accelerate training.

Training

Training uses CTDS (CALM Training DataSet) binary files as input. CTDS files contain packed token sequences with per-sequence ATN positions for loss masking -- loss is computed on all tokens after <ATN> (or the full sequence if <ATN> is absent).

Generate CTDS files using the Lilush calm dataset builtin, or using the dataset tools in the calm_datasets repo together with calm-text-to-ctds.

Using a preset

calm-train --preset nano --dataset train.ctds --save-dir ./runs/nano
calm-train --preset micro --dataset train.ctds --val-dataset val.ctds --save-dir ./runs/micro
calm-train --preset mini --dataset train.ctds --save-dir ./runs/mini

Using a config file

calm-train --config my_config.yaml --dataset train.ctds --save-dir ./runs/mini

CLI options

--preset NAME                      Use a preset (nano/micro/mini/small)
--config PATH                      Use a YAML config file
--dataset PATH                     CTDS training dataset (required)
--val-dataset PATH                 CTDS validation dataset
--save-dir PATH                    Output directory (default: ./runs/calm)
--epochs N                         Override number of epochs
--lr FLOAT                         Override learning rate
--batch-size N                     Override batch size
--resume PATH                      Resume from safetensors checkpoint
--init-cwgt PATH                   Initialize weights from CWGT file
--shuffle / --no-shuffle           Shuffle training data each epoch
--domain NAME                      Model domain (default: shell)
--template SPEC                    Template DSL spec (overrides domain default)
--stop-conditions SPEC             Stop conditions (overrides domain default)

The --domain flag determines the prompt template, stop conditions, and default sampler parameters baked into the exported CWGT file. Known domains: shell, dictionary, netref. When using --init-cwgt, the source model's metadata is preserved unless explicitly overridden.

Training produces:

  • checkpoints/ -- safetensors checkpoints with optimizer state
  • model.cwgt -- CWGT weight file for deployment in Lilush

Config file format

model:
  d_model: 128
  n_layers: 6
  l_max: 768
  ffn_expand: 4
  mamba_expand: 3
  mamba_d_state: 16
  mamba_d_conv: 4

training:
  epochs: 10
  batch_size: 16
  learning_rate: 0.0003
  warmup_steps: 200
  clip_grad_norm: 1.0
  patience: 5
  weight_decay: 0.01
  amp: true                    # mixed precision (bf16/fp16), default true on CUDA
  gradient_checkpointing: false  # trade compute for VRAM

Continual learning (EWC)

Add a continual section to the config to enable Online Elastic Weight Consolidation, which prevents catastrophic forgetting when fine-tuning on new data:

continual:
  method: online_ewc
  lambda: 10.0
  decay: 0.99
  fisher_update_interval_steps: 1000
  fisher_batches: 16
  buffer_batches: 1024

Embeddings

CalmLM.embed() extracts dense vector representations from the model. It runs the full block stack and final LayerNorm but skips the LM head, then pools across positions and optionally L2-normalises. This matches calm_model_embed() in the C engine.

from calm.checkpoint import load_cwgt

model, _, _, _, _ = load_cwgt("model.cwgt")
model.eval()

import torch
ids = torch.tensor([[257, 104, 101, 108, 108, 111, 258]])  # <BOS>hello<EOS>
emb = model.embed(ids, pool_mode=0, normalize=True)         # [1, d_model]

Pool modes: 0 = mean over all positions (default), 1 = last token only. When processing padded batches, pass an attention_mask (bool [B, L], True for real tokens) so the pooling ignores padding.

Normalisation: enabled by default -- output vectors have unit L2 norm, suitable for cosine similarity via dot product.

Contrastive training (InfoNCE)

Fine-tunes a pretrained CALM model to produce high-quality embeddings using symmetric InfoNCE contrastive loss. This matches the C-side calm_contrastive_step().

Pair data format

Training pairs use a text format. Each pair has a <QUERY> (the search text) and a <REF> (the target passage), separated by <ATN>. Pairs are separated by blank lines:

<QUERY>how does TCP handle retransmission
<ATN>
<REF>When a TCP sender detects segment loss using a retransmission
timer or duplicate acknowledgments, it retransmits the lost segment.

<QUERY>TLS 1.3 key exchange
<ATN>
<REF>The handshake protocol negotiates the cryptographic parameters
using Diffie-Hellman key exchange in a single round trip.

Multi-line passages are supported (continuation lines after <REF>). Pairs with query < 5 chars or passage < 20 chars are filtered out.

Training

# Fine-tune a pretrained model
calm-train-contrastive \
    --init-cwgt pretrained.cwgt \
    --pairs rfc_pairs.txt \
    --save-dir ./runs/rfc_embed \
    --epochs 5 --batch-size 16 --lr 1e-4

# Train from scratch with a preset
calm-train-contrastive \
    --preset mini \
    --pairs rfc_pairs.txt \
    --save-dir ./runs/embed_mini

# Resume from checkpoint
calm-train-contrastive \
    --resume ./runs/rfc_embed/checkpoints/epoch0003 \
    --pairs rfc_pairs.txt \
    --save-dir ./runs/rfc_embed

CLI options

--preset NAME              Use a preset (nano/micro/mini/small)
--config PATH              Use a YAML config file
--pairs PATH               Text pair file (required)
--save-dir PATH            Output directory (default: ./runs/contrastive)
--epochs N                 Override number of epochs
--lr FLOAT                 Override learning rate (default: 1e-4)
--batch-size N             Override batch size (default: 16)
--temperature FLOAT        InfoNCE temperature (default: 0.07)
--pool-mode {mean,last}    Embedding pooling mode (default: mean)
--init-cwgt PATH           Initialize weights from CWGT file
--resume PATH              Resume from checkpoint directory
--domain NAME              Model domain (default: netref)
--template SPEC            Template DSL spec
--stop-conditions SPEC     Stop conditions spec

InfoNCE loss

The loss function (calm.contrastive.infonce_loss) computes symmetric InfoNCE over in-batch negatives:

  1. Compute B x B similarity matrix: sim = queries @ positives.T / temperature
  2. Query-side cross-entropy: each query targets its corresponding positive
  3. Positive-side cross-entropy (symmetric): each positive targets its query
  4. Average: (loss_q2p + loss_p2q) / 2

All 2B samples (queries + positives) are batched in a single model.embed() call for efficient GPU utilisation.

from calm.contrastive import infonce_loss

# q_emb, p_emb: [B, d] L2-normalised embeddings
loss = infonce_loss(q_emb, p_emb, temperature=0.07)
loss.backward()

Inference

Mirrors calm generate from Lilush. Two modes: template-driven (default) and raw. Output includes model info, scored candidates, and generation stats unless --quiet is used.

Template mode (default)

Introspects the model's template metadata to build context sequences. For shell models, the template defines frames for CWD, GIT, HIST, etc. For multi-field templates, input is parsed for field:value patterns (e.g. pos:n. headword:cat). Stop conditions come from model metadata.

calm-inference -m model.cwgt -i "git c" -k 5
calm-inference -m model.cwgt -i "ls -" -t 0
calm-inference -m model.cwgt -i "echo " -k 10 -n 3
calm-inference -m model.cwgt -i "cwd:/home/user git:main input:git c"

Raw mode

Parses inline special token patterns (<NAME> or <:ID:>) with no automatic framing and no stop conditions. The caller controls the entire token sequence.

calm-inference -m model.cwgt -r "<BOS><WORD>cat<POS>n.<END><ATN>" --max-tokens 500
calm-inference -m model.cwgt -r "<BOS>The quick brown fox" --max-tokens 100
calm-inference -m model.cwgt -r "<:257:>def fibonacci(" --max-tokens 100

CLI options

-m, --model PATH          Path to CWGT weight file (required)
-i, --input TEXT          Input text (or read from stdin)
-r, --raw TEXT            Raw input with <NAME>/<:ID:> patterns
-k, --top-k N            Top-K sampling (default: 5)
-p, --top-p FLOAT        Nucleus (top-p) sampling threshold (0 = disabled)
--min-p FLOAT             Min-P relative probability threshold (0 = disabled)
-t, --temperature STR     Sampling temperature (default: 0.8)
--max-tokens N            Maximum tokens to generate (default: 256)
-n, --candidates N        Number of completions to generate (default: 1)
-q, --quiet               Output completion text only, no stats
--full                    Output prompt + completion concatenated
-s, --special-tokens      Render special tokens as <BOS> etc in output
--show-fields             Show model template fields and exit

Weight format (CWGT v5)

The CWGT binary format enables direct weight exchange between Python and the Lilush C runtime. All weights are float32, little-endian.

[Header: 48 bytes, packed, little-endian]
  magic "CWGT"        4 bytes
  arch_version        uint16 (5)
  flags               uint16 (bit 0: tied, bit 1: EWC)
  vocab_size          uint16
  d_model             uint16
  n_layers            uint8
  ffn_expand          uint8
  expand              uint8 (offset 14)
  d_state             uint8 (offset 15)
  l_max               uint16
  param_count         uint32
  def_temperature     uint16 (x1000, e.g. 800 = 0.8)
  def_top_k           uint16
  def_top_p           uint16 (x1000)
  def_min_p           uint16 (x1000)
  def_max_tokens      uint16
  def_candidates      uint8
  d_conv              uint8 (offset 33)
  meta_size           uint32 (byte count of metadata blob, 0 if none)
  dt_rank             uint8 (offset 38)
  reserved            9 bytes

[Metadata blob: meta_size bytes, 3 newline-terminated UTF-8 lines]
  domain\n            e.g. "shell"
  template\n          e.g. "BOS;CWD:cwd;GIT:git;...;ATN;CMD:input"
  stop_conditions\n   e.g. "| ; && ||"

[Optional EWC data: 2 x param_count x float32]
[Weights: param_count x float32]

Weight order: token_emb, then per-layer (ln1, in_proj, conv1d_weight, x_proj, dt_proj weight+bias, A_log, D, out_proj, ln2, ffn), then final ln.

Linear layer weights are transposed between PyTorch [out, in] and CWGT [in, out] layout during save/load.

Programmatic usage

from calm.checkpoint import load_cwgt, save_cwgt

# Load -- returns (model, config, meta, ewc_fisher, ewc_anchor)
model, config, meta, _, _ = load_cwgt("model.cwgt")
print(meta.domain, meta.template, meta.stop_conditions)

# Save with domain metadata for deployment in Lilush
save_cwgt("exported.cwgt", model,
          domain="shell",
          template="BOS;CWD:cwd;GIT:git;...;ATN;CMD:input",
          stop_conditions="| ; && ||",
          sampler_defaults={"temperature": 0.8, "top_k": 5})

Dataset format (CTDS)

CTDS (CALM Training DataSet) is a binary format for packed training sequences with per-sequence ATN positions.

[Header: 14 bytes, packed]
  magic "CTDS"      4 bytes
  vocab_version     uint32
  count             uint32
  max_len           uint16

[Lengths: count x uint16]
[ATN positions: count x uint16]
[Tokens: sum(lengths) x uint16]

Creating CTDS files from Python

from calm.tokenizer import CalmTokenizer
from calm.dataset import write_ctds

tok = CalmTokenizer()
sequences = [
    tok.build_sequence(cwd="/home/user", history=[("ls", 0)], partial_cmd="git status"),
    tok.build_sequence(cwd="/tmp", partial_cmd="echo hello"),
]
cmd_positions = [tok.find_cmd_pos(s) for s in sequences]
write_ctds("train.ctds", sequences, cmd_positions)

GPU acceleration (Triton kernels)

When Triton is installed and tensors are on CUDA, the Mamba operator automatically dispatches to fused GPU kernels in calm/triton_kernels/:

  • Selective scan -- replaces the sequential Python loop with a single fused kernel (one program per batch x d_inner tile, full sequence loop in-kernel). Backward uses forward-recompute strategy.
  • Causal conv1d + SiLU -- fuses left-pad, depthwise conv1d, and SiLU activation in BLD layout, eliminating two transpose copies.
  • Gated SiLU -- fuses y * SiLU(z) without materialising the SiLU intermediate.

All kernels fall back to pure PyTorch on CPU or when Triton is absent. Combined with AMP (enabled by default), expect ~3-5x training speedup.

Testing

uv run pytest tests/ -v

Project structure

calm/
  __init__.py             Package root (CalmLM, CalmTokenizer, etc.)
  compat.py               Constants, CWGT/CTDS headers, param count formulas
  tokenizer.py            Byte-level CALM tokenizer (320 vocab)
  template.py             Template DSL parser and sequence builder
  dataset.py              CTDS binary dataset reader/writer + collator
  contrastive_dataset.py  Text pair dataset reader + collator for InfoNCE
  lm.py                   CalmLM model (forward, loss, embed, incremental decode)
  mamba_operator.py       Mamba operator (selective SSM, sequential scan)
  contrastive.py          Symmetric InfoNCE loss function
  checkpoint.py           Safetensors checkpoints + CWGT v5 serialization
  ewc.py                  Online Elastic Weight Consolidation
  triton_kernels/         Fused Triton GPU kernels (auto-dispatch on CUDA)
  configs/                YAML model preset configs (nano/micro/mini/small)
  cli/
    train.py              calm-train entry point
    train_contrastive.py  calm-train-contrastive entry point
    inference.py          calm-inference entry point
    text_to_ctds.py       calm-text-to-ctds entry point
scripts/                  Domain training shell scripts
tests/                    Pytest test suite
pyproject.toml            Package metadata and build config

Dataset conversion and preparation scripts live in the calm_datasets repo.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

calm_mamba-0.1.3.tar.gz (655.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

calm_mamba-0.1.3-py3-none-any.whl (69.7 kB view details)

Uploaded Python 3

File details

Details for the file calm_mamba-0.1.3.tar.gz.

File metadata

  • Download URL: calm_mamba-0.1.3.tar.gz
  • Upload date:
  • Size: 655.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.11.6 {"installer":{"name":"uv","version":"0.11.6","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Debian GNU/Linux","version":"13","id":"trixie","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for calm_mamba-0.1.3.tar.gz
Algorithm Hash digest
SHA256 faefb0ac61df7ad94e7d8039ad122604b74ce864303cde776223357d433417fe
MD5 5723dd180bea7f1c517a9e70a0979c01
BLAKE2b-256 4b1ff47a0b8a487a538f19019623c460872b4f9408408a7a34e66c08f19298d1

See more details on using hashes here.

File details

Details for the file calm_mamba-0.1.3-py3-none-any.whl.

File metadata

  • Download URL: calm_mamba-0.1.3-py3-none-any.whl
  • Upload date:
  • Size: 69.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.11.6 {"installer":{"name":"uv","version":"0.11.6","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Debian GNU/Linux","version":"13","id":"trixie","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for calm_mamba-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 482e37d312e04b6e48a68b909e7a49b8f26af8e8c56920f66dba08937b15bcd9
MD5 88348f104c1c0356f7a41f2695492e35
BLAKE2b-256 0fa876fe54a08e2f849839e23390bc5a7127350f733b1389fb3d9381a8a0a86d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page