Skip to main content

Dynamic neural networks and function transformations in Python + Mojo

Project description

Nabla: High-Performance Distributed ML

A JAX-inspired autodiff library with factor-based SPMD sharding, built on Mojo & MAX.

Active Development: This is the main development branch with distributed SPMD execution and a refined lazy, MAX-native execution model. Read the docs: https://nablaml.com.

Development Status Python 3.12+ License: Apache 2.0


Feature Showcase

1. Tensors & Autodiff

Define Python functions and compute gradients using trace-based automatic differentiation. Read more

import nabla

# Use Accelerator (GPU) or CPU for execution
with nabla.default_device(nabla.Accelerator()):
    x = nabla.uniform((4, 8))
    w = nabla.uniform((8, 16))

    # Define loss function
    def compute_loss(x, w):
        return nabla.mean(nabla.relu(x @ w))

    # Compute loss (implicit .realize() on print)
    loss = compute_loss(x, w)
    print("Loss:", loss)

    # Compute gradients via backward replay
    grad_x, grad_w = nabla.grad(compute_loss, argnums=(0, 1))(x, w)
    print("Gradients:", grad_x.shape, grad_w.shape)

2. SPMD Sharding

Shard tensors on a logical mesh; operations automatically propagate sharding constraints. Read more

# Define 2×4 device mesh (Logical DP × TP)
mesh = nabla.DeviceMesh("my_mini_pod", (2, 4), ("dp", "tp"))

# Shard x on 'dp' (rows), w on 'tp' (columns)
x = nabla.shard(nabla.uniform((32, 128)), mesh, nabla.P("dp", None))
w = nabla.shard(nabla.uniform((128, 256)), mesh, nabla.P(None, "tp"))

def compute_loss(x, w):
    return nabla.mean(nabla.relu(x @ w))

# Automatic AllReduce is inserted for 'tp' sum
loss = compute_loss(x, w)
print("Loss (Sharded):", loss)

3. Mojo Integration

Nabla's core strength is its ability to drop down to Mojo for high-performance custom kernels, bridging the gap between high-level Python and bare-metal execution. Read more

Mojo Kernel (kernels/custom_kernel.mojo)

@compiler.register("my_kernel")
struct MyKernel:
    @staticmethod
    def execute[target: StaticString](
        output: OutputTensor,
        x: InputTensor[dtype = output.dtype, rank = output.rank],
        ctx: DeviceContextPtr,
    ):
        @parameter
        fn add_one[W: Int](idx: IndexList[x.rank]) -> SIMD[x.dtype, W]:
            return x.load[W](idx) + 1

        foreach[add_one, target=target](output, ctx)

Python Usage

class AddOneOp(nabla.UnaryOperation):
    name = "my_kernel"

    def kernel(self, x, **kwargs):
        # Concise invocation: (func_name, path, inputs, out_types)
        return nabla.call_custom_kernel("my_kernel", "./kernels", x, x.type)

x = nabla.Tensor.constant([1., 2., 3.])
y = AddOneOp()(x)

4. Distributed Pipeline Parallelism (GPipe)

Define complex distributed schedules like GPipe using vmap for parallel execution and ppermute for explicit data movement. Read more

# Parallel execution across 'num_stages'
@nabla.vmap(in_axes=(0, 0), spmd_axis_name="stage")
def stage_compute(x, w): 
    return nabla.relu(x @ w)

def pipeline_step(current_state, fresh_input, weights, mask_0):
    # 1. Compute: Run all stages in parallel
    computed = stage_compute(current_state, weights)

    # 2. Communicate: Shift activations to the next stage (i -> i+1)
    shifted = nabla.ppermute(computed, perm=[(i, (i + 1) % stages) for i in range(stages)])

    # 3. Control: Stage 0 takes fresh input; others take shifted data
    return nabla.where(mask_0, fresh_input, shifted)

5. Dynamic Shape Compilation

Compile functions once with symbolic dimensions to handle varying input sizes without recompilation.

# Compile once for ANY batch size (dim 0)
@nabla.compile(dynamic_dims={0: {0: "batch"}})
def square(x):
    return x * x

x_small = nabla.uniform((2, 10))
x_large = nabla.uniform((128, 10))

res1 = square(x_small) # Triggers compilation
res2 = square(x_large) # Reuses compiled graph!

Architecture Overview

Nabla relies on three core principles:

  1. Lazy Execution: Shapes are computed eagerly, but the computation graph is built and compiled only when .realize() is called.
  2. Trace-Based Autodiff: Gradients are computed by tracing the forward pass and replaying operations in reverse.
  3. Factor-Based SPMD: Sharding is propagated using "semantic factors" (e.g., batch, heads) rather than physical mesh axes.

Development Setup

Prerequisites

  • Python 3.12+
  • Modular MAX SDK (via requirements.txt)

Installation

git clone https://github.com/nabla-ml/nabla.git
cd nabla
python -m venv venv
source venv/bin/activate
pip install -r requirements-dev.txt
pip install -e ".[dev]"

PyPI Install

Stable (default index):

pip install nabla-ml

Latest Modular nightly (recommended for newest MAX ops):

pip install --pre --extra-index-url https://whl.modular.com/nightly/simple/ nabla-ml

GPU Support:

  • Linux (AMD/NVIDIA): Supported natively via Modular MAX.
  • macOS (Apple Silicon): Requires Xcode Metal toolchain (xcode-select --install).

Stable Release (v25.7)

For the simpler, single-device version:

pip install nabla-ml

(See v25.7 branch)


Contributing

  • Bugs/Docs: Submit PR directly.
  • Features: Open an Issue first.
  • New Ops: See nabla/ops/README.md.

License

Apache-2.0 — see LICENSE

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

nabla_ml-26.2171409.tar.gz (164.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

nabla_ml-26.2171409-py3-none-any.whl (203.3 kB view details)

Uploaded Python 3

File details

Details for the file nabla_ml-26.2171409.tar.gz.

File metadata

  • Download URL: nabla_ml-26.2171409.tar.gz
  • Upload date:
  • Size: 164.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.6

File hashes

Hashes for nabla_ml-26.2171409.tar.gz
Algorithm Hash digest
SHA256 e926638d6ff49292296866e71f5845316c692b72b04745a8215c3162494857d8
MD5 2e9c793f3e339a26ae73de28612ffe88
BLAKE2b-256 9d3e7a4a094533bf883a666b7aa539e3a32c1006231ad3193bf4c633ca5a0a9f

See more details on using hashes here.

File details

Details for the file nabla_ml-26.2171409-py3-none-any.whl.

File metadata

  • Download URL: nabla_ml-26.2171409-py3-none-any.whl
  • Upload date:
  • Size: 203.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.6

File hashes

Hashes for nabla_ml-26.2171409-py3-none-any.whl
Algorithm Hash digest
SHA256 0fde40d5059f14d0e27646e6c012427654ec60e5297fbe94dfa56274a8ccb80e
MD5 98eab8323522f4dbfa3ca135e6b8a92f
BLAKE2b-256 b7e8601c6dc2f1dff5da2578652abfef8ff346b9e49be5c09204013380cb4448

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page