Skip to main content

An extension for modular JAX code.

Project description

+-------------------------------------+
|       __              __      __    |
|      /_/\            / /\    /_/\   |
|     /_/  \          / /  \  /_/ /   |
|    /_/ /\ \        / / /\ \/_/ /    |
|   /_/ /\_\ \      / / /\_\ \/ /     |
|  /_/ /  \_\ \    / / /__\_\  /      |
|  \_\ \  /_/ /   / / ______   \      |
|   \_\ \/_/ /   / / /   /_/ /\ \     |
|    \_\ \/_/___/ / /   /_/ /\_\ \    |
|     \_\/_______/ /   /_/ /  \_\ \   |
|      \_\_\_\_\_\/    \_\/    \_\/   |
|                                     |
+-------------------------------------+

Github | Documentation

What is OJAX

OJAX is a small extension of JAX to facilitate modular coding.

You might have already noticed, due to its functional nature, JAX does not pair well with the generic Python class structure. People tend to instead write closures/functionals which are functions that return JAX functions (e.g. the Stax NN library from the JAX codebase), which is far from ideal for complex projects.

OJAX lets you write JAX code using class again, with full JAX compatibility, max flexibility, and zero worry. Here is an example custom class using OJAX that can be directly jax.jitted:

class AddWith(ojax.OTree):
    value: float

    def __call__(self, t: jax.Array) -> jax.Array:
        return t + self.value


add42_jitted = jax.jit(AddWith(42.0))
print(add42_jitted(jax.numpy.ones(1)))  # [43.]

OJAX is a simple library that needs less than 1 hour to learn, and will save you countless hours for your JAX projects!

Why OJAX

"Library XXX already did something similar, why reinvent the wheel?"

The short answer is: because the wheel is rounder this time ;)

Motivated by deep learning applications, there are many JAX libraries that already propose some kind of module system: Flax, Equinox, Haiku, Simple Pytree, Treeo / Treex, PAX, just to name a few.

However, none of them offers a perfect “JAX base class” that fulfills all of the desiderata below:

  • Simple to understand and use

  • Flexible custom classes for general JAX computation

  • Compatible with JAX and its functional paradigm

OJAX strives to define how a JAX base class should be. It provides a natural way to structure custom JAX code and discourages users from common pitfalls.

P.S.: the name “OJAX” is a chapeau-bas to OCaml, an awesome functional programming language.

How to code with OJAX

OJAX is easy to install following the installation guide.

You can have a look at the quickstart section to get started, and there is also a simple example code using OJAX.

Of course, check out the OJAX API reference for exact definitions.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ojax-4.0.1.tar.gz (16.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

ojax-4.0.1-py3-none-any.whl (12.5 kB view details)

Uploaded Python 3

File details

Details for the file ojax-4.0.1.tar.gz.

File metadata

  • Download URL: ojax-4.0.1.tar.gz
  • Upload date:
  • Size: 16.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.10.16

File hashes

Hashes for ojax-4.0.1.tar.gz
Algorithm Hash digest
SHA256 7778d1493f96619de9f6ca5780fd6047dec8ea1f6241db53b3af4da10397d86b
MD5 6f811798995c91a1dc37347433070cb9
BLAKE2b-256 989d8338c77090e1c8a466d09e088fb8f2e9c59492e69c3271c6ae7ac0e6b5d3

See more details on using hashes here.

File details

Details for the file ojax-4.0.1-py3-none-any.whl.

File metadata

  • Download URL: ojax-4.0.1-py3-none-any.whl
  • Upload date:
  • Size: 12.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.10.16

File hashes

Hashes for ojax-4.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 f12874b733e5e8e8bd8d8d2ec219f03505c5a23fb26fb53456564521b9c25530
MD5 8004687a37f620d55905da0cd9eaa824
BLAKE2b-256 c87b80e32cde0530138a16266b323ebc649539c571a4305b38f8955542d91b5d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page