Skip to main content

(EasyDel Former) is a utility library designed to simplify and enhance the development in JAX

Project description

eformer (EasyDel Former)

License Python JAX

eformer (EasyDel Former) is a utility library designed to simplify and enhance the development of machine learning models using JAX. It provides a collection of tools for sharding, custom PyTrees, quantization, mixed precision training, and optimized operations, making it easier to build and scale models efficiently.

Features

  • Mixed Precision Training (mpric): Advanced mixed precision utilities supporting float8, float16, and bfloat16 with dynamic loss scaling.
  • Sharding Utilities (escale): Tools for efficient sharding and distributed computation in JAX.
  • Custom PyTrees (jaximus): Enhanced utilities for creating custom PyTrees and ArrayValue objects, updated from Equinox.
  • Custom Calling (callib): A tool for custom function calls and direct integration with Triton kernels in JAX.
  • Optimizer Factory: A flexible factory for creating and configuring optimizers like AdamW, Adafactor, Lion, and RMSProp.
  • Custom Operations and Kernels:
    • Flash Attention 2 for GPUs/TPUs (via Triton and Pallas).
    • 8-bit and NF4 quantization for efficient model.
    • Many others to be added.
  • Quantization Support: Tools for 8-bit and NF4 quantization, enabling memory-efficient model deployment.

Installation

You can install eformer via pip:

pip install eformer

Quick Start

Mixed Precision Handler with mpric

from eformer.mpric import PrecisionHandler

# Create a handler with float8 compute precision
handler = PrecisionHandler(
    policy="p=f32,c=f8_e4m3,o=f32",  # params in f32, compute in float8, output in f32
    use_dynamic_scale=True
)

Customizing Arrays With ArrayValue

import jax

from eformer.jaximus import ArrayValue, implicit
from eformer.ops.quantization.quantization_functions import (
    dequantize_row_q8_0,
    quantize_row_q8_0,
)

array = jax.random.normal(jax.random.key(0), (256, 64), "f2")


class Array8B(ArrayValue):
    scale: jax.Array
    weight: jax.Array

    def __init__(self, array: jax.Array):
        self.weight, self.scale = quantize_row_q8_0(array)

    def materialize(self):
        return dequantize_row_q8_0(self.weight, self.scale)


qarray = Array8B(array)


@jax.jit
@implicit
def sqrt(x):
    return jax.numpy.sqrt(x)


print(sqrt(qarray))
print(qarray)

Optimizer Factory

from eformer.optimizers import OptimizerFactory, SchedulerConfig, AdamWConfig

# Create an AdamW optimizer with a cosine scheduler
scheduler_config = SchedulerConfig(scheduler_type="cosine", learning_rate=1e-3, steps=1000)
optimizer, scheduler = OptimizerFactory.create("adamw", scheduler_config, AdamWConfig())

Quantization

from eformer.quantization import Array8B, ArrayNF4

# Quantize an array to 8-bit
qarray = Array8B(jax.random.normal(jax.random.key(0), (256, 64), "f2"))

# Quantize an array to NF4
n4array = ArrayNF4(jax.random.normal(jax.random.key(0), (256, 64), "f2"), 64)

Advanced Mixed Precision Configuration

from eformer.mpric import Policy, LossScaleConfig

# Create a custom precision policy
policy = Policy(
    param_dtype=jnp.float32,
    compute_dtype=jnp.bfloat16,
    output_dtype=jnp.float32
)

# Configure loss scaling
loss_config = LossScaleConfig(
    initial_scale=2**15,
    growth_interval=2000,
    scale_factor=2,
    min_scale=1.0
)

# Create handler with custom configuration
handler = PrecisionHandler(
    policy=policy,
    use_dynamic_scale=True,
    loss_scale_config=loss_config
)

Contributing

We welcome contributions! Please read our Contributing Guidelines to get started.

License

This project is licensed under the Apache License 2.0. See the LICENSE file for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

eformer-0.0.18.1.tar.gz (80.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

eformer-0.0.18.1-py3-none-any.whl (110.0 kB view details)

Uploaded Python 3

File details

Details for the file eformer-0.0.18.1.tar.gz.

File metadata

  • Download URL: eformer-0.0.18.1.tar.gz
  • Upload date:
  • Size: 80.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.4 CPython/3.10.12 Linux/5.15.167.4-microsoft-standard-WSL2

File hashes

Hashes for eformer-0.0.18.1.tar.gz
Algorithm Hash digest
SHA256 ffc4107d3b99e3a8383ab3b1793f0d4359c60ed0c36f09e7216ef65156776fad
MD5 30f1a2d2f7052f8167bb47f0ee7ca50c
BLAKE2b-256 6ae4c81f7449f1a141ee1aa7529eb5015f8d616bfedb13a07e480ddc255188db

See more details on using hashes here.

File details

Details for the file eformer-0.0.18.1-py3-none-any.whl.

File metadata

  • Download URL: eformer-0.0.18.1-py3-none-any.whl
  • Upload date:
  • Size: 110.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.4 CPython/3.10.12 Linux/5.15.167.4-microsoft-standard-WSL2

File hashes

Hashes for eformer-0.0.18.1-py3-none-any.whl
Algorithm Hash digest
SHA256 91d361f895a9b65e0774433305b22f4d80831b35130329de5f53083de0163a37
MD5 35002ed021c4e3dd4333f9623bd18852
BLAKE2b-256 93238fe651f408654f76a6d09db19649c77be09d54b0b3b8b70e12c3dcf88a05

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page