Skip to main content

A decoder-only Transformer built from scratch in JAX and Flax NNX

Project description

DantinoX

"Nel mezzo del cammin di nostra vita mi ritrovai per una selva oscura..."

A decoder-only Transformer built from scratch in JAX and Flax NNX — complete with a training pipeline, autoregressive generation, hyperparameter sweeps, and a benchmarking suite.


Python 3.10+ JAX Flax NNX License: MIT Ruff Checked with mypy Tests Documentation

Documentation · Coverage Report · API Reference


Overview

DantinoX is a research-grade library for building, training, and benchmarking decoder-only Transformers in pure JAX. It is designed as a transparent, modular codebase for studying how architectural choices — attention mechanism, positional encoding, MoE routing — affect convergence, memory footprint, and inference throughput.

The library ships as an installable Python package (pip install dantinox) with a unified CLI, a programmatic Python API, a typed configuration dataclass, and a full test suite.

Implemented Architectures

Component Variants
Attention Multi-Head (MHA) · Grouped-Query (GQA) · Multi-Head Latent (MLA)
Feed-Forward Dense MLP (SwiGLU / GELU) · Sparse Mixture-of-Experts (Top-K)
Positional Encoding Rotary (RoPE) · Absolute Sinusoidal · Learned
Attention Masking Causal · Sliding Window
Memory Optimizations Gradient Checkpointing (nnx.remat) · Weight Tying · Static KV-Cache
Training Gradient Accumulation · AdamW / Adafactor / Lion · Cosine LR Schedule
Tokenizers Character-level · Byte-Pair Encoding (BPE)

Installation

From PyPI

pip install dantinox                        # core only
pip install "dantinox[data]"               # + HuggingFace datasets
pip install "dantinox[benchmark]"          # + pandas / matplotlib / scipy
pip install "dantinox[all]"               # everything including dev tools

From Source

git clone https://github.com/winstonsmith1897/DantinoX.git
cd DantinoX

conda create -n dantinox python=3.12 -y
conda activate dantinox

make install      # installs JAX + all extras in editable mode

GPU support: replace the JAX CPU wheels with pip install -U "jax[cuda12]" after running make install.


Quick Start

CLI

DantinoX registers a single dantinox entry-point with five subcommands:

# Train from a YAML config
dantinox train --config configs/default_config.yaml --data_path data/corpus.txt

# Override any config field from the command line
dantinox train --config configs/default_config.yaml --data_path data/corpus.txt \
    --lr 3e-4 --use_moe true --num_blocks 8

# Generate text from a saved checkpoint
dantinox generate --run_dir runs/run_20260101_120000 \
    --prompt "Nel mezzo del cammin " --max_new_tokens 200 --temperature 1.2

# Run a W&B Bayesian hyperparameter sweep
dantinox sweep --sweep_config configs/sweep.yaml --data_path data/corpus.txt

# Benchmark all run directories and save metrics to CSV
dantinox benchmark --runs_dir runs --out_csv benchmark_results.csv

# Generate plots from benchmark results
dantinox plot --in_csv benchmark_results.csv --out_dir plots

Python API

from dantinox import Trainer, Generator, BenchmarkRunner
from core.config import Config

# --- Training ---
config = Config(
    dim=512, n_heads=16, head_size=32, kv_heads=4,
    num_blocks=12, max_context=512,
    use_moe=True, n_experts=4, top_k_mlp=2,
    lr=3e-4, batch_size=64, grad_accum=4, epochs=100,
)
trainer = Trainer(config)
run_dir = trainer.fit("data/corpus.txt")

# --- Generation ---
gen = Generator(run_dir)
text = gen.generate(
    "Nel mezzo del cammin ",
    max_new_tokens=200,
    temperature=1.2,
    top_p=0.9,
    use_cache=True,
)
print(text)

# --- Benchmarking ---
runner = BenchmarkRunner("runs")
df = runner.run(out_csv="benchmark_results.csv")
print(df[["run", "type", "params_m", "prefill_ms"]].to_string())

Configuration

All architecture and training settings live in a single typed dataclass. YAML files are fully supported and can be partially overridden from the CLI.

# configs/default_config.yaml

model:
  dim: 512                      # Hidden dimension (must equal n_heads × head_size)
  n_heads: 16                   # Query heads
  kv_heads: 4                   # Key/value heads — set < n_heads to enable GQA
  head_size: 32                 # Per-head dimension
  num_blocks: 12                # Transformer depth
  max_context: 512              # Maximum sequence length
  weight_tying: true            # Tie embedding ↔ LM-head weights
  activation: gelu              # Activation function (gelu | silu)
  use_swiglu: true              # Replace MLP activation with SwiGLU gate
  gradient_checkpointing: true  # Recompute activations to reduce VRAM
  dropout_rate: 0.15

