Dynamic neural networks and function transformations in Python + Mojo
Project description
Nabla: High-Performance Distributed ML
A JAX-inspired autodiff library with factor-based SPMD sharding, built on Mojo & MAX.
Active Development: This is the
maindevelopment branch with distributed SPMD execution and a refined lazy, MAX-native execution model. Read the docs: https://nablaml.com.
Feature Showcase
1. Tensors & Autodiff
Define Python functions and compute gradients using trace-based automatic differentiation. Read more
import nabla
# Use Accelerator (GPU) or CPU for execution
with nabla.default_device(nabla.Accelerator()):
x = nabla.uniform((4, 8))
w = nabla.uniform((8, 16))
# Define loss function
def compute_loss(x, w):
return nabla.mean(nabla.relu(x @ w))
# Compute loss (implicit .realize() on print)
loss = compute_loss(x, w)
print("Loss:", loss)
# Compute gradients via backward replay
grad_x, grad_w = nabla.grad(compute_loss, argnums=(0, 1))(x, w)
print("Gradients:", grad_x.shape, grad_w.shape)
2. SPMD Sharding
Shard tensors on a logical mesh; operations automatically propagate sharding constraints. Read more
# Define 2×4 device mesh (Logical DP × TP)
mesh = nabla.DeviceMesh("my_mini_pod", (2, 4), ("dp", "tp"))
# Shard x on 'dp' (rows), w on 'tp' (columns)
x = nabla.shard(nabla.uniform((32, 128)), mesh, nabla.P("dp", None))
w = nabla.shard(nabla.uniform((128, 256)), mesh, nabla.P(None, "tp"))
def compute_loss(x, w):
return nabla.mean(nabla.relu(x @ w))
# Automatic AllReduce is inserted for 'tp' sum
loss = compute_loss(x, w)
print("Loss (Sharded):", loss)
3. Mojo Integration
Nabla's core strength is its ability to drop down to Mojo for high-performance custom kernels, bridging the gap between high-level Python and bare-metal execution. Read more
Mojo Kernel (kernels/custom_kernel.mojo)
@compiler.register("my_kernel")
struct MyKernel:
@staticmethod
def execute[target: StaticString](
output: OutputTensor,
x: InputTensor[dtype = output.dtype, rank = output.rank],
ctx: DeviceContextPtr,
):
@parameter
fn add_one[W: Int](idx: IndexList[x.rank]) -> SIMD[x.dtype, W]:
return x.load[W](idx) + 1
foreach[add_one, target=target](output, ctx)
Python Usage
class AddOneOp(nabla.UnaryOperation):
name = "my_kernel"
def kernel(self, x, **kwargs):
# Concise invocation: (func_name, path, inputs, out_types)
return nabla.call_custom_kernel("my_kernel", "./kernels", x, x.type)
x = nabla.Tensor.constant([1., 2., 3.])
y = AddOneOp()(x)
4. Distributed Pipeline Parallelism (GPipe)
Define complex distributed schedules like GPipe using vmap for parallel execution and ppermute for explicit data movement. Read more
# Parallel execution across 'num_stages'
@nabla.vmap(in_axes=(0, 0), spmd_axis_name="stage")
def stage_compute(x, w):
return nabla.relu(x @ w)
def pipeline_step(current_state, fresh_input, weights, mask_0):
# 1. Compute: Run all stages in parallel
computed = stage_compute(current_state, weights)
# 2. Communicate: Shift activations to the next stage (i -> i+1)
shifted = nabla.ppermute(computed, perm=[(i, (i + 1) % stages) for i in range(stages)])
# 3. Control: Stage 0 takes fresh input; others take shifted data
return nabla.where(mask_0, fresh_input, shifted)
5. Dynamic Shape Compilation
Compile functions once with symbolic dimensions to handle varying input sizes without recompilation.
# Compile once for ANY batch size (dim 0)
@nabla.compile(dynamic_dims={0: {0: "batch"}})
def square(x):
return x * x
x_small = nabla.uniform((2, 10))
x_large = nabla.uniform((128, 10))
res1 = square(x_small) # Triggers compilation
res2 = square(x_large) # Reuses compiled graph!
Architecture Overview
Nabla relies on three core principles:
- Lazy Execution: Shapes are computed eagerly, but the computation graph is built and compiled only when
.realize()is called. - Trace-Based Autodiff: Gradients are computed by tracing the forward pass and replaying operations in reverse.
- Factor-Based SPMD: Sharding is propagated using "semantic factors" (e.g., batch, heads) rather than physical mesh axes.
Development Setup
Prerequisites
- Python 3.12+
- Modular MAX SDK (via
requirements.txt)
Installation
git clone https://github.com/nabla-ml/nabla.git
cd nabla
python -m venv venv
source venv/bin/activate
pip install -r requirements-dev.txt
pip install -e ".[dev]"
PyPI Install
Stable (default index):
pip install nabla-ml
Latest Modular nightly (recommended for newest MAX ops):
pip install --pre --extra-index-url https://whl.modular.com/nightly/simple/ nabla-ml
GPU Support:
- Linux (AMD/NVIDIA): Supported natively via Modular MAX.
- macOS (Apple Silicon): Requires Xcode Metal toolchain (
xcode-select --install).
Stable Release (v25.7)
For the simpler, single-device version:
pip install nabla-ml
(See v25.7 branch)
Contributing
- Bugs/Docs: Submit PR directly.
- Features: Open an Issue first.
- New Ops: See nabla/ops/README.md.
License
Apache-2.0 — see LICENSE
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file nabla_ml-26.2171409.tar.gz.
File metadata
- Download URL: nabla_ml-26.2171409.tar.gz
- Upload date:
- Size: 164.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.6
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e926638d6ff49292296866e71f5845316c692b72b04745a8215c3162494857d8
|
|
| MD5 |
2e9c793f3e339a26ae73de28612ffe88
|
|
| BLAKE2b-256 |
9d3e7a4a094533bf883a666b7aa539e3a32c1006231ad3193bf4c633ca5a0a9f
|
File details
Details for the file nabla_ml-26.2171409-py3-none-any.whl.
File metadata
- Download URL: nabla_ml-26.2171409-py3-none-any.whl
- Upload date:
- Size: 203.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.6
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
0fde40d5059f14d0e27646e6c012427654ec60e5297fbe94dfa56274a8ccb80e
|
|
| MD5 |
98eab8323522f4dbfa3ca135e6b8a92f
|
|
| BLAKE2b-256 |
b7e8601c6dc2f1dff5da2578652abfef8ff346b9e49be5c09204013380cb4448
|