Nano Mixture-of-Experts language model in JAX/Flax — a lightweight, educational MoE implementation
Project description
NanoMoE — Mixture-of-Experts in JAX
A lightweight, educational Mixture-of-Experts (MoE) GPT-style language model built from scratch in JAX / Flax.
Inspired by nanoGPT, NanoMoE replaces the standard FFN in each transformer block with a sparse MoE layer — only top-k experts activate per token, giving increased model capacity with reduced compute per forward pass.
Architecture
Input Tokens
↓
Token Embedding + Positional Embedding
↓
┌─────────────────────────────────────┐
│ Transformer Block ×N │
│ │
│ LayerNorm → Causal Multi-Head Attn │
│ ↓ + Residual │
│ LayerNorm → MoE Layer │
│ ↓ + Residual │
│ │
│ ┌─── MoE Layer ─────────────────┐ │
│ │ Router (Top-K Gating) │ │
│ │ ├─ Expert 1 (FFN) │ │
│ │ ├─ Expert 2 (FFN) │ │
│ │ ├─ ... │ │
│ │ └─ Expert N (FFN) │ │
│ │ → Weighted Sum of Top-K │ │
│ └───────────────────────────────┘ │
└─────────────────────────────────────┘
↓
LayerNorm → Linear Head → Logits
Key Features
- Sparse MoE Routing — Top-K gating with softmax; only a subset of experts runs per token
- Load-Balancing Loss — Switch Transformer-style auxiliary loss for uniform expert utilisation
- Pure JAX/Flax — No custom CUDA kernels; portable across CPU, GPU, and TPU
- Autoregressive Generation — Temperature + top-k sampling for text generation
- Fully JIT-compiled training and evaluation steps
Quick Start
Install
git clone https://github.com/carrycooldude/MoE-JAX.git
cd MoE-JAX
pip install -r requirements.txt
Note: For GPU support, install the appropriate
jaxlibCUDA wheel — see JAX installation.
Train on Tiny Shakespeare
python examples/train_shakespeare.py
This downloads Tiny Shakespeare (~1 MB), trains a character-level NanoMoE, and generates sample text.
Run Tests
python -m pytest tests/ -v
Project Structure
MoE-JAX/
├── nano_moe/
│ ├── __init__.py # Public API
│ ├── config.py # Hyperparameter dataclass
│ ├── layers.py # ExpertFFN, Router, MoELayer, Attention, TransformerBlock
│ ├── model.py # NanoMoE model + generate()
│ ├── train.py # Training loop, JIT-compiled steps
│ └── utils.py # Param counting, batching, data loading
├── examples/
│ └── train_shakespeare.py
├── tests/
│ ├── test_layers.py
│ └── test_model.py
├── requirements.txt
└── README.md
Default Hyperparameters
| Parameter | Value | Description |
|---|---|---|
d_model |
128 | Hidden dimension |
n_layers |
4 | Transformer blocks |
n_heads |
4 | Attention heads |
d_ff |
512 | Expert FFN inner dim |
n_experts |
4 | Experts per MoE layer |
top_k |
2 | Active experts per token |
block_size |
128 | Max context length |
aux_loss_coeff |
0.01 | Load-balancing loss weight |
How It Works
- Router projects each token to
n_expertslogits and selects the top-k experts - Experts are independent 2-layer FFNs (d_model → d_ff → d_model, GELU activation)
- Weighted Sum combines the top-k expert outputs using normalised softmax gates
- Auxiliary Loss penalises uneven routing to prevent expert collapse
License
MIT
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file nano_moe_jax-0.1.0.tar.gz.
File metadata
- Download URL: nano_moe_jax-0.1.0.tar.gz
- Upload date:
- Size: 13.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
75f5044f265b0b144bad0aec56bbf123e357d318c3e885d66f62a86973512274
|
|
| MD5 |
2ca0882133ac6c2d6c4457fe118141da
|
|
| BLAKE2b-256 |
17d8c9dc0808d2e839e5cf5cb21654cc7deeb43e2d39241f2bd9cf8ef8e7e817
|
Provenance
The following attestation bundles were made for nano_moe_jax-0.1.0.tar.gz:
Publisher:
publish.yml on carrycooldude/Nano-MoE-JAX
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
nano_moe_jax-0.1.0.tar.gz -
Subject digest:
75f5044f265b0b144bad0aec56bbf123e357d318c3e885d66f62a86973512274 - Sigstore transparency entry: 985910044
- Sigstore integration time:
-
Permalink:
carrycooldude/Nano-MoE-JAX@79550a336465cb1ec06996df11017a2888dbd713 -
Branch / Tag:
refs/tags/v0.1.0 - Owner: https://github.com/carrycooldude
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@79550a336465cb1ec06996df11017a2888dbd713 -
Trigger Event:
release
-
Statement type:
File details
Details for the file nano_moe_jax-0.1.0-py3-none-any.whl.
File metadata
- Download URL: nano_moe_jax-0.1.0-py3-none-any.whl
- Upload date:
- Size: 12.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
77ae7fd51f363b3c8178e03959da02a5d2c7271694738dad0ebf468b1edb2b51
|
|
| MD5 |
57b802650d71fe4ddf900e5d097d38c8
|
|
| BLAKE2b-256 |
e72caf191fb26744ab0c54368457c847cdce82ce5716e771cf1a4ed056523033
|
Provenance
The following attestation bundles were made for nano_moe_jax-0.1.0-py3-none-any.whl:
Publisher:
publish.yml on carrycooldude/Nano-MoE-JAX
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
nano_moe_jax-0.1.0-py3-none-any.whl -
Subject digest:
77ae7fd51f363b3c8178e03959da02a5d2c7271694738dad0ebf468b1edb2b51 - Sigstore transparency entry: 985910107
- Sigstore integration time:
-
Permalink:
carrycooldude/Nano-MoE-JAX@79550a336465cb1ec06996df11017a2888dbd713 -
Branch / Tag:
refs/tags/v0.1.0 - Owner: https://github.com/carrycooldude
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@79550a336465cb1ec06996df11017a2888dbd713 -
Trigger Event:
release
-
Statement type: