Enabling Unit-aware Computations for AI-driven Scientific Computing.
Project description
Motivation
SAIUnit (/saɪ ˈjuːnɪt/) is designed to provide physical units and unit-aware mathematical systems tailored for Scientific AI within JAX. In this context, Scientific AI refers to the use of AI models or tools to advance scientific computations. SAIUnit evolves from our BrainUnit, a unit framework originally developed for brain dynamics modeling, extending its capabilities to support a broader range of scientific computing applications. SAIUnit is committed to providing rigorous and automatic physical unit conversion and analysis system for general AI-driven scientific computing.
Features
Compared to existing unit libraries, such as Quantities and Pint, SAIUnit introduces a rigorous physical unit system specifically designed to support AI computations (e.g., automatic differentiation, just-in-time compilation, and parallelization). Its unique advantages include:
- Integration of over 2,000 commonly used physical units and constants
- Implementation of more than 500 unit-aware mathematical functions
- Deep integration with JAX, providing comprehensive support for modern AI framework features including automatic differentiation (autograd), just-in-time compilation (JIT), vectorization, and parallel computation
- Unit conversion and analysis are performed at compilation time, resulting in zero runtime overhead
- Strict physical unit type checking and dimensional inference system, detecting unit inconsistencies during compilation
graph TD
A[SAIUnit] --> B[Physical Units]
A --> C[Mathematical Functions]
A --> D[JAX Integration]
B --> B1[2000+ Units]
B --> B2[Physical Constants]
C --> C1[500+ Unit-aware Functions]
D --> D1[Autograd]
D --> D2[JIT Compilation]
D --> D3[Vectorization]
D --> D4[Parallelization]
We hope these features establish SAIUnit as a reliable physical unit handling solution for general AI-driven scientific computing scenarios.
A quick example:
import saiunit as u
# Define a physical quantity
x = 3.0 * u.meter
x
# [out] 3. * meter
# autograd
f = lambda x: x ** 3
u.autograd.grad(f)(x)
# [out] 27. * meter2
# JIT
import jax
jax.jit(f)(x)
# [out] 27. * klitre
# vmap
jax.vmap(f)(u.math.arange(0. * u.mV, 10. * u.mV, 1. * u.mV))
# [out] ArrayImpl([ 0., 1., 8., 27., 64., 125., 216., 343., 512., 729.]) * mvolt3
Multiple-backend support
saiunit is backend-agnostic: a Quantity pairs a unit with an array
mantissa, and that mantissa can live on any of the supported array libraries.
Every unit-aware operation dispatches to the matching backend, so you can stay
in one library end-to-end or convert with a single method call.
| Backend | Mantissa | Install | When to use |
|---|---|---|---|
numpy |
numpy.ndarray |
core (always installed) | eager CPU, scipy/pandas/sklearn interop |
jax |
jax.Array |
saiunit[jax] (or [cpu]/[cuda12]/[cuda13]/[tpu]) |
autograd, JIT, vmap, accelerators |
cupy |
cupy.ndarray |
saiunit[cupy] |
NVIDIA GPU arrays |
torch |
torch.Tensor |
saiunit[torch] |
PyTorch models, torch autograd |
dask |
dask.array.Array |
saiunit[dask] |
out-of-core / parallel, lazy compute |
ndonnx |
ndonnx.Array |
saiunit[ndonnx] |
symbolic graph for ONNX export |
Select or override the backend explicitly with u.using_backend(...) or
u.set_default_backend(...); convert with q.to_jax() / q.to_numpy() /
q.to_cupy() / q.to_torch() / q.to_dask() / q.to_ndonnx(). Requesting an
uninstalled backend raises saiunit.BackendError with the install command —
not a bare ImportError. See the
Backends documentation
for the full story.
Installation
saiunit has been well tested on python>=3.10 and can be installed on
Windows, Linux, and MacOS. The core package depends only on NumPy. JAX is
optional — install it to enable the saiunit.autograd, saiunit.lax,
and saiunit.sparse submodules, the custom exprel primitive, and the
"jax" backend.
Install the NumPy-only core:
pip install saiunit --upgrade
Or pull in JAX with the accelerator build that matches your hardware:
pip install -U saiunit[jax] # plain JAX
pip install -U saiunit[cpu] # pinned JAX CPU wheels
pip install -U saiunit[cuda12] # JAX on CUDA 12
pip install -U saiunit[cuda13] # JAX on CUDA 13
pip install -U saiunit[tpu] # JAX on TPU
Opt into additional array backends with the matching extra:
pip install -U saiunit[cupy] # CuPy (NVIDIA GPU)
pip install -U saiunit[torch] # PyTorch
pip install -U saiunit[dask] # Dask
pip install -U saiunit[ndonnx] # ndonnx
pip install -U saiunit[all] # jax + cupy + torch + dask + ndonnx
Without JAX, the NumPy backend is auto-selected and any access to a
JAX-only submodule (saiunit.autograd, saiunit.lax, saiunit.sparse)
raises saiunit.BackendError with an install hint. The optional extras are
independent and can be combined freely.
To install the latest version from source:
git clone https://github.com/chaobrain/saiunit.git
cd saiunit
pip install -e .
Alternatively, you can install BrainX, which bundles saiunit with other compatible packages for a comprehensive brain modeling ecosystem:
pip install BrainX -U
Documentation
The official documentation is hosted on Read the Docs: https://saiunit.readthedocs.io
Citation
@article{wang2025integrating,
title={Integrating physical units into high-performance AI-driven scientific computing},
author={Wang, Chaoming and He, Sichao and Luo, Shouwei and Huan, Yuxiang and Wu, Si},
journal={Nature Communications},
volume={16},
number={1},
pages={3609},
year={2025},
publisher={Nature Publishing Group UK London},
url={https://doi.org/10.1038/s41467-025-58626-4}
}
Ecosystem
saiunit has been deeply integrated into following diverse projects, such as:
brainstate: A State-based Transformation System for Program Compilation and Augmentationbraintaichi: Leveraging Taichi Lang to customize brain dynamics operatorsbraintools: The Common Toolbox for Brain Dynamics Programming.dendritex: Dendritic Modeling in JAXpinnx: Physics-Informed Neural Networks for Scientific Machine Learning in JAX.
Other unofficial projects include:
diffrax: Numerical differential equation solvers in JAX.jax-md: Differentiable Molecular Dynamics in JAXCatalax: JAX-based framework to model biological systems- ...
Acknowledgement
The initial version of the project benefited a lot from the following projects:
astropy.units: physical units inastropy.brian2.units: physical units inbrian2.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file saiunit-0.3.1.tar.gz.
File metadata
- Download URL: saiunit-0.3.1.tar.gz
- Upload date:
- Size: 450.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
fd48794ffafc3467be9ed4861d4ade1d3b71e08a97fbd2b7dd86a745af6bbee8
|
|
| MD5 |
e3a545e7f785d77d24dd2b671ff42a7e
|
|
| BLAKE2b-256 |
50691ada91ffb68db3bd64e8d241c64e79dd5815ac74a14f731aeb7410f4dac2
|
File details
Details for the file saiunit-0.3.1-py3-none-any.whl.
File metadata
- Download URL: saiunit-0.3.1-py3-none-any.whl
- Upload date:
- Size: 527.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
fe6318d92c4d4290045383c072fbe87fb2016ca80da886fc39f782518a3d387f
|
|
| MD5 |
278f61080acb581179d89a970e51f6d1
|
|
| BLAKE2b-256 |
1c0fdddac33c7122ecac98f30cef81f25042e62aaa4b26b76371f67481cb6c2f
|