Skip to main content

Scalable Training for Foundation Models with Named Tensors and JAX

Project description

Levanter

Levanter is developed and released from the marin-community/marin monorepo (lib/levanter) and published to PyPI as marin-levanter. This README documents the package as shipped from that monorepo.

Documentation Status PyPI

You could not prevent a thunderstorm, but you could use the electricity; you could not direct the wind, but you could trim your sail so as to propel your vessel as you pleased, no matter which way the wind blew.
— Cora L. V. Hatch

Levanter is a framework for training large language models (LLMs) and other foundation models that strives for legibility, scalability, and reproducibility:

  1. Legible: Levanter uses our named tensor library Haliax to write easy-to-follow, composable deep learning code, while still being high performance.
  2. Scalable: Levanter scales to large models, and to be able to train on a variety of hardware, including GPUs and TPUs.
  3. Reproducible: Levanter is bitwise deterministic, meaning that the same configuration will always produce the same results, even in the face of preemption and resumption.

We built Levanter with JAX, Equinox, and Haliax.

Documentation

Levanter's documentation is available at levanter.readthedocs.io. Haliax's documentation is available at haliax.readthedocs.io.

Features

  • Distributed Training: We support distributed training on TPUs and GPUs, including FSDP and tensor parallelism.
  • Compatibility: Levanter supports importing and exporting models to/from the Hugging Face ecosystem, including tokenizers, datasets, and models via SafeTensors.
  • Performance: Levanter's performance rivals commercially-backed frameworks like MosaicML's Composer or Google's MaxText.
  • Resilience: Levanter supports fast, distributed checkpointing and fast resume from checkpoints with no data seek, making Levanter robust to preemption and hardware failure.
  • Cached On-Demand Data Preprocessing: We preprocess corpora online, but we cache the results of preprocessing so that resumes are much faster and so that subsequent runs are even faster. As soon as the first part of the cache is complete, Levanter will start training.
  • Logging: Levanter logs a rich and detailed set of metrics covering loss and performance. Levanter also supports a few different logging backends, including WandB and TensorBoard. (Adding a new logging backend is easy!) Levanter even exposes the ability to log inside of JAX jit-ted functions.
  • Reproducibility: On TPU, Levanter is bitwise deterministic, meaning that the same configuration will always produce the same results, even in the face of preemption and resumption.
  • Distributed Checkpointing: Distributed checkpointing is supported via Google's TensorStore library. Training can even be resumed on a different number of hosts, though this breaks reproducibility for now.
  • Optimization: We support Optax for optimization with AdamW, as well as newer optimizers like Muon, SOAP, and more.
  • Flexible: Levanter supports tuning data mixtures without having to retokenize or shuffle data.

Levanter was created by Stanford's Center for Research on Foundation Models (CRFM)'s research engineering team. You can also find us in the #levanter channel on the unofficial Jax LLM Discord

Getting Started

Here is a small set of examples to get you started. For more information about the various configuration options, please see the Getting Started guide or the In-Depth Configuration Guide. You can also use --help or poke around other configs to see all the options available to you.

Installing Levanter

After installing JAX with the appropriate configuration for your platform, install Levanter from PyPI:

pip install marin-levanter
wandb login  # optional, we use wandb for logging

For development, clone the marin monorepo and use uv sync to install Levanter alongside its sibling packages (Haliax, Iris, etc.) in editable form:

git clone https://github.com/marin-community/marin.git
cd marin
uv sync

Please refer to the Installation Guide for more information on how to install Levanter.

If you're using a TPU, more complete documentation for setting that up is available here. GPU support is still in-progress; documentation is available here.

Training a GPT2-nano

As a kind of hello world, here's how you can train a GPT-2 "nano"-sized model on a small dataset.

python -m levanter.main.train_lm --config_path config/gpt2_nano.yaml

# alternatively, if you didn't use -e and are in a different directory
python -m levanter.main.train_lm --config_path gpt2_nano

This will train a GPT2-nano model on the WikiText-103 dataset.

Training a Llama-small on your own data

You can also change the dataset by changing the dataset field in the config file. If your dataset is a Hugging Face dataset, you can use the data.id field to specify it:

python -m levanter.main.train_lm --config_path config/llama_small_fast.yaml --data.id openwebtext

# optionally, you may specify a tokenizer and/or a cache directory, which may be local or on gcs
python -m levanter.main.train_lm --config_path config/llama_small_fast.yaml --data.id openwebtext --data.tokenizer "NousResearch/Llama-2-7b-hf" --data.cache_dir "gs://path/to/cache/dir"

If instead your data is a list of URLs, you can use the data.train_urls and data.validation_urls fields to specify them. Data URLS can be local files, gcs files, or http(s) URLs, or anything that fsspec supports. Levanter (really, fsspec) will automatically uncompress .gz and .zstd files, and probably other formats too.

python -m levanter.main.train_lm --config_path config/llama_small_fast.yaml --data.train_urls ["https://path/to/train/data_*.jsonl.gz"] --data.validation_urls ["https://path/to/val/data_*.jsonl.gz"]

Customizing a Config File

You can modify the config file to change the model, the dataset, the training parameters, and more. Here's the llama_small_fast.yaml file:

data:
  train_urls:
      - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz"
  validation_urls:
      - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz"
  cache_dir: "gs://pubmed-mosaic/tokenized/openwebtext/"
model:
  type: llama
  hidden_dim: 768
  intermediate_dim: 2048
  num_heads: 12
  num_kv_heads: 12
  num_layers: 12
  seq_len: 1024
  gradient_checkpointing: true
trainer:
  tracker:
    type: wandb
    project: "levanter"
    tags: [ "openwebtext", "llama" ]

  mp: p=f32,c=bfloat16
  mesh:
    axes: {data: -1, replica: 1, model: 1}   # inherited defaults; override if you need TP
  per_device_parallelism: 4

  train_batch_size: 512
optimizer:
  learning_rate: 6E-4
  weight_decay: 0.1
  min_lr_ratio: 0.1

Other Architectures

Currently, we support the following architectures:

We plan to add more in the future.

For speech, we currently only support Whisper.

Continued Pretraining with Llama

Here's an example of how to continue pretraining a Llama 1 or Llama 2 model on the OpenWebText dataset:

python -m levanter.main.train_lm --config_path config/llama2_7b_continued.yaml

Distributed and Cloud Training

Training on a TPU Cloud VM

Please see the TPU Getting Started guide for more information on how to set up a TPU Cloud VM and run Levanter there.

Training with CUDA

Please see the CUDA Getting Started guide for more information on how to set up a CUDA environment and run Levanter there.

Contributing

We welcome contributions! Please see CONTRIBUTING.md for more information. Issues and pull requests are tracked at marin-community/marin.

License

Levanter is licensed under the Apache License, Version 2.0. See LICENSE for the full license text.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

marin_levanter-0.99.tar.gz (496.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

marin_levanter-0.99-py3-none-any.whl (609.8 kB view details)

Uploaded Python 3

File details

Details for the file marin_levanter-0.99.tar.gz.

File metadata

  • Download URL: marin_levanter-0.99.tar.gz
  • Upload date:
  • Size: 496.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.8.17

File hashes

Hashes for marin_levanter-0.99.tar.gz
Algorithm Hash digest
SHA256 15d99e1957d30d8bafd92226176bec62ea378ab29b38aaaa86b1a3a46f806e05
MD5 0b1ee11409729bb0e4cbeee731f022da
BLAKE2b-256 0c763d4f1591f58ad92b7500416416e7c467373223924373e3b7afd5f5da14dd

See more details on using hashes here.

File details

Details for the file marin_levanter-0.99-py3-none-any.whl.

File metadata

File hashes

Hashes for marin_levanter-0.99-py3-none-any.whl
Algorithm Hash digest
SHA256 565e99b745f77929542e8b122a119526cdbaa15ee327d9fc935115855c7dffff
MD5 2fb97e9f0437268cffb5e4f0ddc35244
BLAKE2b-256 e22f20d4805eebe3194071cfdea6ec6e5ddf363733c4213f254ef0bb70cc251b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page