moe:
  use_moe: false                # Toggle Sparse MoE (true) vs Dense MLP (false)
  n_experts: 4                  # Total number of experts
  top_k_mlp: 2                  # Active experts per token
  expansion: 4                  # Expert hidden-dimension multiplier
  alpha_balance: 0.1            # Load-balancing loss weight

attention:
  use_rotary_pos: true          # Rotary Positional Embedding (RoPE)
  sliding_window: false         # Restrict attention to a local window
  context_window: 4             # Window size (if sliding_window: true)
  no_sink: true                 # Sigmoid attention gate for training stability

  # Multi-Head Latent Attention (MLA)
  mla: false
  down_dim_q: 256               # Query compression dimension
  down_dim_kv: 256              # Key/Value compression dimension
  rope_dim: 32                  # RoPE dimensions for decoupled key encoding

tokenizer:
  tokenizer_type: char          # char | bpe
  tokenizer_path: null

data:
  dataset_source: local         # local | huggingface
  dataset_name: ""

training:
  lr: 0.005
  batch_size: 128
  grad_accum: 16
  optimizer: adamw              # adamw | adafactor | lion
  epochs: 1000
  warmup_steps: 420
  seed: 42

Config Validation

The Config dataclass enforces constraints at instantiation:

Config(dim=512, n_heads=16, head_size=32)   # OK — 16 × 32 = 512
Config(dim=512, n_heads=16, head_size=31)   # ConfigError: dim must equal n_heads × head_size
Config(dim=512, n_heads=16, kv_heads=3)     # ConfigError: n_heads must be divisible by kv_heads

CLI Reference

dantinox train

Argument Default Description
--config configs/default_config.yaml YAML config file
--data_path Path to plain-text corpus
--run_dir auto-generated Output directory for weights and logs
--wandb_project W&B project name for live logging
--<field> config value Override any Config field directly

dantinox generate

Argument Default Description
--run_dir required Run directory with config.yaml + model_weights.msgpack
--prompt "Nel mezzo del cammin " Input text prefix
--max_new_tokens 150 Number of tokens to generate
--temperature 1.0 Sampling temperature
--top_p null Nucleus sampling threshold
--top_k null Top-K sampling limit
--greedy false Deterministic greedy decoding
--no_cache false Disable KV-cache (slower, for debugging)
--seed 42 RNG seed

dantinox benchmark

Argument Default Description
--runs_dir runs Directory containing run sub-directories
--runs all Specific run names to benchmark
--out_csv Save results to this CSV path

dantinox sweep

Argument Default Description
--sweep_config configs/sweep.yaml W&B sweep YAML
--config configs/default_config.yaml Base model config (overridden by sweep)
--data_path required Training corpus
--wandb_project DantinoX W&B project
--count unlimited Maximum sweep runs

Project Structure

DantinoX/
├── core/                        # Neural network primitives
│   ├── config.py                # Config dataclass — single source of truth
│   ├── model.py                 # Transformer: embedding → blocks → LM head
│   ├── attention.py             # MHA / GQA / MLA + RoPE + KV-cache
│   ├── block.py                 # Transformer block (Attention + FFN + LayerNorm)
│   ├── mlp.py                   # Dense MLP (SwiGLU / GELU)
│   ├── moe.py                   # Sparse Mixture-of-Experts with load-balancing loss
│   └── generation.py            # Autoregressive decode loop (fori_loop + vmap)
│
├── dantinox/                    # Installable library package
│   ├── cli.py                   # `dantinox` entry-point (train/generate/sweep/benchmark/plot)
│   ├── trainer.py               # Trainer — JIT training loop, logging, checkpointing
│   ├── generator.py             # Generator — checkpoint loading + text generation
│   ├── bench.py                 # BenchmarkRunner — latency / throughput / FLOPs
│   ├── plotting.py              # Plotter — automated figure generation
│   └── exceptions.py            # Exception hierarchy (DantinoXError → sub-classes)
│
├── utils/
│   ├── tokenizer.py             # CharTokenizer · BPETokenizer · Tokenizer Protocol
│   └── helpers.py               # Loss · batch sampling · LR schedule
│
├── configs/                     # YAML configuration files
│   ├── default_config.yaml
│   └── sweep.yaml
│
├── tests/                       # Pytest test suite (22 tests)
│   ├── conftest.py              # Session-scoped Config fixtures
│   ├── test_model.py            # Forward pass · GQA · MoE · weight tying · JIT
│   └── test_mla.py              # MLA training · inference cache · RoPE constraints
│
├── pyproject.toml               # Package metadata, deps, ruff, mypy, pytest config
├── Makefile                     # Development targets
└── mkdocs.yml                   # Documentation site configuration

Exception Hierarchy

DantinoXError
├── ConfigError        — invalid or inconsistent Config fields
├── CheckpointError    — missing run directory, config, or weights
├── BenchmarkError     — failure loading or running a benchmark
└── PlotError          — missing CSV or plot module import failure

Development

All common workflows are exposed through make:

