Skip to main content

Nano Mixture-of-Experts language model in JAX/Flax — a lightweight, educational MoE implementation

Project description

NanoMoE — Mixture-of-Experts in JAX

A lightweight, educational Mixture-of-Experts (MoE) GPT-style language model built from scratch in JAX / Flax.

Inspired by nanoGPT, NanoMoE replaces the standard FFN in each transformer block with a sparse MoE layer — only top-k experts activate per token, giving increased model capacity with reduced compute per forward pass.

Architecture

Input Tokens
    ↓
Token Embedding + Positional Embedding
    ↓
┌─────────────────────────────────────┐
│         Transformer Block ×N        │
│                                     │
│  LayerNorm → Causal Multi-Head Attn │
│      ↓ + Residual                   │
│  LayerNorm → MoE Layer              │
│      ↓ + Residual                   │
│                                     │
│  ┌─── MoE Layer ─────────────────┐  │
│  │ Router (Top-K Gating)         │  │
│  │   ├─ Expert 1 (FFN)           │  │
│  │   ├─ Expert 2 (FFN)           │  │
│  │   ├─ ...                      │  │
│  │   └─ Expert N (FFN)           │  │
│  │ → Weighted Sum of Top-K       │  │
│  └───────────────────────────────┘  │
└─────────────────────────────────────┘
    ↓
LayerNorm → Linear Head → Logits

Key Features

  • Sparse MoE Routing — Top-K gating with softmax; only a subset of experts runs per token
  • Load-Balancing Loss — Switch Transformer-style auxiliary loss for uniform expert utilisation
  • Pure JAX/Flax — No custom CUDA kernels; portable across CPU, GPU, and TPU
  • Autoregressive Generation — Temperature + top-k sampling for text generation
  • Fully JIT-compiled training and evaluation steps

Quick Start

Install

git clone https://github.com/carrycooldude/MoE-JAX.git
cd MoE-JAX
pip install -r requirements.txt

Note: For GPU support, install the appropriate jaxlib CUDA wheel — see JAX installation.

Train on Tiny Shakespeare

python examples/train_shakespeare.py

This downloads Tiny Shakespeare (~1 MB), trains a character-level NanoMoE, and generates sample text.

Run Tests

python -m pytest tests/ -v

Project Structure

MoE-JAX/
├── nano_moe/
│   ├── __init__.py        # Public API
│   ├── config.py          # Hyperparameter dataclass
│   ├── layers.py          # ExpertFFN, Router, MoELayer, Attention, TransformerBlock
│   ├── model.py           # NanoMoE model + generate()
│   ├── train.py           # Training loop, JIT-compiled steps
│   └── utils.py           # Param counting, batching, data loading
├── examples/
│   └── train_shakespeare.py
├── tests/
│   ├── test_layers.py
│   └── test_model.py
├── requirements.txt
└── README.md

Default Hyperparameters

Parameter Value Description
d_model 128 Hidden dimension
n_layers 4 Transformer blocks
n_heads 4 Attention heads
d_ff 512 Expert FFN inner dim
n_experts 4 Experts per MoE layer
top_k 2 Active experts per token
block_size 128 Max context length
aux_loss_coeff 0.01 Load-balancing loss weight

How It Works

  1. Router projects each token to n_experts logits and selects the top-k experts
  2. Experts are independent 2-layer FFNs (d_model → d_ff → d_model, GELU activation)
  3. Weighted Sum combines the top-k expert outputs using normalised softmax gates
  4. Auxiliary Loss penalises uneven routing to prevent expert collapse

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

nano_moe_jax-0.1.0.tar.gz (13.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

nano_moe_jax-0.1.0-py3-none-any.whl (12.6 kB view details)

Uploaded Python 3

File details

Details for the file nano_moe_jax-0.1.0.tar.gz.

File metadata

  • Download URL: nano_moe_jax-0.1.0.tar.gz
  • Upload date:
  • Size: 13.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for nano_moe_jax-0.1.0.tar.gz
Algorithm Hash digest
SHA256 75f5044f265b0b144bad0aec56bbf123e357d318c3e885d66f62a86973512274
MD5 2ca0882133ac6c2d6c4457fe118141da
BLAKE2b-256 17d8c9dc0808d2e839e5cf5cb21654cc7deeb43e2d39241f2bd9cf8ef8e7e817

See more details on using hashes here.

Provenance

The following attestation bundles were made for nano_moe_jax-0.1.0.tar.gz:

Publisher: publish.yml on carrycooldude/Nano-MoE-JAX

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file nano_moe_jax-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: nano_moe_jax-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 12.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for nano_moe_jax-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 77ae7fd51f363b3c8178e03959da02a5d2c7271694738dad0ebf468b1edb2b51
MD5 57b802650d71fe4ddf900e5d097d38c8
BLAKE2b-256 e72caf191fb26744ab0c54368457c847cdce82ce5716e771cf1a4ed056523033

See more details on using hashes here.

Provenance

The following attestation bundles were made for nano_moe_jax-0.1.0-py3-none-any.whl:

Publisher: publish.yml on carrycooldude/Nano-MoE-JAX

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page