Skip to main content

Embark on a journey of paralleled/unparalleled computational prowess with FJFormer in JAX

Project description

FJFormer

PyPI version Documentation Status License

FJFormer is a powerful and flexible JAX-based package designed to accelerate and simplify machine learning and deep learning workflows. It provides a comprehensive suite of tools and utilities for efficient model development, training, and deployment.

Features

1. JAX Sharding Utils

Leverage the power of distributed computing and model parallelism with our advanced JAX sharding utilities. These tools enable efficient splitting and management of large models across multiple devices, enhancing performance and enabling the training of larger models.

2. Custom Pallas / Triton Operation Kernels

Boost your model's performance with our optimized kernels for specific operations. These custom-built kernels, implemented using Pallas and Triton, provide significant speedups for common bottleneck operations in deep learning models.

3. Pre-built Optimizers

Jump-start your training with our collection of ready-to-use, efficiently implemented optimization algorithms:

  • AdamW: An Adam variant with decoupled weight decay.
  • Adafactor: Memory-efficient adaptive optimization algorithm.
  • Lion: Recently proposed optimizer combining the benefits of momentum and adaptive methods.
  • RMSprop: Adaptive learning rate optimization algorithm.

4. Utility Functions

A rich set of utility functions to streamline your workflow, including:

  • Various loss functions (e.g., cross-entropy)
  • Metrics calculation
  • Data preprocessing tools

5. ImplicitArray

Our innovative ImplicitArray class provides a powerful abstraction for representing and manipulating large arrays without instantiation. Benefits include:

  • Lazy evaluation for memory efficiency
  • Optimized array operations in JAX
  • Seamless integration with other FJFormer components

6. Custom Dtypes

  • Implement 4-bit quantization (NF4) effortlessly using our ArrayNF4 class, built on top of ImplicitArray. Reduce model size and increase inference speed without significant loss in accuracy (from QLoRA paper).

  • Similar to ArrayNF4, our Array8Lt implementation offers 8-bit quantization via ImplicitArray, providing a balance between model compression and precision.

7. LoRA (Low-Rank Adaptation)

Efficiently fine-tune large language models with our LoRA implementation, leveraging ImplicitArray for optimal performance and memory usage.

8. JAX and Array Manipulation

A comprehensive set of tools and utilities for efficient array operations and manipulations in JAX, designed to complement and extend JAX's native capabilities.

9. Checkpoint Managers

Robust utilities for managing model checkpoints, including:

  • Efficient saving and loading of model states
  • Version control for checkpoints
  • Integration with distributed training workflows

Installation

You can install FJFormer using pip:

pip install fjformer

For the latest development version, you can install directly from GitHub:

pip install git+https://github.com/yourusername/fjformer.git

Documentation

For detailed documentation, including API references, please visit:

https://fjformer.readthedocs.org

License

FJFormer is released under the Apache License 2.0. See the LICENSE file for more details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

fjformer-0.0.86.tar.gz (122.5 kB view details)

Uploaded Source

Built Distribution

fjformer-0.0.86-py3-none-any.whl (171.6 kB view details)

Uploaded Python 3

File details

Details for the file fjformer-0.0.86.tar.gz.

File metadata

  • Download URL: fjformer-0.0.86.tar.gz
  • Upload date:
  • Size: 122.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.4 CPython/3.10.12 Linux/6.8.0-48-generic

File hashes

Hashes for fjformer-0.0.86.tar.gz
Algorithm Hash digest
SHA256 efef11476c7938255d4e4e4984a9bf9a9b4ead6ae16020a5ec2730f39e2e3a29
MD5 5402cda96b571316e08a3f7035324154
BLAKE2b-256 41d04bfce2bf44aab9cf66ba1ddfb9ea78644356ea56f04111edec88b570246e

See more details on using hashes here.

File details

Details for the file fjformer-0.0.86-py3-none-any.whl.

File metadata

  • Download URL: fjformer-0.0.86-py3-none-any.whl
  • Upload date:
  • Size: 171.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.4 CPython/3.10.12 Linux/6.8.0-48-generic

File hashes

Hashes for fjformer-0.0.86-py3-none-any.whl
Algorithm Hash digest
SHA256 7ec4b3b4298568f665d89cf8d4eaf637d19306545a77868dafc08ee65fa2c2bc
MD5 52fe45a9fcfb2144453749bd4a84cf74
BLAKE2b-256 d8bc24cdac4c84b5db4a7d24b8cc0198cfd111a2d2c59c79d7567fd3d1bdd03e

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page