Quantities in JAX
Project description
unxt
Unitful Quantities in JAX
Unxt is unitful quantities and calculations in JAX, built on Equinox and Quax.
Unxt supports JAX's compelling features:
- JIT compilation (
jit) - vectorization (
vmap, etc.) - auto-differentiation (
grad,jacobian,hessian) - GPU/TPU/multi-host acceleration
And best of all, unxt doesn't force you to use special unit-compatible
re-exports of JAX libraries. You can use unxt with existing JAX code, and with
quax's simple decorator, JAX will work with unxt.Quantity.
Installation
pip install unxt
using uv
uv add unxt
from source, using pip
pip install git+https://https://github.com/GalacticDynamics/unxt.git
building from source
cd /path/to/parent
git clone https://https://github.com/GalacticDynamics/unxt.git
cd unxt
pip install -e . # editable mode
Documentation
For full documentation, including installation instructions, tutorials, and API reference, please see the unxt docs. This README provides a brief overview and some quick examples.
Dimensions
Dimensions represent the physical type of a quantity, such as length, time, or mass.
import unxt as u
Create dimensions from strings:
length_dim = u.dimension("length")
print(length_dim)
# PhysicalType('length')
Dimensions support mathematical expressions:
speed_dim = u.dimension("length / time")
print(speed_dim)
# PhysicalType({'speed', 'velocity'})
Multi-word dimension names require parentheses in expressions:
activity_dim = u.dimension("(amount of substance) / (time)")
print(activity_dim)
# PhysicalType('catalytic activity')
Units
Units specify the scale and dimension of measurements.
meter = u.unit("m")
print(meter)
# Unit("m")
Units can be combined:
velocity_unit = u.unit("km/h") # in the expression
print(velocity_unit)
# Unit("km / h")
velocity_unit2 = u.unit("km") / u.unit("h") # via arithmetic
print(velocity_unit2)
# Unit("km / h")
Get the dimension of a unit:
print(u.dimension_of(meter))
# PhysicalType('length')
Unit Systems
Unit systems define consistent sets of base units for specific domains. unxt
provides built-in unit systems and tools for creating custom ones.
Built-in Unit Systems
# SI (International System of Units)
si = u.unitsystem("si")
print(si)
# unitsystem(m, kg, s, mol, A, K, cd, rad)
# CGS (centimeter-gram-second)
cgs = u.unitsystem("cgs")
print(cgs)
# unitsystem(cm, g, s, dyn, erg, Ba, P, St, rad)
# Galactic (astrophysics)
galactic = u.unitsystem("galactic")
print(galactic)
# unitsystem(kpc, Myr, solMass, rad)
Composing Units from a Unit System
Once you have a unit system, you can get units for any physical dimension by indexing the system:
usys = u.unitsystem("si")
# Get specific units
print(usys["length"])
# Unit("m")
Custom Unit Systems
Create custom unit systems by specifying base units:
import unxt as u
# Define a custom unit system
custom_usys = u.unitsystem("km", "h", "tonne", "degree")
print(custom_usys)
# unitsystem(km, h, t, deg)
# Access derived units
print(custom_usys["velocity"])
# Unit("km / h")
Dynamical Unit Systems
For domains like gravitational dynamics, use dynamical unit systems where $G = 1$:
from unxt.unitsystems import DynamicalSimUSysFlag
# Create a dynamical system where G=1
# Only specify 2 of (length, time, mass)
usys = u.unitsystem(DynamicalSimUSysFlag, "kpc", "Myr")
print(usys)
# unitsystem(kpc, Myr, ...)
# The third dimension (mass) is computed to make G=1
print(usys["mass"])
# Unit("10^11 solMass") # computed value
Quantities
Quantities combine values with units, providing type-safe unitful arithmetic.
Basic Quantities
import jax.numpy as jnp
x = u.Quantity(jnp.arange(1, 5, dtype=float), "km")
print(x)
# Quantity['length']([1., 2., 3., 4.], unit='km')
The constituent value and unit are accessible as attributes:
repr(x.value)
# Array([1., 2., 3., 4.], dtype=float64)
repr(x.unit)
# Unit("km")
Quantity objects obey the rules of unitful arithmetic.
# Addition / Subtraction
print(x + x)
# Quantity["length"]([2.0, 4.0, 6.0, 8.0], unit="km")
# Multiplication / Division
print(2 * x)
# Quantity["length"]([2.0, 4.0, 6.0, 8.0], unit="km")
y = u.Quantity(jnp.arange(4, 8, dtype=float), "yr")
print(x / y)
# Quantity['speed']([0.25, 0.4 , 0.5 , 0.57142857], unit='km / yr')
# Exponentiation
print(x**2)
# Quantity['area']([ 1., 4., 9., 16.], unit='km2')
# Unit checking on operations
try:
x + y
except Exception as e:
print(e)
# 'yr' (time) and 'km' (length) are not convertible
Quantities can be converted to different units:
print(u.uconvert("m", x)) # via function
# Quantity['length']([1000., 2000., 3000., 4000.], unit='m')
print(x.uconvert("m")) # via method
# Quantity['length']([1000., 2000., 3000., 4000.], unit='m')
Since Quantity is parametric, it can do runtime dimension checking!
LengthQuantity = u.Quantity["length"]
print(LengthQuantity(2, "km"))
# Quantity['length'](2, unit='km')
try:
LengthQuantity(2, "s")
except ValueError as e:
print(e)
# Physical type mismatch.
BareQuantity
For performance-critical code where you don't need dimension checking, use
BareQuantity:
import unxt as u
import jax.numpy as jnp
# BareQuantity skips dimension checks for better performance
bq = u.quantity.BareQuantity(jnp.array([1.0, 2.0, 3.0]), "m")
print(bq)
# BareQuantity([1., 2., 3.], unit='m')
# Works just like Quantity but without dimension validation
print(bq * 2)
# BareQuantity([2., 4., 6.], unit='m')
Angle
Angle is a specialized quantity with wrapping support for angular values:
import unxt as u
import jax.numpy as jnp
# Angles can wrap to a specified range
theta = u.Angle(jnp.array([0, 90, 180, 270, 360]), "deg")
print(theta)
# Angle([0., 90., 180., 270., 360.], unit='deg')
# Optional wrapping to a specified range
angle = u.Angle(jnp.array([370, -10]), "deg")
wrapped = angle.wrap_to(u.Q(0, "deg"), u.Q(360, "deg"))
print(wrapped)
# Angle([10., 350.], unit='deg')
StaticQuantity
For static configuration values (e.g., JAX static arguments), use
StaticQuantity, which stores NumPy values and rejects JAX arrays:
import numpy as np
from functools import partial
import jax
import jax.numpy as jnp
import unxt as u
cfg = u.StaticQuantity(np.array([1.0, 2.0]), "m")
@partial(jax.jit, static_argnames=("q",))
def add(x, q):
return x + jnp.asarray(q.value)
print(add(1.0, cfg))
StaticValue
If you want a Quantity that keeps a static value but still participates in
regular arithmetic, wrap the value with StaticValue. Arithmetic behaves like
the wrapped array, and StaticValue + StaticValue returns a StaticValue.
Comparison operators (==, !=, <, <=, >, >=) return NumPy boolean
arrays for element-wise comparison:
import numpy as np
import jax.numpy as jnp
import unxt as u
sv = u.quantity.StaticValue(np.array([1.0, 2.0]))
q_static = u.Q(sv, "m")
q = u.Q(jnp.array([3.0, 4.0]), "m")
print(q_static + q)
# Comparisons return NumPy boolean arrays (element-wise)
sv2 = u.quantity.StaticValue(np.array([2.0, 1.0]))
print(sv < sv2) # array([ True, False])
print(sv == np.array([1.0, 2.0])) # array([ True, True])
JAX Integration
unxt is built on quax, which enables custom array-ish objects in
JAX. For convenience we use the quaxed library, which is just a
quax.quaxify wrapper around jax to avoid boilerplate code.
[!NOTE]
Using
quaxedis optional. You can directly usequaxify, and even apply it to the top-level function instead of individual functions.
from quaxed import grad, vmap
import quaxed.numpy as jnp
# Using the x quantity from earlier examples
print(jnp.square(x))
# Quantity['area']([ 1., 4., 9., 16.], unit='km2')
print(jnp.power(x, 3))
# Quantity['volume']([ 1., 8., 27., 64.], unit='km3')
print(vmap(grad(lambda x: x**3))(x))
# Quantity['area']([ 3., 12., 27., 48.], unit='km2')
See the documentation for more examples and details of JIT and AD
Citation
If you found this library to be useful and want to support the development and maintenance of lower-level code libraries for the scientific community, please consider citing this work.
Contributing and Development
We welcome contributions! Contributions are how open source projects improve and grow.
To contribute to unxt, please
fork the repository, make a
development branch, develop on that branch, then
open a pull request from the
branch in your fork to main.
To report bugs, request features, or suggest other ideas, please open an issue.
For more information, see CONTRIBUTING.md.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file unxt-1.10.3.tar.gz.
File metadata
- Download URL: unxt-1.10.3.tar.gz
- Upload date:
- Size: 1.0 MB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
522318b9d19171f52faeaceb5aff9ceaa73c39ba7f7dfea434b42c1d126bc164
|
|
| MD5 |
f9c155a4f3031d31dd15e44e5240f3e6
|
|
| BLAKE2b-256 |
252f78436064bc5112882dcf871f5a51f16a6a2b2c27c7e36e8ca146e51cab34
|
Provenance
The following attestation bundles were made for unxt-1.10.3.tar.gz:
Publisher:
cd-unxt.yml on GalacticDynamics/unxt
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
unxt-1.10.3.tar.gz -
Subject digest:
522318b9d19171f52faeaceb5aff9ceaa73c39ba7f7dfea434b42c1d126bc164 - Sigstore transparency entry: 884122767
- Sigstore integration time:
-
Permalink:
GalacticDynamics/unxt@8773e25de923b4fafbd7f7438c447c446a692a85 -
Branch / Tag:
refs/tags/unxt-v1.10.3 - Owner: https://github.com/GalacticDynamics
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
cd-unxt.yml@8773e25de923b4fafbd7f7438c447c446a692a85 -
Trigger Event:
release
-
Statement type:
File details
Details for the file unxt-1.10.3-py3-none-any.whl.
File metadata
- Download URL: unxt-1.10.3-py3-none-any.whl
- Upload date:
- Size: 83.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
127c6a66e08126bba40e9f9383af31c089a9112d06eebabdf9eb1cdeff0f1359
|
|
| MD5 |
fc91787f995716b11bdc2188334eb619
|
|
| BLAKE2b-256 |
15e72635b0f644deb4cf9e92b3d3c10f6230ecde9814b85808b4302c8d7cf7fd
|
Provenance
The following attestation bundles were made for unxt-1.10.3-py3-none-any.whl:
Publisher:
cd-unxt.yml on GalacticDynamics/unxt
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
unxt-1.10.3-py3-none-any.whl -
Subject digest:
127c6a66e08126bba40e9f9383af31c089a9112d06eebabdf9eb1cdeff0f1359 - Sigstore transparency entry: 884122872
- Sigstore integration time:
-
Permalink:
GalacticDynamics/unxt@8773e25de923b4fafbd7f7438c447c446a692a85 -
Branch / Tag:
refs/tags/unxt-v1.10.3 - Owner: https://github.com/GalacticDynamics
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
cd-unxt.yml@8773e25de923b4fafbd7f7438c447c446a692a85 -
Trigger Event:
release
-
Statement type: