An extension for modular JAX code.
Project description
+-------------------------------------+ | __ __ __ | | /_/\ / /\ /_/\ | | /_/ \ / / \ /_/ / | | /_/ /\ \ / / /\ \/_/ / | | /_/ /\_\ \ / / /\_\ \/ / | | /_/ / \_\ \ / / /__\_\ / | | \_\ \ /_/ / / / ______ \ | | \_\ \/_/ / / / / /_/ /\ \ | | \_\ \/_/___/ / / /_/ /\_\ \ | | \_\/_______/ / /_/ / \_\ \ | | \_\_\_\_\_\/ \_\/ \_\/ | | | +-------------------------------------+
What is OJAX
OJAX is a small extension of JAX to facilitate modular coding.
You might have already noticed, due to its functional nature, JAX does not pair well with the generic Python class structure. People tend to instead write closures/functionals which are functions that return JAX functions (e.g. the Stax NN library from the JAX codebase), which is far from ideal for complex projects.
OJAX lets you write JAX code using class again, with full JAX compatibility, max flexibility, and zero worry. Here is an example custom class using OJAX that can be directly jax.jitted:
class AddWith(ojax.OTree):
value: float
def __call__(self, t: jax.Array) -> jax.Array:
return t + self.value
add42_jitted = jax.jit(AddWith(42.0))
print(add42_jitted(jax.numpy.ones(1))) # [43.]
OJAX is a simple library that needs less than 1 hour to learn, and will save you countless hours for your JAX projects!
Why OJAX
"Library XXX already did something similar, why reinvent the wheel?"
The short answer is: because the wheel is rounder this time ;)
Motivated by deep learning applications, there are many JAX libraries that already propose some kind of module system: Flax, Equinox, Haiku, Simple Pytree, Treeo / Treex, PAX, just to name a few.
However, none of them offers a perfect “JAX base class” that fulfills all of the desiderata below:
Simple to understand and use
Flexible custom classes for general JAX computation
Compatible with JAX and its functional paradigm
OJAX strives to define how a JAX base class should be. It provides a natural way to structure custom JAX code and discourages users from common pitfalls.
P.S.: the name “OJAX” is a chapeau-bas to OCaml, an awesome functional programming language.
How to code with OJAX
OJAX is easy to install following the installation guide.
You can have a look at the quickstart section to get started, and there is also a simple example code using OJAX.
Of course, check out the OJAX API reference for exact definitions.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.