Skip to main content

Dynamic neural networks and function transformations in Python + Mojo

Project description

alt text

Nabla is a Machine Learning library for the emerging Mojo/Python ecosystem, featuring:

  • Gradient computation the PyTorch way (imperatively via .backward())
  • Purely-functional, JAX-like composable function transformations: grad, vmap, jit, etc.
  • Custom differentiable CPU/GPU kernels

For tutorials and API reference, visit: nablaml.com

Installation

pip install nabla-ml

Quick Start

The most simple, but fully functional Neural Network training setup:

import nabla as nb

# Defines MLP forward pass and loss.
def loss_fn(params, x, y):
    for i in range(0, len(params) - 2, 2):
        x = nb.relu(x @ params[i] + params[i + 1])
    predictions = x @ params[-2] + params[-1]
    return nb.mean((predictions - y) ** 2)

# JIT-compiled training step via SGD
@nb.jit(auto_device=True)
def train_step(params, x, y, lr):
    loss, grads = nb.value_and_grad(loss_fn)(params, x, y)
    return loss, [p - g * lr for p, g in zip(params, grads)]

# Setup network (hyper)parameters.
LAYERS = [1, 32, 64, 32, 1]
params = [p for i in range(len(LAYERS) - 1) for p in (nb.glorot_uniform((LAYERS[i], LAYERS[i + 1])), nb.zeros((1, LAYERS[i + 1])),)]

# Run training loop.
x, y = nb.rand((256, 1)), nb.rand((256, 1))
for i in range(1001):
    loss, params = train_step(params, x, y, 0.01)
    if i % 100 == 0: print(i, loss.to_numpy())

For Developers

  1. Clone the repository
  2. Create a virtual environment (recommended)
  3. Install dependencies
git clone https://github.com/nabla-ml/nabla.git
cd nabla

python3 -m venv venv
source venv/bin/activate

pip install -r requirements-dev.txt
pip install -e ".[dev]"

Repository Structure

nabla/
├── nabla/                     # Core Python library
│   ├── core/                  # Tensor class and MAX compiler integration
│   ├── nn/                    # Neural network modules and models
│   ├── ops/                   # Mathematical operations (binary, unary, linalg, etc.)
│   ├── transforms/            # Function transformations (vmap, grad, jit, etc.)
│   └── utils/                 # Utilities (formatting, types, MAX-interop, etc.)
├── tests/                     # Comprehensive test suite
├── tutorials/                 # Notebooks on Nabla usage for ML tasks
└── examples/                  # Example scripts for common use cases

Contributing

Contributions welcome! Discuss significant changes in Issues first. Submit PRs for bugs, docs, and smaller features.

License

Nabla is licensed under the Apache-2.0 license.


Development Status PyPI version Python 3.12+ License: Apache 2.0

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

nabla_ml-25.10271359.tar.gz (99.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

nabla_ml-25.10271359-py3-none-any.whl (127.2 kB view details)

Uploaded Python 3

File details

Details for the file nabla_ml-25.10271359.tar.gz.

File metadata

  • Download URL: nabla_ml-25.10271359.tar.gz
  • Upload date:
  • Size: 99.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.6

File hashes

Hashes for nabla_ml-25.10271359.tar.gz
Algorithm Hash digest
SHA256 18ecde5ee33bcbc26c399fa06c0d139e68f2559bd1a6d6ca66bca335d73356a7
MD5 85df579d6dacfd4b11770721b8095f65
BLAKE2b-256 e3ec642ab32f5dc3d7f724e632edf8d386b1f1768b80556779ff96f20d730432

See more details on using hashes here.

File details

Details for the file nabla_ml-25.10271359-py3-none-any.whl.

File metadata

  • Download URL: nabla_ml-25.10271359-py3-none-any.whl
  • Upload date:
  • Size: 127.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.6

File hashes

Hashes for nabla_ml-25.10271359-py3-none-any.whl
Algorithm Hash digest
SHA256 f9cc8f3afa38886f1095e9015ea038ac89c1313aa31f5361891e0b2d0e7259b8
MD5 0b8189822141e45616210da026d685a2
BLAKE2b-256 09fa5f27732ed5e209063e752d66057faf57e6bb86ca0450e667eb04b8a6bbad

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page