Skip to main content

A minimal, explicit, and functional neural network library for JAX.

Project description

blox logo

blox

A lightweight, strictly functional neural network library for JAX.

blox is released under the MIT license Python 3.10+ JAX 0.4+

blox embraces JAX's functional paradigm without the headache.

It provides a minimal, object-oriented layer solely for organizing your code, while strictly enforcing functional state management and explicit data flow. By stripping away the "magic" found in other frameworks—like implicit context managers, thread-local storage, and global state—blox ensures your code remains side-effect free, transparent, and trivially compatible with JAX's powerful transformations.

No wrappers needed. Because there is no hidden state, jax.jit, jax.grad, and jax.vmap work right out of the box.

⚡ Core Principles

  • Functional purity: Models are just stateless transformations. Parameters and RNG state are passed explicitly as arguments (params), never stored in self.
  • Explicit data flow: No hidden global context. You can trace the path of every single tensor just by reading the function signature.
  • Structural RNG: Random keys are derived deterministically from the graph structure. Say goodbye to manually threading keys through every layer ("refactoring hell"); blox handles the math while keeping your functions pure.
  • Visualizable: Comes with out-of-the-box Treescope integration for beautiful, interactive visualization of your model's architecture and parameters.

📦 Installation

git clone https://github.com/hamzamerzic/blox.git
cd blox
pip install -e .

🚀 Quick Start

In blox, a module is just a structural container (__init__) and a pure mathematical function (__call__).

Define your layers

Notice the signature: params carries the state (weights + RNG), while inputs is your data.

import jax
import jax.numpy as jnp
import blox as bx

class CustomLinear(bx.Module):

  def __init__(
      self,
      graph: bx.Graph,
      output_size: int,
  ) -> None:
    super().__init__(graph)
    self.output_size = output_size

  def __call__(
      self,
      params: bx.Params,
      inputs: jax.Array,
  ) -> tuple[jax.Array, bx.Params]:
    # Request parameters explicitly from the container.
    # The RNG key is automatically derived from the graph path.
    w_shape = (inputs.shape[-1], self.output_size)
    w, params = self.get_param(
        params, 'w', w_shape, jax.nn.initializers.glorot_uniform()
    )
    b_shape = (self.output_size,)
    b, params = self.get_param(
        params, 'b', b_shape, jax.nn.initializers.zeros
    )
    return inputs @ w + b, params

Composition & Dependency Injection

Because blox modules are just standard Python objects, dependency injection is a breeze.

Instead of hardcoding layers, you can create modules outside and pass them in. This changes the graph hierarchy: the injected module keeps its original path (it's a "sibling"), while internal layers become children.

class CustomMLP(bx.Module):

  def __init__(
      self,
      graph: bx.Graph,
      hidden_size: int,
      # Inject a pre-built module instance.
      output_projection: bx.Module,
  ) -> None:
    super().__init__(graph)
    # Internal layer: We create it here, so it lives in our scope.
    self.hidden_proj = CustomLinear(graph.child('hidden'), hidden_size)
    
    # Injected layer: It was created outside, so we just store the reference.
    self.output_projection = output_projection

  def __call__(
      self,
      params: bx.Params,
      inputs: jax.Array,
  ) -> tuple[jax.Array, bx.Params]:
    # Chain the functional transformations.
    x, params = self.hidden_proj(params, inputs)
    x = jax.nn.relu(x)
    
    # The output projection knows where to find its own params in the container.
    outputs, params = self.output_projection(params, x)
    return outputs, params

Initialization & Visualization

We cleanly separate the "Initialization phase" (traversing the graph to create parameters) from the "Runtime phase" (training the parameters).

# Define the structure (Wiring).
graph = bx.Graph('net')

# Create the output layer explicitly at the root level ('net/readout').
readout = CustomLinear(graph.child('readout'), output_size=1)

# Pass it into the MLP. 
# The MLP lives at 'net/mlp', but it uses 'readout' which lives at 'net/readout'.
model = CustomMLP(graph.child('mlp'), hidden_size=32, output_projection=readout)

# Create Data and Seed.
inputs = jnp.ones((1, 10))
params = bx.Params(seed=42)

# Initialization Pass.
# We run the model once to populate the params container.
outputs, params = model(params, inputs)

# Finalize initialization.
# This prevents further changes to the parameter structure (like accidentally 
# adding new parameters after initialization).
params = params.finalize()

# Visualize.
bx.display(graph, params)

Output: Notice how readout and mlp are siblings in the graph, while hidden is nested inside mlp.

net: Graph # Param: 385 (1.5 KB)(
  rng=Param[N](
    shape=(2,),
    dtype=object,
    metadata={'tag': 'rng'},
    value=(<jax.Array...>, <jax.Array...>)
  ),
  readout=CustomLinear # Param: 33 (132.0 B)(
    output_size=1,
    w=Param[T](value=<jax.Array...>),
    b=Param[T](value=<jax.Array...>)
  ),
  mlp=CustomMLP # Param: 352 (1.4 KB)(
    hidden=CustomLinear # Param: 352 (1.4 KB)(
      output_size=32,
      w=Param[T](value=<jax.Array...>),
      b=Param[T](value=<jax.Array...>)
    )
  )
)

