(EasyDel Former) is a utility library designed to simplify and enhance the development in JAX
Project description
eformer (EasyDel Former)
eformer (EasyDel Former) is a utility library designed to simplify and enhance the development of machine learning models using JAX. It provides a collection of tools for sharding, custom PyTrees, quantization, mixed precision training, and optimized operations, making it easier to build and scale models efficiently.
Features
- Mixed Precision Training (
mpric): Advanced mixed precision utilities supporting float8, float16, and bfloat16 with dynamic loss scaling. - Sharding Utilities (
escale): Tools for efficient sharding and distributed computation in JAX. - Custom PyTrees (
jaximus): Enhanced utilities for creating custom PyTrees andArrayValueobjects, updated from Equinox. - Custom Calling (
callib): A tool for custom function calls and direct integration with Triton kernels in JAX. - Optimizer Factory: A flexible factory for creating and configuring optimizers like AdamW, Adafactor, Lion, and RMSProp.
- Custom Operations and Kernels:
- Flash Attention 2 for GPUs/TPUs (via Triton and Pallas).
- 8-bit and NF4 quantization for efficient model.
- Many others to be added.
- Quantization Support: Tools for 8-bit and NF4 quantization, enabling memory-efficient model deployment.
Installation
You can install eformer via pip:
pip install eformer
Quick Start
Mixed Precision Handler with mpric
from eformer.mpric import PrecisionHandler
# Create a handler with float8 compute precision
handler = PrecisionHandler(
policy="p=f32,c=f8_e4m3,o=f32", # params in f32, compute in float8, output in f32
use_dynamic_scale=True
)
Customizing Arrays With ArrayValue
import jax
from eformer.jaximus import ArrayValue, implicit
from eformer.ops.quantization.quantization_functions import (
dequantize_row_q8_0,
quantize_row_q8_0,
)
array = jax.random.normal(jax.random.key(0), (256, 64), "f2")
class Array8B(ArrayValue):
scale: jax.Array
weight: jax.Array
def __init__(self, array: jax.Array):
self.weight, self.scale = quantize_row_q8_0(array)
def materialize(self):
return dequantize_row_q8_0(self.weight, self.scale)
qarray = Array8B(array)
@jax.jit
@implicit
def sqrt(x):
return jax.numpy.sqrt(x)
print(sqrt(qarray))
print(qarray)
Optimizer Factory
from eformer.optimizers import OptimizerFactory, SchedulerConfig, AdamWConfig
# Create an AdamW optimizer with a cosine scheduler
scheduler_config = SchedulerConfig(scheduler_type="cosine", learning_rate=1e-3, steps=1000)
optimizer, scheduler = OptimizerFactory.create("adamw", scheduler_config, AdamWConfig())
Quantization
from eformer.quantization import Array8B, ArrayNF4
# Quantize an array to 8-bit
qarray = Array8B(jax.random.normal(jax.random.key(0), (256, 64), "f2"))
# Quantize an array to NF4
n4array = ArrayNF4(jax.random.normal(jax.random.key(0), (256, 64), "f2"), 64)
Advanced Mixed Precision Configuration
from eformer.mpric import Policy, LossScaleConfig
# Create a custom precision policy
policy = Policy(
param_dtype=jnp.float32,
compute_dtype=jnp.bfloat16,
output_dtype=jnp.float32
)
# Configure loss scaling
loss_config = LossScaleConfig(
initial_scale=2**15,
growth_interval=2000,
scale_factor=2,
min_scale=1.0
)
# Create handler with custom configuration
handler = PrecisionHandler(
policy=policy,
use_dynamic_scale=True,
loss_scale_config=loss_config
)
Contributing
We welcome contributions! Please read our Contributing Guidelines to get started.
License
This project is licensed under the Apache License 2.0. See the LICENSE file for details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file eformer-0.0.6.tar.gz.
File metadata
- Download URL: eformer-0.0.6.tar.gz
- Upload date:
- Size: 60.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/2.0.1 CPython/3.10.12 Linux/6.8.0-52-generic
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a14c132e869e11350f1057d5e71bcd78c92fa144293f961859360a61b721e5fc
|
|
| MD5 |
57e84b0426f1b2fdd4b630eb40091198
|
|
| BLAKE2b-256 |
c7dd442d830a01c9dc94412170c8855630d59519786d2eea16ded53f02a84b30
|
File details
Details for the file eformer-0.0.6-py3-none-any.whl.
File metadata
- Download URL: eformer-0.0.6-py3-none-any.whl
- Upload date:
- Size: 87.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/2.0.1 CPython/3.10.12 Linux/6.8.0-52-generic
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e48a50ff26d38244f296da3641863ecc414a2fda2de1f102f37aeed3523e8f95
|
|
| MD5 |
1d4adbcc3ef39ba05fe95a03b07c4697
|
|
| BLAKE2b-256 |
cbfa6f961da9407295b0a16e9548f36e76b6e72db44686d09a390e07e812cbb7
|