Skip to main content

(EasyDel Former) is a utility library designed to simplify and enhance the development in JAX

Project description

eformer (EasyDel Former)

License Python JAX PyPI version

eformer (EasyDel Former) is a utility library designed to simplify and enhance the development of machine learning models using JAX. It provides a comprehensive collection of tools for distributed computing, custom data structures, numerical optimization, and high-performance operations. Eformer aims to make it easier to build, scale, and optimize models efficiently while leveraging JAX's capabilities for high-performance computing.

Project Structure Overview

The library is organized into several core modules:

  • aparser: Advanced argument parsing utilities with dataclass integration
  • callib: Custom function calling and Triton kernel integration
  • common_types: Shared type definitions and sharding constants
  • escale: Distributed sharding and parallelism utilities
  • executor: Execution management and hardware-specific optimizations
  • jaximus: Custom PyTree implementations and structured array utilities
  • mpric: Mixed precision training and dynamic scaling infrastructure
  • ops: Optimized operations including Flash Attention and quantization
  • optimizers: Flexible optimizer configuration and factory patterns
  • pytree: Enhanced tree manipulation and transformation utilities

Key Features

1. Mixed Precision Training (mpric)

Advanced mixed precision utilities supporting float8, float16, and bfloat16 with dynamic loss scaling, enabling faster training and reduced memory footprint.

2. Distributed Sharding (escale)

Tools for efficient sharding and distributed computation in JAX, allowing you to scale your models across multiple devices with various sharding strategies:

  • Data Parallelism (DP)
  • Fully Sharded Data Parallel (FSDP)
  • Tensor Parallelism (TP)
  • Expert Parallelism (EP)
  • Sequence Parallelism (SP)

3. Custom PyTrees (jaximus)

Enhanced utilities for creating custom PyTrees and ArrayValue objects, updated from Equinox, providing flexible data structures for your models.

4. Triton Integration (callib)

Custom function calling utilities with direct integration of Triton kernels in JAX, allowing you to optimize performance-critical operations.

5. Optimizer Factory

A flexible factory for creating and configuring optimizers like AdamW, Adafactor, Lion, and RMSProp, making it easy to experiment with different optimization strategies.

6. Optimized Operations (ops)

  • Flash Attention 2 implementation for GPUs/TPUs (via Triton and Pallas) for faster attention computations
  • 8-bit and NF4 quantization for efficient model deployment
  • Additional optimized operations under active development

API Documentation

For detailed API references and usage examples, see:

Installation

You can install eformer via pip:

pip install eformer

Getting Started

Mixed Precision Handler with mpric

from eformer.mpric import PrecisionHandler

# Create a handler with float8 compute precision
handler = PrecisionHandler(
    policy="p=f32,c=f8_e4m3,o=f32",  # params in f32, compute in float8, output in f32
    use_dynamic_scale=True
)

Custom PyTree Implementation

import jax
from eformer.jaximus import ArrayValue, implicit
from eformer.ops.quantization.quantization_functions import dequantize_row_q8_0, quantize_row_q8_0

class Array8B(ArrayValue):
    scale: jax.Array
    weight: jax.Array
    
    def __init__(self, array: jax.Array):
        self.weight, self.scale = quantize_row_q8_0(array)
    
    def materialize(self):
        return dequantize_row_q8_0(self.weight, self.scale)

array = jax.random.normal(jax.random.key(0), (256, 64), "f2")
qarray = Array8B(array)

Contributing

We welcome contributions! Please read our Contributing Guidelines to get started.

License

This project is licensed under the Apache License 2.0. See the LICENSE file for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

eformer-0.0.61.tar.gz (190.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

eformer-0.0.61-py3-none-any.whl (242.9 kB view details)

Uploaded Python 3

File details

Details for the file eformer-0.0.61.tar.gz.

File metadata

  • Download URL: eformer-0.0.61.tar.gz
  • Upload date:
  • Size: 190.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.8.0

File hashes

Hashes for eformer-0.0.61.tar.gz
Algorithm Hash digest
SHA256 256d9208b0b481a9fe7ddefe9901d6fba238a44407fdabb4df6864ee4525deee
MD5 6718b57f667b6bca8e6f2fdb1a586054
BLAKE2b-256 b523a77d3baca5c6075b10c50fbefe08a215052d6882c56f4d3a4d4def267843

See more details on using hashes here.

File details

Details for the file eformer-0.0.61-py3-none-any.whl.

File metadata

  • Download URL: eformer-0.0.61-py3-none-any.whl
  • Upload date:
  • Size: 242.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.8.0

File hashes

Hashes for eformer-0.0.61-py3-none-any.whl
Algorithm Hash digest
SHA256 183db4d1f070a6fd8ff864cb60a6cfc05e66267e0ba3bc1e44da12ee6b8773a0
MD5 8ce625943ae2296e293d4f3f9ee497b3
BLAKE2b-256 f67df5f5f10249fca419054ba49be54b86b5a3fbc104c5e9b4bb639e29153586

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page