⚡ Training (JIT & Gradients)

The Params container holds everything: weights, biases, RNG state, batch norm statistics, and EMA moving averages.

When training, we usually want to differentiate with respect to weights, but we still need to update the non-trainable state (like the RNG counter or batch statistics). blox makes this partitioning explicit.

@jax.jit
def train_step(params, inputs, targets):
  # Split params into two sets.
  # Trainable: weights, biases (we want gradients for these).
  # Non-trainable: RNG, batch stats, EMA (we just want the updated values).
  trainable, non_trainable = params.split_trainable()

  def loss_fn(t, nt):
    # Merge parameters to run the forward pass.
    predictions, new_params = model(t.merge(nt), inputs)

    # Calculate the loss.
    loss = jnp.mean((predictions - targets) ** 2)

    # Extract the updated non-trainable state to pass it out.
    _, new_non_trainable = new_params.split_trainable()
    return loss, new_non_trainable

  # Calculate gradients and capture the auxiliary state.
  grads, new_non_trainable = jax.grad(loss_fn, has_aux=True)(
      trainable, non_trainable
  )

  # Update the trainable weights using SGD.
  new_trainable = jax.tree.map(lambda w, g: w - 0.01 * g, trainable, grads)

  # Merge the updated weights with the updated non-trainable state.
  return new_trainable.merge(new_non_trainable)

🧠 Under the Hood (No Magic)

blox is designed to be fully transparent. The "abstraction" is really just automated path handling to keep your code clean and your state pure.

The Graph This acts as a Path Builder. It is a lightweight object that represents a location in the model hierarchy (e.g., net/mlp/dense1). When you call graph.child('name'), it appends to the path. This ensures that every module has a unique address space for its variables.

The Params This is a Secure Vault. It holds all weights, biases, and RNG states in a single, flat, immutable dictionary keyed by the paths generated by the Graph (e.g., "net/mlp/dense1/w"). It provides methods to partition state (for gradients) and merge it back (for updates).

The RNG Handling randomness in pure functional programming can be painful. Instead of manually threading key arguments through every single layer, Params maintains a master key and a counter.

  • When a module needs randomness (e.g., initialization or dropout), it asks Params for a key.
  • Params uses jax.random.fold_in(master_key, counter) to generate a deterministic, unique key for that specific call.
  • It increments the counter and returns a new Params object.
  • This guarantees that your model is mathematically reproducible and parallel-safe without the boilerplate.

⚖️ Why blox?

blox chooses clarity over brevity.

Most frameworks rely on implicit global state or thread-local contexts to save you from passing arguments. This works great until you need to debug a side-effect, use a transformation the framework wasn't designed for, or inspect the state mid-execution.

Standard Frameworks blox
out = layer(x) out, params = layer(params, inputs)
Implicit global context Explicit state passing
Hidden variable scopes Explicit bx.Graph paths
Custom jit / vmap wrappers Standard jax.jit / jax.vmap

By accepting slightly more verbose function signatures, you gain:

  1. Total transparency: You know exactly what data your function touches.
  2. JIT safety: It is impossible to leak tracers or capture side-effects, as there is no global state.
  3. Performance: The library compiles down to the exact same XLA kernels as raw JAX code.

📄 License

MIT License. See LICENSE for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_blox-0.1.1.tar.gz (5.0 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax_blox-0.1.1-py3-none-any.whl (17.3 kB view details)

Uploaded Python 3

File details

Details for the file jax_blox-0.1.1.tar.gz.

File metadata

  • Download URL: jax_blox-0.1.1.tar.gz
  • Upload date:
  • Size: 5.0 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.1

File hashes

Hashes for jax_blox-0.1.1.tar.gz
Algorithm Hash digest
SHA256 9be1e3cce01eed9deb382874fbd935e3d09ab8aba739b354cfbd7104998dcbea
MD5 25009b75c9f00f6e3dbc1bf389deda20
BLAKE2b-256 fce8d48fcc38391880c5e670d29621e3fcaaa0fd89ac8e87f74b6f6ae4105559

See more details on using hashes here.

File details

Details for the file jax_blox-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: jax_blox-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 17.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.1

File hashes

Hashes for jax_blox-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 8e361d1e3d9b3912c269374846881e7b61fd7ba14918e9579620999559a34bff
MD5 a8e2bbaefa05f9314c4812fae317f1af
BLAKE2b-256 805100eb9aac0abf7272fd3541e04d3e03c00dd79d04024796ded2851743dfbc

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page