Dynamic neural networks and function transformations in Python + Mojo
Project description
Nabla is a Machine Learning library for the emerging Mojo/Python ecosystem, featuring:
- Gradient computation the PyTorch way (imperatively via .backward())
- Purely-functional, JAX-like composable function transformations:
grad,vmap,jit, etc. - Custom differentiable CPU/GPU kernels
For tutorials and API reference, visit: nablaml.com
Installation
pip install nabla-ml
Quick Start
The most simple, but fully functional Neural Network training setup:
import nabla as nb
# Defines MLP forward pass and loss.
def loss_fn(params, x, y):
for i in range(0, len(params) - 2, 2):
x = nb.relu(x @ params[i] + params[i + 1])
predictions = x @ params[-2] + params[-1]
return nb.mean((predictions - y) ** 2)
# JIT-compiled training step via SGD
@nb.jit(auto_device=True)
def train_step(params, x, y, lr):
loss, grads = nb.value_and_grad(loss_fn)(params, x, y)
return loss, [p - g * lr for p, g in zip(params, grads)]
# Setup network (hyper)parameters.
LAYERS = [1, 32, 64, 32, 1]
params = [p for i in range(len(LAYERS) - 1) for p in (nb.glorot_uniform((LAYERS[i], LAYERS[i + 1])), nb.zeros((1, LAYERS[i + 1])),)]
# Run training loop.
x, y = nb.rand((256, 1)), nb.rand((256, 1))
for i in range(1001):
loss, params = train_step(params, x, y, 0.01)
if i % 100 == 0: print(i, loss.to_numpy())
For Developers
- Clone the repository
- Create a virtual environment (recommended)
- Install dependencies
git clone https://github.com/nabla-ml/nabla.git
cd nabla
python3 -m venv venv
source venv/bin/activate
pip install -r requirements-dev.txt
pip install -e ".[dev]"
Repository Structure
nabla/
├── nabla/ # Core Python library
│ ├── core/ # Tensor class and MAX compiler integration
│ ├── nn/ # Neural network modules and models
│ ├── ops/ # Mathematical operations (binary, unary, linalg, etc.)
│ ├── transforms/ # Function transformations (vmap, grad, jit, etc.)
│ └── utils/ # Utilities (formatting, types, MAX-interop, etc.)
├── tests/ # Comprehensive test suite
├── tutorials/ # Notebooks on Nabla usage for ML tasks
└── examples/ # Example scripts for common use cases
Contributing
Contributions welcome! Discuss significant changes in Issues first. Submit PRs for bugs, docs, and smaller features.
License
Nabla is licensed under the Apache-2.0 license.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file nabla_ml-25.10271359.tar.gz.
File metadata
- Download URL: nabla_ml-25.10271359.tar.gz
- Upload date:
- Size: 99.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.6
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
18ecde5ee33bcbc26c399fa06c0d139e68f2559bd1a6d6ca66bca335d73356a7
|
|
| MD5 |
85df579d6dacfd4b11770721b8095f65
|
|
| BLAKE2b-256 |
e3ec642ab32f5dc3d7f724e632edf8d386b1f1768b80556779ff96f20d730432
|
File details
Details for the file nabla_ml-25.10271359-py3-none-any.whl.
File metadata
- Download URL: nabla_ml-25.10271359-py3-none-any.whl
- Upload date:
- Size: 127.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.6
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f9cc8f3afa38886f1095e9015ea038ac89c1313aa31f5361891e0b2d0e7259b8
|
|
| MD5 |
0b8189822141e45616210da026d685a2
|
|
| BLAKE2b-256 |
09fa5f27732ed5e209063e752d66057faf57e6bb86ca0450e667eb04b8a6bbad
|