Hybrid SSM-Attention language model on Apple Silicon with MLX
Project description
Alloy
Hybrid SSM-Attention language model built on MLX for Apple Silicon.
Alloy interleaves Mamba-2 (selective state-space) blocks with Attention blocks in a single model, combining the linear-time efficiency of SSMs with the precise recall of Attention.
Features
- Mamba-2 block — selective scan with chunked parallel computation + fused Metal parallel scan kernel
- Attention block — MHA / GQA / sliding-window, with RoPE
- HybridLM — configurable interleaved architecture, supports both Alloy-native and Zamba2 modes
- Training — AdamW + cosine schedule, streaming JSONL dataloader with packing
- LoRA — freeze-and-inject adapter, save/load/merge
- Generation — autoregressive decoding with KV + SSM cache, top-p sampling, streaming output
- Weight conversion — load Zamba2 / Jamba weights from HuggingFace (verified on Zamba2-1.2B, with LoRA adapter merging)
- Metal kernels — fused conv1d+SiLU (8.3x), parallel associative scan (2.2x at cs=512)
- bfloat16 — mixed precision support (2x memory reduction, scan internals auto-promote to fp32)
- Autoresearch — autonomous architecture search harness (28 experiments, 22.6% improvement)
Quickstart
pip install -e ".[dev]"
pip install transformers huggingface_hub # for pretrained models
Chat with a model (easiest)
# Auto-downloads from HuggingFace, 4-bit quantized, only 1.3 GB memory
python -m alloy.chat --model Zyphra/Zamba2-1.2B-instruct --quantize 4
OpenAI-compatible API server
# Start server
python -m alloy.serve --model Zyphra/Zamba2-1.2B-instruct --quantize 4 --port 8000
# Use with any OpenAI client
curl http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{"messages": [{"role": "user", "content": "Hello!"}]}'
Convert models
# Download + convert + quantize + save
python -m alloy.convert_cli --model Zyphra/Zamba2-1.2B --quantize 4 --output models/zamba2-4bit
Evaluate
python -m alloy.eval --model Zyphra/Zamba2-1.2B --quantize 4
Python API
from alloy.convert import load_pretrained
from alloy.generate import generate
from transformers import AutoTokenizer
model = load_pretrained("path/to/Zamba2-1.2B")
model.to_bfloat16() # optional: halve memory (6.9 GB → 3.5 GB)
tokenizer = AutoTokenizer.from_pretrained("path/to/Zamba2-1.2B")
text = generate(model, tokenizer, "The capital of France is", max_tokens=100)
Train from scratch
# With climbmix data (auto-detected)
python prepare.py --num-shards 10
python -m alloy.train --config configs/toy.yaml --data climbmix --max-steps 2000
# With custom JSONL data
python -m alloy.train --config configs/toy.yaml --data data/train.jsonl
LoRA fine-tune
from alloy.lora import linear_to_lora_layers, save_lora_weights, merge_lora_weights
model.freeze()
linear_to_lora_layers(model, lora_rank=16)
# ... train as usual, only LoRA params update ...
save_lora_weights(model, "adapter.npz")
merge_lora_weights(model) # fold adapter into base weights
Autoresearch (autonomous architecture search)
python prepare.py --num-shards 10
python train.py > run.log 2>&1 # 5-min budget per experiment
See program.md and docs/autoresearch-report.md for details.
Model configurations
| Config | d_model | Layers | Params | Use case |
|---|---|---|---|---|
toy.yaml |
512 | 12 | ~100M | Architecture validation |
small.yaml |
1024 | 24 | ~500M | Quick experiments |
medium.yaml |
2048 | 32 | ~1.5B | Full training |
autoresearch.yaml |
512 | 2 | ~15M | Optimal for 5-min autoresearch budget |
Key findings from autoresearch
28 autonomous experiments validated core architectural decisions:
| Architecture | val_bpb | Notes |
|---|---|---|
| Hybrid (1M+1A) | 1.676 | Best — SSM + Attention complement each other |
| Pure Mamba (2M) | 1.999 | +0.32 worse, lacks precise recall |
| Pure Attention (2A) | 2.095 | +0.42 worse, despite more steps |
Key insights:
- Mamba first, Attention last — reversed order catastrophic (2.195)
- Shallow + wide wins under fixed time budget (2L > 3L > 4L)
- GQA effective even in hybrid models (n_kv_heads=2 helps)
- Batch size 2^13 optimal (balance: gradient quality vs step count)
See docs/autoresearch-report.md for all 28 experiments.
Project structure
alloy/
├── alloy/
│ ├── models/
│ │ ├── mamba_block.py # Mamba-2 (Alloy + Zamba2 modes)
│ │ ├── mamba_kernels.py # Metal GPU kernels
│ │ ├── attention_block.py # MHA / GQA / sliding window
│ │ ├── hybrid_model.py # HybridLM + HybridBlock
│ │ └── cache.py # MambaCache / AttentionCache / Zamba2HybridLayerCache
│ ├── data/
│ │ └── dataloader.py # Streaming JSONL + packing
│ ├── chat.py # Interactive chat CLI
│ ├── serve.py # OpenAI-compatible HTTP API
│ ├── eval.py # Lightweight evaluation suite
│ ├── convert_cli.py # Model conversion CLI
│ ├── generate.py # Autoregressive generation
│ ├── train.py # Training loop + CLI
│ ├── lora.py # LoRA inject / save / merge
│ └── convert.py # HuggingFace weight conversion
├── configs/ # YAML model configs
├── tests/ # 88 tests
├── docs/
│ ├── roadmap.md # Project roadmap
│ ├── autoresearch-report.md # 28-experiment report
│ └── spec.md # Full project spec
├── prepare.py # Autoresearch data pipeline
├── train.py # Autoresearch training script
└── pyproject.toml
Tests
python -m pytest tests/ -v # 88 tests, ~0.5s
Architecture
Alloy mode (default)
Each HybridBlock follows the pre-norm residual pattern:
x → RMSNorm → [MambaBlock or AttentionBlock] → + → RMSNorm → FFN → + → out
↑_______________________________________________↑ ↑________________↑
Zamba2 mode (for pretrained Zamba2 models)
Hybrid layers contain both mamba and attention:
┌─ cat(x, emb) → Norm → Attention → Norm → FFN ─┐
x → shared_transformer ─────────────────────────────────────────────── linear
└─ (x + linear_out) → Norm → MambaDecoder → + → out ───────────────────┘
Performance
Metal kernel acceleration
| Operation | Pure MLX | Metal Kernel | Speedup |
|---|---|---|---|
| Conv1d + SiLU | 3.6ms | 0.4ms | 8.3x |
| Scan chunk (cs=512) | 11.6ms | 5.2ms | 2.2x |
| MambaBlock forward | 36.7ms | 26.7ms | 1.38x |
The parallel scan kernel auto-selects based on chunk size: matmul for cs<256 (hardware matrix engines optimal), Metal parallel scan for cs≥256 (O(cs·log cs) beats O(cs²)).
Zamba2-1.2B inference
| Mode | Speed | Memory |
|---|---|---|
| No cache | 5.3 tok/s | 6.9 GB (fp32) |
| KV + SSM cache | 24.6 tok/s | 6.9 GB (fp32) |
| KV + SSM cache + bf16 | 24.6 tok/s | 3.5 GB |
| KV + SSM cache + 4-bit | 66.7 tok/s | 1.3 GB |
Logit alignment with HuggingFace reference
| Configuration | Avg top-5 diff | Top-1 agreement |
|---|---|---|
| Without LoRA adapters | 0.64 | — |
| With LoRA adapters merged | 0.23 | 80% |
Benchmark (quick eval, small sample)
| Benchmark | Score | Random baseline |
|---|---|---|
| MMLU (20 questions) | 80.0% | 25% |
| HellaSwag (5 questions) | 100.0% | 25% |
References
- Mamba: Linear-Time Sequence Modeling — Gu & Dao
- Transformers are SSMs — Dao & Gu (Mamba-2)
- Jamba: A Hybrid Transformer-Mamba LM
- Zamba2 — Zyphra
- autoresearch-mlx — Karpathy's autonomous research
- MLX
License
MIT
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file alloy_mlx-0.1.0.tar.gz.
File metadata
- Download URL: alloy_mlx-0.1.0.tar.gz
- Upload date:
- Size: 51.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.15
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
3e6f783852b75c46bd7ab8abb9a56e6e996e2893bddaf0604f5a4f3ba4f85a66
|
|
| MD5 |
e11d9a31a8fa7d89623b8af5b9445bea
|
|
| BLAKE2b-256 |
a15838069ec77a539f5610b239214999d780f1033d9e9631f7b4b1dee078058d
|
File details
Details for the file alloy_mlx-0.1.0-py3-none-any.whl.
File metadata
- Download URL: alloy_mlx-0.1.0-py3-none-any.whl
- Upload date:
- Size: 44.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.15
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
6def9c10d31c67f375fa49a5344a64218d840b8d3418aadc8f746bd11545bb09
|
|
| MD5 |
2d0d01b59589d08b351a948521a06fa3
|
|
| BLAKE2b-256 |
a93eab5936671e0d22ac709d7859d92f7d1889789ad58d9d74c2937672222b54
|