Skip to main content

A minimal, explicit, and functional neural network library for JAX.

Project description

blox logo

blox

A lightweight, strictly functional neural network library for JAX.

blox is released under the MIT license Python 3.10+ JAX 0.4+

blox embraces JAX's functional paradigm without the headache.

It provides a minimal, object-oriented layer solely for organizing your code, while strictly enforcing functional state management and explicit data flow. By stripping away the "magic" found in other frameworks—like implicit context managers, thread-local storage, and global state—blox ensures your code remains side-effect free, transparent, and trivially compatible with JAX's powerful transformations.

No wrappers needed. Because there is no hidden state, jax.jit, jax.grad, and jax.vmap work right out of the box.

⚡ Core Principles

  • Functional purity: Models are just stateless transformations. Parameters and RNG state are passed explicitly as arguments (params), never stored in self.
  • Explicit data flow: No hidden global context. You can trace the path of every single tensor just by reading the function signature.
  • Structural RNG: Random keys are derived deterministically from the graph structure. Say goodbye to manually threading keys through every layer ("refactoring hell"); blox handles the math while keeping your functions pure.
  • Visualizable: Comes with out-of-the-box Treescope integration for beautiful, interactive visualization of your model's architecture and parameters.

📦 Installation

git clone https://github.com/hamzamerzic/blox.git
cd blox
pip install -e .

🚀 Quick Start

In blox, a module is just a structural container (__init__) and a pure mathematical function (__call__).

Define your layers

Notice the signature: params carries the state (weights + RNG), while inputs is your data.

import jax
import jax.numpy as jnp
import blox as bx

class CustomLinear(bx.Module):

  def __init__(
      self,
      graph: bx.Graph,
      output_size: int,
  ) -> None:
    super().__init__(graph)
    self.output_size = output_size

  def __call__(
      self,
      params: bx.Params,
      inputs: jax.Array,
  ) -> tuple[jax.Array, bx.Params]:
    # Request parameters explicitly from the container.
    # The RNG key is automatically derived from the graph path.
    w_shape = (inputs.shape[-1], self.output_size)
    w, params = self.get_param(
        params, 'w', w_shape, jax.nn.initializers.glorot_uniform()
    )
    b_shape = (self.output_size,)
    b, params = self.get_param(
        params, 'b', b_shape, jax.nn.initializers.zeros
    )
    return inputs @ w + b, params

Composition & Dependency Injection

Because blox modules are just standard Python objects, dependency injection is a breeze.

Instead of hardcoding layers, you can create modules outside and pass them in. This changes the graph hierarchy: the injected module keeps its original path (it's a "sibling"), while internal layers become children.

class CustomMLP(bx.Module):

  def __init__(
      self,
      graph: bx.Graph,
      hidden_size: int,
      # Inject a pre-built module instance.
      output_projection: bx.Module,
  ) -> None:
    super().__init__(graph)
    # Internal layer: We create it here, so it lives in our scope.
    self.hidden_proj = CustomLinear(graph.child('hidden'), hidden_size)
    
    # Injected layer: It was created outside, so we just store the reference.
    self.output_projection = output_projection

  def __call__(
      self,
      params: bx.Params,
      inputs: jax.Array,
  ) -> tuple[jax.Array, bx.Params]:
    # Chain the functional transformations.
    x, params = self.hidden_proj(params, inputs)
    x = jax.nn.relu(x)
    
    # The output projection knows where to find its own params in the container.
    outputs, params = self.output_projection(params, x)
    return outputs, params

Initialization & Visualization

We cleanly separate the "Initialization phase" (traversing the graph to create parameters) from the "Runtime phase" (training the parameters).

# Define the structure (Wiring).
graph = bx.Graph('net')

# Create the output layer explicitly at the root level ('net/readout').
readout = CustomLinear(graph.child('readout'), output_size=1)

# Pass it into the MLP. 
# The MLP lives at 'net/mlp', but it uses 'readout' which lives at 'net/readout'.
model = CustomMLP(graph.child('mlp'), hidden_size=32, output_projection=readout)

# Create Data and Seed.
inputs = jnp.ones((1, 10))
params = bx.Params(seed=42)

# Initialization Pass.
# We run the model once to populate the params container.
outputs, params = model(params, inputs)

# Finalize initialization.
# This prevents further changes to the parameter structure (like accidentally 
# adding new parameters after initialization).
params = params.finalize()

# Visualize.
bx.display(graph, params)

Output: Notice how readout and mlp are siblings in the graph, while hidden is nested inside mlp.

net: Graph # Param: 385 (1.5 KB)(
  rng=Param[N](
    shape=(2,),
    dtype=object,
    metadata={'tag': 'rng'},
    value=(<jax.Array...>, <jax.Array...>)
  ),
  readout=CustomLinear # Param: 33 (132.0 B)(
    output_size=1,
    w=Param[T](value=<jax.Array...>),
    b=Param[T](value=<jax.Array...>)
  ),
  mlp=CustomMLP # Param: 352 (1.4 KB)(
    hidden=CustomLinear # Param: 352 (1.4 KB)(
      output_size=32,
      w=Param[T](value=<jax.Array...>),
      b=Param[T](value=<jax.Array...>)
    )
  )
)

