Skip to main content

A DeepSpeed-like training library for Google TPUs, built on JAX/Flax.

Project description

AetherLM

A DeepSpeed-like training library for Google TPUs, built on JAX/Flax.

AetherLM abstracts away the complexity of distributed training on TPU pods, providing a simple initialize() API for pretraining and fine-tuning transformer models.

Features

  • One-call initializationaetherlm.initialize(config) sets up mesh, model, optimizer, and metrics
  • TPU-native — Built on JAX with automatic TPU topology detection and optimal sharding
  • Multiple training modes — MLM pretraining, contrastive learning, causal LM
  • Built-in modelsAetherBERT (bidirectional) and AetherCausalLM (autoregressive with generation)
  • YAML/dict configuration — DeepSpeed-style config system
  • MTEB evaluation — Integrated embedding evaluation with 56 tasks
  • Checkpoint management — Automatic saving, rotation, and restore with Orbax

Quick Start

import aetherlm

# Configure and initialize (auto-detects TPU)
engine = aetherlm.initialize(config={
    "model": {"model_type": "bert", "embed_dim": 768, "num_layers": 12},
    "training": {"mode": "mlm", "batch_size": 32, "learning_rate": 1e-4},
    "tpu": {"precision": "bf16"},
})

# Load data
from aetherlm.data import load_mlm_datasets
train_data, val_data = load_mlm_datasets(maxlen=512, mask_prob=0.15, vocab_size=50265)

# Train (handles sharding, logging, checkpointing)
engine.train(train_data, val_data)

Installation

# From PyPI
pip install aetherlm

# With eval support (MTEB, sklearn, scipy)
pip install aetherlm[eval]

# From source
pip install -e ".[all]"

Training Modes

MLM Pretraining (BERT-style)

aetherlm --mode mlm --config configs/default.yaml

Causal Language Modeling (GPT-style)

aetherlm --mode causal --config configs/causal_small.yaml

Contrastive Learning (requires checkpoint)

aetherlm --mode contrastive --checkpoint ./checkpoints/step_10000

Evaluation

# Quick MTEB eval (3 tasks, ~1 min)
aetherlm --mode eval --checkpoint ./checkpoints/step_10000 --mteb_preset quick

# Full leaderboard (56 tasks)
aetherlm --mode eval --checkpoint ./checkpoints/step_10000 --mteb_preset leaderboard

Configuration

Aether uses a dataclass-based config system. Create configs from YAML, JSON, or Python dicts:

from aetherlm import AetherConfig

# From YAML
config = AetherConfig.from_yaml("configs/default.yaml")

# From dict
config = AetherConfig.from_dict({
    "model": {"embed_dim": 768, "num_layers": 12},
    "training": {"mode": "mlm", "batch_size": 32},
})

# Save for reproducibility
config.to_yaml("my_experiment.yaml")

See configs/ for example configurations.

Project Structure

aetherlm/
    __init__.py           # Top-level API: initialize(), AetherConfig, models
    core/
        config.py         # Dataclass configuration system
        engine.py         # Training engine (the heart of the library)
        sharding.py       # Automatic mesh sharding
    models/
        base.py           # Abstract model interface + utilities
        transformer.py    # Transformer blocks, embeddings
        bert.py           # AetherBERT (bidirectional MLM)
        causal.py         # AetherCausalLM (autoregressive + generation)
    optim/
        optimizers.py     # Optimizer factory (AdamW, Adam, SGD)
        schedules.py      # LR schedules (warmup-cosine, linear, constant)
        switching.py      # Plateau detection + optimizer switching
    data/
        pipeline.py       # Tokenizer caching, batch iterators
        mlm.py            # MLM masking and dataset processing
        contrastive.py    # Contrastive pair creation (self-supervised + AllNLI)
        causal.py         # Causal LM next-token prediction format
    losses/
        mlm.py            # Efficient gather-based MLM loss
        contrastive.py    # SimCLR-style contrastive loss
        causal.py         # Causal LM cross-entropy loss
    training/
        steps.py          # JIT-compiled train/eval steps for all modes
    checkpoint/
        manager.py        # Orbax checkpoint save/load/rotation
    metrics/
        tracker.py        # Throughput, loss, ETA, WandB logging
    eval/
        tasks.py          # MTEB task lists and presets
        mteb.py           # MTEB EncoderProtocol wrapper
        runner.py         # Evaluation orchestrators
    tpu/
        topology.py       # TPU detection, mesh creation
        precision.py      # bfloat16/mixed precision config
    cli/
        main.py           # CLI entry point
configs/                  # Example YAML configurations
notebooks/                # Tutorial notebooks

Notebooks

Notebook Description
01_quickstart.ipynb Core features: models, config, initialize
02_mlm_pretraining.ipynb Full MLM pretraining pipeline
03_causal_lm.ipynb Causal LM training + text generation
04_evaluation.ipynb MTEB and custom evaluation

License

Apache 2.0

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

aetherlm-0.1.0.tar.gz (38.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

aetherlm-0.1.0-py3-none-any.whl (48.8 kB view details)

Uploaded Python 3

File details

Details for the file aetherlm-0.1.0.tar.gz.

File metadata

  • Download URL: aetherlm-0.1.0.tar.gz
  • Upload date:
  • Size: 38.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.1

File hashes

Hashes for aetherlm-0.1.0.tar.gz
Algorithm Hash digest
SHA256 d99483cdd046f55acc24599883d72a6dacedc84948bc9e57b1a7115737e927c4
MD5 d8b32b650e5808e8d40b85903cf0ad4a
BLAKE2b-256 ef598286ad8a9275d197a2c551028ba7a04337bbf75172822655ab93cc5dba8f

See more details on using hashes here.

File details

Details for the file aetherlm-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: aetherlm-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 48.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.1

File hashes

Hashes for aetherlm-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 15b352fc3b70631391e460ebcffc6ff101fb2e9e174a492df05955bde89424ce
MD5 0d74e543f0e35ad87a4e29d68bf30d03
BLAKE2b-256 3093df06a92afb6c187cc4bc0b4793b567d4477a80b54396b29fce2fb9c76fbd

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page