(EasyDel Former) is a utility library designed to simplify and enhance the development in JAX
Project description
eformer 🔮 (EasyDel Former)
EasyDel Former is a batteries-included JAX toolkit for building, quantizing, scaling, and deploying modern transformer-style workloads on GPUs and TPUs.
Table of Contents
- Why eformer?
- Feature Highlights
- Module Map
- Installation
- Quickstart
- Examples & Guides
- Documentation
- Testing & Quality
- Contributing
- License
Why eformer?
eformer packages the infrastructure the EasyDel project uses to run large JAX models in production:
- Single import for system glue – argument parsing, logging, filesystem helpers, PyTree utilities, sharding, and TensorStore checkpoints live in one coherent namespace.
- Hardware-aware building blocks – Ray + TPU/GPU executors, mesh utilities, quantized kernels, and loss scaling are battle-tested in multi-slice pods.
- Productivity without boilerplate – dataclass-driven CLIs, optimizer factories, progress loggers, and serialization APIs keep research prototypes tidy.
- Deep integration with JAX – everything is PyTree-friendly,
jax.jit/vmapcompatible, and aware of sharding semantics so you can stay inside pure JAX programs.
Feature Highlights
Precision & Quantization
eformer.mpricexposesPolicy,PrecisionHandler, and dynamicLossScalerutilities so you can express policies likep=f32,c=f8_e4m3,o=f32and automatically wrap training/inference steps with casting and loss-scaling logic.- Unified quantization interface (
QuantizationConfig,QuantizationType,quantize,straight_through) supports NF4, INT8, binary, and ternary formats with actual bit packing, TPU-optimized NF4 kernels via Pallas, and STE support for QAT. eformer.jaximussupplies the implicit-array runtime (ImplicitArray,register,ste,implicitdecorator) that lets Array8B, ArrayNF4, and 1-bit tensors participate in JAX primitives without materializing unless needed.
Distributed Scaling & Executors
eformer.escaleprovides semantic sharding viaPartitionAxis,PartitionManager,auto_namedsharding, and helpers to convert per-layer rules intoPartitionSpecs that respect DP/FSDP/TP/EP/SP axes.- Mesh tooling (
create_mesh,MeshPartitionHelper) inspects pytree shapes and suggests sharding plans, while constraint utilities (with_sharding_constraint,get_corrected_named_sharding) fix up specs for real device meshes. eformer.executorbuilds on Ray to launch pods or multi-slice TPU jobs with automatic retries (RayExecutor.execute_resumable,execute_multislice_resumable), Docker orchestration, and SLURM-friendly cluster discovery (eSlurmCluster,auto_ray_cluster).
PyTree, Serialization & Storage
eformer.pytreeships >50 helpers for diffing, stacking, filtering, flattening, and serializing PyTrees plus MsgPack-basedto_bytes/from_bytesand type registration hooks.- High-level checkpointing (
serialization.Checkpointer,AsyncCheckpointManager,TensorStorebackends) supports time/step policies, async cleanup, and sharded array saves without all-gathers. eformer.paths.ePathabstracts local paths and Google Cloud Storage with identical APIs, including JAX array saves/loads and recursive globbing.
Optimizers & Training Ergonomics
OptimizerFactory+_buildersturn concise config dataclasses (AdamW, Adafactor, Muon, Lion, RMSProp, WhiteKron, Mars, Soap, Kron) into Optax transforms with scheduler composition.SchedulerFactorygenerates cosine/linear/warmup schedules or plugs in custom callables for experiments.aparser.Argu+DataClassArgumentParsertransform dataclasses into CLIs with YAML/JSON loading, alias handling, and bool toggles.loggings.get_loggeroffers colorized, process-aware loggers and progress tracking, whilecommon_typescentralizes semantic axis constants (BATCH, VOCAB, DP, TP, etc.) to keep sharding specs consistent.
Module Map
| Module | Purpose | Key entry points |
|---|---|---|
eformer.aparser |
Dataclass-first argument parsing & config loading | Argu, DataClassArgumentParser.parse_args_into_dataclasses, parse_yaml_file |
eformer.escale |
Mesh + sharding orchestration across DP/FSDP/TP/EP/SP | PartitionAxis, PartitionManager, auto_partition_spec, MeshPartitionHelper |
eformer.executor |
Ray-powered TPU/GPU executors, Docker helpers, SLURM glue | RayExecutor, execute_multislice_resumable, auto_ray_cluster, TpuAcceleratorConfig |
eformer.jaximus |
Implicit arrays and custom PyTree runtime for quantized tensors | ImplicitArray, register, implicit, ste |
eformer.mpric |
Mixed precision policies, dtype registries, dynamic loss scaling | Policy, PrecisionHandler, LossScaleConfig, DynamicLossScale |
eformer.ops.quantization |
NF4/INT8/1-bit quantization kernels and STE wrappers | QuantizationConfig, QuantizationType, ArrayNF4, Array8B, quantize, straight_through |
eformer.optimizers |
Configurable optimizer factory & scheduler utilities | OptimizerFactory, SchedulerFactory, optax_add_scheduled_weight_decay |
eformer.pytree |
Extensive PyTree manipulation and MsgPack serialization | tree_* helpers, PyTree, to_bytes, save_to_file |
eformer.serialization |
TensorStore checkpointing and async save managers | Checkpointer, CheckpointInterval, AsyncCheckpointManager, fsspec_utils |
eformer.paths |
Unified local/GCS path abstraction with ML utilities | ePath, LocalPath, GCSPath, save_jax_array, load_jax_array |
eformer.loggings |
Color logs, once-only warnings, progress meters | get_logger, LazyLogger, ProgressLogger |
eformer.common_types |
Shared axis constants & sharding-friendly aliases | BATCH, EMBED, DP, TP, PartitionAxis, DynamicShardingAxes |
Installation
eformer targets Python 3.11–3.13 with jax>=0.8.0. Install the TPU/GPU-specific JAX build that matches your platform before using hardware accelerators.
PyPI release
pip install eformer
From source (development)
git clone https://github.com/erfanzar/eformer.git
cd eformer
pip install -e '.[dev]'
# optional: keep dependencies in sync with uv
uv sync --dev
For documentation builds:
pip install -r docs/requirements.txt
make -C docs html
Quickstart
1. Dataclass-driven configuration
from dataclasses import dataclass
from eformer.aparser import Argu, DataClassArgumentParser
@dataclass
class RuntimeConfig:
steps: int = Argu(help="Number of training steps", default=10_000)
mesh: str = Argu(help="Mesh spec such as 'dp:2,tp:4'", default="dp:1,tp:1")
policy: str = Argu(help="Precision policy string", default="p=f32,c=f8_e4m3,o=f32")
parser = DataClassArgumentParser(RuntimeConfig, description="Train a transformer with eformer.")
config, = parser.parse_args_into_dataclasses()
# Load overrides from a YAML file if desired
config, = parser.parse_yaml_file("configs/train.yaml")
print(config)
Argu stores CLI metadata (aliases/help/defaults), and the parser can read dictionaries/JSON/YAML while validating against your dataclass schema.
2. Mixed-precision training with mpric
import jax
import jax.numpy as jnp
from eformer.mpric import PrecisionHandler
handler = PrecisionHandler(policy="p=f32,c=f8_e4m3,o=f32", use_dynamic_scale=True)
@jax.jit
def train_step(params, batch):
def loss_fn(p):
logits = model_apply(p, batch["inputs"])
labels = batch["labels"]
return jnp.mean(cross_entropy(logits, labels))
loss, grads = jax.value_and_grad(loss_fn)(params)
return loss, grads
train_step = handler.training_step_wrapper(train_step)
loss, grads, grads_finite = train_step(params, batch)
PrecisionHandler jit-wraps casting, loss scaling, underflow detection, and gradient unscaling so the wrapped function stays focused on model math.
3. Work with quantized weights (NF4/INT8/Binary)
import jax
import jax.numpy as jnp
from eformer.jaximus import implicit
from eformer.ops.quantization import (
QuantizationConfig,
QuantizationType,
quantize,
straight_through,
)
@implicit
def nf4_linear(x, w):
return x @ w # dot_general dispatches to implicit handlers when possible
config = QuantizationConfig(dtype=QuantizationType.NF4, block_size=64)
nf4_weights = quantize(weight_fp32, config=config)
# Inference uses compressed tensors directly
logits = nf4_linear(inputs, nf4_weights)
# Training keeps float32 master weights but injects STE quantization on the fly
def loss_fn(master_weight):
q_weight = straight_through(master_weight, config=config)
preds = nf4_linear(inputs, q_weight)
return jnp.mean((preds - targets) ** 2)
loss, grads = jax.value_and_grad(loss_fn)(weight_fp32)
quantize returns implicit arrays (NF4, INT8, Binary), and the implicit decorator routes JAX primitives (dot, pow, matmul, etc.) through registered handlers that load custom Triton/Pallas kernels when available.
4. Partition tensors and launch Ray jobs
import jax
from eformer.common_types import BATCH, EMBED
from eformer.escale import MeshPartitionHelper, PartitionAxis, PartitionManager, create_mesh
from eformer.executor.ray import execute_multislice_resumable, TpuAcceleratorConfig
mesh = create_mesh(axis_dims=(2, 2), axis_names=("dp", "tp"))
helper = MeshPartitionHelper(mesh)
manager = PartitionManager(paxis=PartitionAxis(batch_axis="dp", hidden_state_axis="tp"))
with mesh:
sharded_state = helper.auto_shard_pytree(train_state)
hidden = manager.shard(hidden_states, axes=(BATCH, EMBED))
job_status = execute_multislice_resumable(
remote_fn=train_slice_remote, # decorated with @ray.remote
accelerator_config=TpuAcceleratorConfig(type="v4-8", pod_count=2),
max_retries_preemption=5,
max_retries_failure=2,
)
MeshPartitionHelper inspects trees to produce sensible PartitionSpecs; PartitionManager gives semantic sharding (batch/hidden/etc.), and RayExecutor manages multi-slice TPU or GPU execution with resumable jobs.
Examples & Guides
examples/quantization_training.py– end-to-end training loop demonstrating NF4/INT8/Binary quantization with the unified API.env.py– short script showing NF4 straight-through training and inference using implicit arrays.QUANTIZATION.txt– quick-reference sheet for supported quantization modes.docs/pytree_utils.md– catalog of every PyTree helper with explanations.docs/api_docs/*.rst– per-module API descriptions used by Sphinx.
Run the example locally:
python examples/quantization_training.py
Documentation
Hosted docs: https://eformer.readthedocs.org
Build the Sphinx site locally:
pip install -r docs/requirements.txt
make -C docs html
# open docs/_build/html/index.html
docs/index.rst is the landing page, and the api_docs/ folder mirrors the Python package layout so you can quickly locate functions/classes.
Testing & Quality
Unit tests cover key areas such as PyTree utilities, optimizer factory logic, and quantization kernels (tests/test_*.py). To run them:
pip install -e '.[dev]'
pytest
The repository also contains formatter/linter configurations:
ruff check .
black --check .
Feel free to wire these commands into pre-commit hooks or your CI. uv run pytest works out of the box if you prefer uv's virtual environments.
Contributing
Contributions are welcome! Please read CONTRIBUTING.md and follow the Apache Code of Conduct. If you plan to work on distributed/TPU features, include repro steps or environment notes in the PR so we can validate them.
- Report bugs / feature requests via GitHub issues.
- Keep PRs focused, include tests where possible, and respect existing formatting rules (Black line length 121, Ruff config in
pyproject.toml). - See
CHANGES.txtfor release notes andQUANTIZATION.txtfor design background.
License
Licensed under the Apache License 2.0. Portions of the executor/cluster utilities build upon the excellent work in the Stanford CRFM Levanter project; see file headers for details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file eformer-0.0.90.tar.gz.
File metadata
- Download URL: eformer-0.0.90.tar.gz
- Upload date:
- Size: 271.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.9.21 {"installer":{"name":"uv","version":"0.9.21","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"22.04","id":"jammy","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e2083478fbd7e3d7ac72696a0ae57ab3abe3707fd87414cb74fd1671ff55ad91
|
|
| MD5 |
3f5db78d9d7f123e63bc7f847c78e87d
|
|
| BLAKE2b-256 |
6402d01c4240103207fb1efafffc2b7a6265085930a8d2595d0a1906e3f0728c
|
File details
Details for the file eformer-0.0.90-py3-none-any.whl.
File metadata
- Download URL: eformer-0.0.90-py3-none-any.whl
- Upload date:
- Size: 332.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.9.21 {"installer":{"name":"uv","version":"0.9.21","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"22.04","id":"jammy","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
3b57717a9ebc763844fd09927c00997926fbea84dba07c3a74f499e46b140cd0
|
|
| MD5 |
27622d66f42b3bce707110d13eb693ce
|
|
| BLAKE2b-256 |
a24c3c0290824f01a11e52ef82de62f45f07ac8b26206ceccb74c2afefc08273
|