⚡ Training (JIT & Gradients)

The Params container holds everything: weights, biases, RNG state, batch norm statistics, and EMA moving averages.

When training, we usually want to differentiate with respect to weights, but we still need to update the non-trainable state (like the RNG counter or batch statistics). blox makes this partitioning explicit.

@jax.jit
def train_step(params, inputs, targets):
  # Split params into two sets.
  # Trainable: weights, biases (we want gradients for these).
  # Non-trainable: RNG, batch stats, EMA (we just want the updated values).
  trainable, non_trainable = params.split_trainable()

  def loss_fn(t, nt):
    # Merge parameters to run the forward pass.
    predictions, new_params = model(t.merge(nt), inputs)

    # Calculate the loss.
    loss = jnp.mean((predictions - targets) ** 2)

    # Extract the updated non-trainable state to pass it out.
    _, new_non_trainable = new_params.split_trainable()
    return loss, new_non_trainable

  # Calculate gradients and capture the auxiliary state.
  grads, new_non_trainable = jax.grad(loss_fn, has_aux=True)(
      trainable, non_trainable
  )

  # Update the trainable weights using SGD.
  new_trainable = jax.tree.map(lambda w, g: w - 0.01 * g, trainable, grads)

  # Merge the updated weights with the updated non-trainable state.
  return new_trainable.merge(new_non_trainable)

🧠 Under the Hood (No Magic)

blox is designed to be fully transparent. The "abstraction" is really just automated path handling to keep your code clean and your state pure.

The Graph This acts as a Path Builder. It is a lightweight object that represents a location in the model hierarchy (e.g., net/mlp/dense1). When you call graph.child('name'), it appends to the path. This ensures that every module has a unique address space for its variables.

The Params This is a Secure Vault. It holds all weights, biases, and RNG states in a single, flat, immutable dictionary keyed by the paths generated by the Graph (e.g., "net/mlp/dense1/w"). It provides methods to partition state (for gradients) and merge it back (for updates).

The RNG Handling randomness in pure functional programming can be painful. Instead of manually threading key arguments through every single layer, Params maintains a master key and a counter.

  • When a module needs randomness (e.g., initialization or dropout), it asks Params for a key.
  • Params uses jax.random.fold_in(master_key, counter) to generate a deterministic, unique key for that specific call.
  • It increments the counter and returns a new Params object.
  • This guarantees that your model is mathematically reproducible and parallel-safe without the boilerplate.

⚖️ Why blox?

blox chooses clarity over brevity.

Most frameworks rely on implicit global state or thread-local contexts to save you from passing arguments. This works great until you need to debug a side-effect, use a transformation the framework wasn't designed for, or inspect the state mid-execution.

Standard Frameworks blox
out = layer(x) out, params = layer(params, inputs)
Implicit global context Explicit state passing
Hidden variable scopes Explicit bx.Graph paths
Custom jit / vmap wrappers Standard jax.jit / jax.vmap

By accepting slightly more verbose function signatures, you gain:

  1. Total transparency: You know exactly what data your function touches.
  2. JIT safety: It is impossible to leak tracers or capture side-effects, as there is no global state.
  3. Performance: The library compiles down to the exact same XLA kernels as raw JAX code.

📄 License

MIT License. See LICENSE for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_blox-0.1.1.post1.tar.gz (5.0 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax_blox-0.1.1.post1-py3-none-any.whl (17.8 kB view details)

Uploaded Python 3

File details

Details for the file jax_blox-0.1.1.post1.tar.gz.

File metadata

  • Download URL: jax_blox-0.1.1.post1.tar.gz
  • Upload date:
  • Size: 5.0 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.1

File hashes

Hashes for jax_blox-0.1.1.post1.tar.gz
Algorithm Hash digest
SHA256 ccfbd4e2566b5e341835b9be5f73ae57e39fa385a1935842b4d7c2615b3e3556
MD5 bd166133434f715f96411d09d00871d4
BLAKE2b-256 d1f0fa89972b65ed5a1024a555e820db36c739ffd3842df1b1df02e74b5c8336

See more details on using hashes here.

File details

Details for the file jax_blox-0.1.1.post1-py3-none-any.whl.

File metadata

  • Download URL: jax_blox-0.1.1.post1-py3-none-any.whl
  • Upload date:
  • Size: 17.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.1

File hashes

Hashes for jax_blox-0.1.1.post1-py3-none-any.whl
Algorithm Hash digest
SHA256 356f1cf0a07f23555de7e32c05165d9993ca01140846b94525bf700d68280c75
MD5 52842fd57054d19164d8dbd8435f543d
BLAKE2b-256 6e8d0abc2dc2f186b7bf290698bae0d052b3ed80e7dcf72ca72b9bbd69b0b87f

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page