make install      # Install package + all dev/doc dependencies (editable)
make test         # Run test suite with coverage report
make lint         # Ruff static analysis
make typecheck    # Mypy type checking
make check        # lint + typecheck + test  (run before every push)
make build        # Build sdist + wheel into dist/
make publish      # Upload dist/ to PyPI via twine
make clean        # Remove build artefacts and __pycache__

Running Tests

make test

# Or directly:
JAX_PLATFORM_NAME=cpu python -m pytest tests/ -v

The suite runs on CPU (no GPU required) and covers:

  • Forward-pass output shapes for MHA, GQA, and MLA
  • KV-cache correctness and accumulation
  • MoE load-balancing loss propagation
  • Weight tying between embedding and LM head
  • JIT compilation stability
  • Config validation (dim constraints, GQA divisibility, MLA rope_dim)
  • Config round-trip serialization (to_dict / from_dict)

Coverage output is written to docs/coverage/ and published automatically with the documentation site.

Code Quality

The project enforces a strict quality baseline:

Tool Configuration What it checks
ruff pyproject.toml Style (E/W), imports (I), pyupgrade (UP), bugbear (B), simplify (SIM)
mypy pyproject.toml Full type annotation coverage across dantinox/, core/, utils/
pytest pyproject.toml 22 unit tests, CPU-only, session-scoped fixtures

Training Artifacts

Each training run writes an isolated artifact directory:

runs/run_20260101_120000/
├── config.yaml              # Exact config used for the run
├── model_summary.json       # Parameter counts per component
├── training_log.csv         # step, train_loss, val_loss, bal_loss, ms_per_step
└── model_weights.msgpack    # Serialized model state (Flax msgpack format)

The training loop logs to console via tqdm with live loss postfix, and optionally streams metrics to Weights & Biases when --wandb_project is specified.


Benchmarking

BenchmarkRunner measures latency and throughput across a matrix of sequence lengths and batch sizes using XLA cost analysis for FLOPs:

from dantinox import BenchmarkRunner
from dantinox.plotting import Plotter

df = BenchmarkRunner("runs").run(out_csv="benchmark_results.csv")
Plotter("benchmark_results.csv", out_dir="plots").run()

Reported metrics per run:

Metric Description
params_m Total trainable parameters (millions)
theoretical_cache_mb KV-cache memory at max_context (MB)
prefill_ms Prefill latency for a 256-token prompt
tps_{64,128,256,512} Decode throughput (tok/s) at each sequence length
tps_bs{1,4,16,64,...} Decode throughput at each batch size
decode_gflops FLOPs per decode step (XLA cost analysis)
prefill_arith_int Arithmetic intensity of the prefill kernel
val_loss Final validation loss from training_log.csv

Empirical Results

Ablation studies were conducted via W&B Bayesian sweeps over 100+ configurations. Key findings:

  • MLA vs GQA vs MHA: MLA achieves lower KV-cache memory with comparable perplexity when down_dim_kv ≤ dim / 4.
  • SwiGLU: Consistently outperforms GELU by ~0.05 val-loss across all model sizes.
  • Sliding Window: Improves training speed on long contexts with negligible perplexity loss when context_window ≥ 64.
  • Attention Gating (no_sink): Stabilizes training when combined with RoPE at high learning rates.
  • MoE (Top-2/4): Matches dense perplexity at 60% of the active-parameter count.

Full charts and analysis: Ablation Studies


Documentation

The full documentation is built with MkDocs Material and deployed to GitHub Pages:

# Rebuild and deploy
mkdocs gh-deploy --force

Sections:


License

MIT — see LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

dantinox-0.3.8.tar.gz (82.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

dantinox-0.3.8-py3-none-any.whl (77.2 kB view details)

Uploaded Python 3

File details

Details for the file dantinox-0.3.8.tar.gz.

File metadata

  • Download URL: dantinox-0.3.8.tar.gz
  • Upload date:
  • Size: 82.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.13

File hashes

Hashes for dantinox-0.3.8.tar.gz
Algorithm Hash digest
SHA256 177a7fe27a4c13513f841461bd350c3a1a139dc0bd84f523707efec46f6ca203
MD5 dd683629a6fa8f00c537cc7eeb31b985
BLAKE2b-256 665321d09a06c8bef772e5a3475d2e517140693cf2fe547a4381738e8f0c797a

See more details on using hashes here.

File details

Details for the file dantinox-0.3.8-py3-none-any.whl.

File metadata

  • Download URL: dantinox-0.3.8-py3-none-any.whl
  • Upload date:
  • Size: 77.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.13

File hashes

Hashes for dantinox-0.3.8-py3-none-any.whl
Algorithm Hash digest
SHA256 849e3cc223b541db243519da2bc4bb7a15817dada9d1f3646ad87f64247b2735
MD5 9301372469a08de9b4f36ce891fc527b
BLAKE2b-256 05553f34705cebda81c38e44bac2f15037a641e4a310a92f772eb38a7d538c38

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page