Skip to main content

Embark on a journey of paralleled/unparalleled computational prowess with eformer in JAX

Project description

eformer (EasyDel Former)

License Python JAX

eformer (EasyDel Former) is a utility library designed to simplify and enhance the development of machine learning models using JAX. It provides a collection of tools for sharding, custom PyTrees, quantization, and optimized operations, making it easier to build and scale models efficiently.

Features

  • Sharding Utilities (escale): Tools for efficient sharding and distributed computation in JAX.
  • Custom PyTrees (jaximus): Enhanced utilities for creating custom PyTrees and ArrayValue objects, updated from Equinox.
  • Custom Calling (callib): A tool for custom function calls and direct integration with Triton kernels in JAX.
  • Optimizer Factory: A flexible factory for creating and configuring optimizers like AdamW, Adafactor, Lion, and RMSProp.
  • Custom Operations and Kernels:
    • Flash Attention 2 for GPUs/TPUs (via Triton and Pallas).
    • 8-bit and NF4 quantization for efficient model.
  • Quantization Support: Tools for 8-bit and NF4 quantization, enabling memory-efficient model deployment.

Installation

You can install eformer via pip:

pip install eformer

Quick Start

Customizing Arrays With ArrayValue

import jax

from eformer.jaximus import ArrayValue, implicit
from eformer.ops.quantization.quantization_functions import (
 dequantize_row_q8_0,
 quantize_row_q8_0,
)

array = jax.random.normal(jax.random.key(0), (256, 64), "f2")


class Array8B(ArrayValue):
 scale: jax.Array
 weight: jax.Array

 def __init__(self, array: jax.Array):
  self.weight, self.scale = quantize_row_q8_0(array)

 def materialize(self):
  return dequantize_row_q8_0(self.weight, self.scale)


qarray = Array8B(array)


@jax.jit
@implicit
def sqrt(x):
 return jax.numpy.sqrt(x)


print(sqrt(qarray))
print(qarray)

Optimizer Factory

from eformer.optimizers import OptimizerFactory, SchedulerConfig, AdamWConfig

# Create an AdamW optimizer with a cosine scheduler
scheduler_config = SchedulerConfig(scheduler_type="cosine", learning_rate=1e-3, steps=1000)
optimizer, scheduler = OptimizerFactory.create("adamw", scheduler_config, AdamWConfig())

Quantization

from eformer.quantization import Array8B, ArrayNF4

# Quantize an array to 8-bit
qarray = Array8B(jax.random.normal(jax.random.key(0), (256, 64), "f2"))

# Quantize an array to NF4
n4array = ArrayNF4(jax.random.normal(jax.random.key(0), (256, 64), "f2"), 64)

Contributing

We welcome contributions! Please read our Contributing Guidelines to get started.

License

This project is licensed under the Apache License 2.0. See the LICENSE file for details.

Project details


Release history Release notifications | RSS feed

This version

0.0.1

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

eformer-0.0.1.tar.gz (56.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

eformer-0.0.1-py3-none-any.whl (80.2 kB view details)

Uploaded Python 3

File details

Details for the file eformer-0.0.1.tar.gz.

File metadata

  • Download URL: eformer-0.0.1.tar.gz
  • Upload date:
  • Size: 56.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.5 CPython/3.10.12 Linux/6.8.0-52-generic

File hashes

Hashes for eformer-0.0.1.tar.gz
Algorithm Hash digest
SHA256 cd1f3d74cc1a4103506834d94bd54e36947523b5c38aec38349f652a6143ddb2
MD5 3ab1cf19340348fafe0bae1a50936136
BLAKE2b-256 4d73bf33f7610bb4dbb62dc1ca8d2147dfc62785be490d0a7992b70f811a17e0

See more details on using hashes here.

File details

Details for the file eformer-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: eformer-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 80.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.5 CPython/3.10.12 Linux/6.8.0-52-generic

File hashes

Hashes for eformer-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 0cfdfe597d582331e7fdb0d5ba21613c1d75279e6e66e460a1102bdb9c5abe38
MD5 4bc4009262348f1d8b6281dc376c5d6d
BLAKE2b-256 24235d93ae4299ea97f1bff52290576df05a1eaafd9ec8b0ea9112e836ce3730

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page