Dynamic neural networks and function transformations in Python + Mojo
Project description
NABLA
Nabla is a Python library that provides three key features:
- Multidimensional Array computation (like NumPy) with strong GPU acceleration
- Composable Function Transformations:
vmap
,grad
,jit
, and other Automatic Differentiation tools - Deep integration with MAX and (custom) Mojo kernels
For tutorials and API reference, visit: nablaml.com
Installation
Now available on PyPI!
pip install nabla-ml
Quick Start
import nabla as nb
# Example function using Nabla's array operations
def foo(input):
return nb.sum(input * input, axes=0)
# Vectorize, differentiate, accelerate
foo_grads = nb.jit(nb.vmap(nb.grad(foo)))
gradients = foo_grads(nb.randn((10, 5)))
Development Setup and Reproducibility
This guide is for contributors or for reproducing the validation and benchmark results presented in the thesis.
1. Initial Setup
First, clone the repository and set up a virtual environment with all necessary dependencies.
# Clone the repository
git clone https://github.com/nabla-ml/nabla.git
cd nabla
# Create and activate a virtual environment (recommended)
python3 -m venv venv
source venv/bin/activate
# Install all core and development dependencies
pip install -r requirements-dev.txt
2. Run the Correctness Validation Suite
This runs the full test suite to verify Nabla's correctness against JAX.
# Navigate to the unit test directory from the project root
cd nabla/tests/unit
# Execute the unified test script
python unified.py all -all-configs
3. Run the Performance Benchmarks
This script reproduces the performance benchmarks for Nabla, JAX, and PyTorch.
# Navigate to the benchmark directory
cd nabla/tests/benchmarks
# Run the benchmark script
python benchmark.py
Repository Structure
nabla/
├── nabla/ # Core Python library
│ ├── core/ # Array class and MAX compiler integration
│ ├── nn/ # Neural network modules and models
│ ├── ops/ # Mathematical operations (binary, unary, linalg, etc.)
│ ├── transforms/ # Function transformations (vmap, grad, jit, etc.)
│ └── utils/ # Utilities (formatting, types, MAX-interop, etc.)
├── tests/ # Comprehensive test suite
├── tutorials/ # Notebooks on Nabla usage for ML tasks
├── examples/ # Example scripts for common use cases
└── experimental/ # Core (pure) Mojo library (WIP!)
Contributing
Contributions welcome! Discuss significant changes in Issues first. Submit PRs for bugs, docs, and smaller features.
License
Nabla is licensed under the Apache-2.0 license.
Thank you for checking out Nabla!
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file nabla_ml-25.7162019.tar.gz
.
File metadata
- Download URL: nabla_ml-25.7162019.tar.gz
- Upload date:
- Size: 196.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 |
107e5403e581bd3eb30f8dc17d0ddbf882a360bc87af1f03f3698b64c4da2a0d
|
|
MD5 |
b713b536e548a0f27592bf490226418d
|
|
BLAKE2b-256 |
dad380050d48980c2cc9778ed144c232945d61e2e41758d3a34936c3e254db0c
|
File details
Details for the file nabla_ml-25.7162019-py3-none-any.whl
.
File metadata
- Download URL: nabla_ml-25.7162019-py3-none-any.whl
- Upload date:
- Size: 261.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 |
b5b8a86dd45d8e677d4bcd17c235c0cc39fdaac7c4259c1e8c771e79f1868970
|
|
MD5 |
1c8f640ba49b63a51cbe32b136f02381
|
|
BLAKE2b-256 |
f44f696fc54f3032ec39c713affb08e8c0efe0026c2e10f2a9cfd54cf749f346
|