A DeepSpeed-like training library for Google TPUs, built on JAX/Flax.
Project description
AetherLM
A DeepSpeed-like training library for Google TPUs, built on JAX/Flax.
AetherLM abstracts away the complexity of distributed training on TPU pods, providing a simple initialize() API for pretraining and fine-tuning transformer models.
Features
- One-call initialization —
aetherlm.initialize(config)sets up mesh, model, optimizer, and metrics - TPU-native — Built on JAX with automatic TPU topology detection and optimal sharding
- Multiple training modes — MLM pretraining, contrastive learning, causal LM
- Built-in models —
AetherBERT(bidirectional) andAetherCausalLM(autoregressive with generation) - YAML/dict configuration — DeepSpeed-style config system
- MTEB evaluation — Integrated embedding evaluation with 56 tasks
- Checkpoint management — Automatic saving, rotation, and restore with Orbax
Quick Start
import aetherlm
# Configure and initialize (auto-detects TPU)
engine = aetherlm.initialize(config={
"model": {"model_type": "bert", "embed_dim": 768, "num_layers": 12},
"training": {"mode": "mlm", "batch_size": 32, "learning_rate": 1e-4},
"tpu": {"precision": "bf16"},
})
# Load data
from aetherlm.data import load_mlm_datasets
train_data, val_data = load_mlm_datasets(maxlen=512, mask_prob=0.15, vocab_size=50265)
# Train (handles sharding, logging, checkpointing)
engine.train(train_data, val_data)
Installation
# From PyPI
pip install aetherlm
# With eval support (MTEB, sklearn, scipy)
pip install aetherlm[eval]
# From source
pip install -e ".[all]"
Training Modes
MLM Pretraining (BERT-style)
aetherlm --mode mlm --config configs/default.yaml
Causal Language Modeling (GPT-style)
aetherlm --mode causal --config configs/causal_small.yaml
Contrastive Learning (requires checkpoint)
aetherlm --mode contrastive --checkpoint ./checkpoints/step_10000
Evaluation
# Quick MTEB eval (3 tasks, ~1 min)
aetherlm --mode eval --checkpoint ./checkpoints/step_10000 --mteb_preset quick
# Full leaderboard (56 tasks)
aetherlm --mode eval --checkpoint ./checkpoints/step_10000 --mteb_preset leaderboard
Configuration
Aether uses a dataclass-based config system. Create configs from YAML, JSON, or Python dicts:
from aetherlm import AetherConfig
# From YAML
config = AetherConfig.from_yaml("configs/default.yaml")
# From dict
config = AetherConfig.from_dict({
"model": {"embed_dim": 768, "num_layers": 12},
"training": {"mode": "mlm", "batch_size": 32},
})
# Save for reproducibility
config.to_yaml("my_experiment.yaml")
See configs/ for example configurations.
Project Structure
aetherlm/
__init__.py # Top-level API: initialize(), AetherConfig, models
core/
config.py # Dataclass configuration system
engine.py # Training engine (the heart of the library)
sharding.py # Automatic mesh sharding
models/
base.py # Abstract model interface + utilities
transformer.py # Transformer blocks, embeddings
bert.py # AetherBERT (bidirectional MLM)
causal.py # AetherCausalLM (autoregressive + generation)
optim/
optimizers.py # Optimizer factory (AdamW, Adam, SGD)
schedules.py # LR schedules (warmup-cosine, linear, constant)
switching.py # Plateau detection + optimizer switching
data/
pipeline.py # Tokenizer caching, batch iterators
mlm.py # MLM masking and dataset processing
contrastive.py # Contrastive pair creation (self-supervised + AllNLI)
causal.py # Causal LM next-token prediction format
losses/
mlm.py # Efficient gather-based MLM loss
contrastive.py # SimCLR-style contrastive loss
causal.py # Causal LM cross-entropy loss
training/
steps.py # JIT-compiled train/eval steps for all modes
checkpoint/
manager.py # Orbax checkpoint save/load/rotation
metrics/
tracker.py # Throughput, loss, ETA, WandB logging
eval/
tasks.py # MTEB task lists and presets
mteb.py # MTEB EncoderProtocol wrapper
runner.py # Evaluation orchestrators
tpu/
topology.py # TPU detection, mesh creation
precision.py # bfloat16/mixed precision config
cli/
main.py # CLI entry point
configs/ # Example YAML configurations
notebooks/ # Tutorial notebooks
Notebooks
| Notebook | Description |
|---|---|
01_quickstart.ipynb |
Core features: models, config, initialize |
02_mlm_pretraining.ipynb |
Full MLM pretraining pipeline |
03_causal_lm.ipynb |
Causal LM training + text generation |
04_evaluation.ipynb |
MTEB and custom evaluation |
License
Apache 2.0
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file aetherlm-0.1.0.tar.gz.
File metadata
- Download URL: aetherlm-0.1.0.tar.gz
- Upload date:
- Size: 38.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.1
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d99483cdd046f55acc24599883d72a6dacedc84948bc9e57b1a7115737e927c4
|
|
| MD5 |
d8b32b650e5808e8d40b85903cf0ad4a
|
|
| BLAKE2b-256 |
ef598286ad8a9275d197a2c551028ba7a04337bbf75172822655ab93cc5dba8f
|
File details
Details for the file aetherlm-0.1.0-py3-none-any.whl.
File metadata
- Download URL: aetherlm-0.1.0-py3-none-any.whl
- Upload date:
- Size: 48.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.1
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
15b352fc3b70631391e460ebcffc6ff101fb2e9e174a492df05955bde89424ce
|
|
| MD5 |
0d74e543f0e35ad87a4e29d68bf30d03
|
|
| BLAKE2b-256 |
3093df06a92afb6c187cc4bc0b4793b567d4477a80b54396b29fce2fb9c76fbd
|