Skip to main content

Quantities in JAX

Project description

unxt

Unitful Quantities in JAX

PyPI: unxt PyPI versions: unxt ReadTheDocs unxt license

ruff ruff pre-commit CodSpeed Badge

CI status ReadTheDocs codecov

DOI badge


Unxt is unitful quantities and calculations in JAX, built on Equinox and Quax.

Unxt supports JAX's compelling features:

  • JIT compilation (jit)
  • vectorization (vmap, etc.)
  • auto-differentiation (grad, jacobian, hessian)
  • GPU/TPU/multi-host acceleration

And best of all, unxt doesn't force you to use special unit-compatible re-exports of JAX libraries. You can use unxt with existing JAX code, and with quax's simple decorator, JAX will work with unxt.Quantity.

Installation

PyPI version PyPI platforms

pip install unxt
using uv
uv add unxt
from source, using pip
pip install git+https://https://github.com/GalacticDynamics/unxt.git
building from source
cd /path/to/parent
git clone https://https://github.com/GalacticDynamics/unxt.git
cd unxt
pip install -e .  # editable mode

Documentation

Read The Docs

For full documentation, including installation instructions, tutorials, and API reference, please see the unxt docs. This README provides a brief overview and some quick examples.

Dimensions

Dimensions represent the physical type of a quantity, such as length, time, or mass.

import unxt as u

Create dimensions from strings:

length_dim = u.dimension("length")
print(length_dim)
# PhysicalType('length')

Dimensions support mathematical expressions:

speed_dim = u.dimension("length / time")
print(speed_dim)
# PhysicalType({'speed', 'velocity'})

Multi-word dimension names require parentheses in expressions:

activity_dim = u.dimension("(amount of substance) / (time)")
print(activity_dim)
# PhysicalType('catalytic activity')

Units

Units specify the scale and dimension of measurements.

meter = u.unit("m")
print(meter)
# Unit("m")

Units can be combined:

velocity_unit = u.unit("km/h")  # in the expression
print(velocity_unit)
# Unit("km / h")

velocity_unit2 = u.unit("km") / u.unit("h")  # via arithmetic
print(velocity_unit2)
# Unit("km / h")

Get the dimension of a unit:

print(u.dimension_of(meter))
# PhysicalType('length')

Unit Systems

Unit systems define consistent sets of base units for specific domains. unxt provides built-in unit systems and tools for creating custom ones.

Built-in Unit Systems

# SI (International System of Units)
si = u.unitsystem("si")
print(si)
# unitsystem(m, kg, s, mol, A, K, cd, rad)

# CGS (centimeter-gram-second)
cgs = u.unitsystem("cgs")
print(cgs)
# unitsystem(cm, g, s, dyn, erg, Ba, P, St, rad)

# Galactic (astrophysics)
galactic = u.unitsystem("galactic")
print(galactic)
# unitsystem(kpc, Myr, solMass, rad)

Composing Units from a Unit System

Once you have a unit system, you can get units for any physical dimension by indexing the system:

usys = u.unitsystem("si")

# Get specific units
print(usys["length"])
# Unit("m")

Custom Unit Systems

Create custom unit systems by specifying base units:

import unxt as u

# Define a custom unit system
custom_usys = u.unitsystem("km", "h", "tonne", "degree")
print(custom_usys)
# unitsystem(km, h, t, deg)

# Access derived units
print(custom_usys["velocity"])
# Unit("km / h")

Dynamical Unit Systems

For domains like gravitational dynamics, use dynamical unit systems where $G = 1$:

from unxt.unitsystems import DynamicalSimUSysFlag

# Create a dynamical system where G=1
# Only specify 2 of (length, time, mass)
usys = u.unitsystem(DynamicalSimUSysFlag, "kpc", "Myr")
print(usys)
# unitsystem(kpc, Myr, ...)

# The third dimension (mass) is computed to make G=1
print(usys["mass"])
# Unit("10^11 solMass")  # computed value

Quantities

Quantities combine values with units, providing type-safe unitful arithmetic.

Basic Quantities

import jax.numpy as jnp

x = u.Quantity(jnp.arange(1, 5, dtype=float), "km")
print(x)
# Quantity['length']([1., 2., 3., 4.], unit='km')

The constituent value and unit are accessible as attributes:

repr(x.value)
# Array([1., 2., 3., 4.], dtype=float64)

repr(x.unit)
# Unit("km")

Quantity objects obey the rules of unitful arithmetic.

# Addition / Subtraction
print(x + x)
# Quantity["length"]([2.0, 4.0, 6.0, 8.0], unit="km")

# Multiplication / Division
print(2 * x)
# Quantity["length"]([2.0, 4.0, 6.0, 8.0], unit="km")

y = u.Quantity(jnp.arange(4, 8, dtype=float), "yr")

print(x / y)
# Quantity['speed']([0.25, 0.4 , 0.5 , 0.57142857], unit='km / yr')

# Exponentiation
print(x**2)
# Quantity['area']([ 1.,  4.,  9., 16.], unit='km2')

# Unit checking on operations
try:
    x + y
except Exception as e:
    print(e)
# 'yr' (time) and 'km' (length) are not convertible

Quantities can be converted to different units:

print(u.uconvert("m", x))  # via function
# Quantity['length']([1000., 2000., 3000., 4000.], unit='m')

print(x.uconvert("m"))  # via method
# Quantity['length']([1000., 2000., 3000., 4000.], unit='m')

Since Quantity is parametric, it can do runtime dimension checking!

LengthQuantity = u.Quantity["length"]
print(LengthQuantity(2, "km"))
# Quantity['length'](2, unit='km')

try:
    LengthQuantity(2, "s")
except ValueError as e:
    print(e)
# Physical type mismatch.

BareQuantity

For performance-critical code where you don't need dimension checking, use BareQuantity:

import unxt as u
import jax.numpy as jnp

# BareQuantity skips dimension checks for better performance
bq = u.quantity.BareQuantity(jnp.array([1.0, 2.0, 3.0]), "m")
print(bq)
# BareQuantity([1., 2., 3.], unit='m')

# Works just like Quantity but without dimension validation
print(bq * 2)
# BareQuantity([2., 4., 6.], unit='m')

Angle

Angle is a specialized quantity with wrapping support for angular values:

import unxt as u
import jax.numpy as jnp

# Angles can wrap to a specified range
theta = u.Angle(jnp.array([0, 90, 180, 270, 360]), "deg")
print(theta)
# Angle([0., 90., 180., 270., 360.], unit='deg')

# Optional wrapping to a specified range
angle = u.Angle(jnp.array([370, -10]), "deg")
wrapped = angle.wrap_to(u.Q(0, "deg"), u.Q(360, "deg"))
print(wrapped)
# Angle([10., 350.], unit='deg')

StaticQuantity

For static configuration values (e.g., JAX static arguments), use StaticQuantity, which stores NumPy values and rejects JAX arrays:

import numpy as np
from functools import partial
import jax
import jax.numpy as jnp
import unxt as u

cfg = u.StaticQuantity(np.array([1.0, 2.0]), "m")


@partial(jax.jit, static_argnames=("q",))
def add(x, q):
    return x + jnp.asarray(q.value)


print(add(1.0, cfg))

StaticValue

If you want a Quantity that keeps a static value but still participates in regular arithmetic, wrap the value with StaticValue. Arithmetic behaves like the wrapped array, and StaticValue + StaticValue returns a StaticValue. Comparison operators (==, !=, <, <=, >, >=) return NumPy boolean arrays for element-wise comparison:

import numpy as np
import jax.numpy as jnp
import unxt as u

sv = u.quantity.StaticValue(np.array([1.0, 2.0]))
q_static = u.Q(sv, "m")
q = u.Q(jnp.array([3.0, 4.0]), "m")

print(q_static + q)

# Comparisons return NumPy boolean arrays (element-wise)
sv2 = u.quantity.StaticValue(np.array([2.0, 1.0]))
print(sv < sv2)  # array([ True, False])
print(sv == np.array([1.0, 2.0]))  # array([ True,  True])

JAX Integration

unxt is built on quax, which enables custom array-ish objects in JAX. For convenience we use the quaxed library, which is just a quax.quaxify wrapper around jax to avoid boilerplate code.

[!NOTE]

Using quaxed is optional. You can directly use quaxify, and even apply it to the top-level function instead of individual functions.

from quaxed import grad, vmap
import quaxed.numpy as jnp

# Using the x quantity from earlier examples
print(jnp.square(x))
# Quantity['area']([ 1.,  4.,  9., 16.], unit='km2')

print(jnp.power(x, 3))
# Quantity['volume']([ 1.,  8., 27., 64.], unit='km3')

print(vmap(grad(lambda x: x**3))(x))
# Quantity['area']([ 3., 12., 27., 48.], unit='km2')

See the documentation for more examples and details of JIT and AD

Citation

JOSS DOI

If you found this library to be useful and want to support the development and maintenance of lower-level code libraries for the scientific community, please consider citing this work.

Contributing and Development

Actions Status Documentation Status codecov SPEC 0 — Minimum Supported Dependencies pre-commit ruff CodSpeed Badge

We welcome contributions! Contributions are how open source projects improve and grow.

To contribute to unxt, please fork the repository, make a development branch, develop on that branch, then open a pull request from the branch in your fork to main.

To report bugs, request features, or suggest other ideas, please open an issue.

For more information, see CONTRIBUTING.md.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

unxt-1.10.3.tar.gz (1.0 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

unxt-1.10.3-py3-none-any.whl (83.5 kB view details)

Uploaded Python 3

File details

Details for the file unxt-1.10.3.tar.gz.

File metadata

  • Download URL: unxt-1.10.3.tar.gz
  • Upload date:
  • Size: 1.0 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for unxt-1.10.3.tar.gz
Algorithm Hash digest
SHA256 522318b9d19171f52faeaceb5aff9ceaa73c39ba7f7dfea434b42c1d126bc164
MD5 f9c155a4f3031d31dd15e44e5240f3e6
BLAKE2b-256 252f78436064bc5112882dcf871f5a51f16a6a2b2c27c7e36e8ca146e51cab34

See more details on using hashes here.

Provenance

The following attestation bundles were made for unxt-1.10.3.tar.gz:

Publisher: cd-unxt.yml on GalacticDynamics/unxt

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file unxt-1.10.3-py3-none-any.whl.

File metadata

  • Download URL: unxt-1.10.3-py3-none-any.whl
  • Upload date:
  • Size: 83.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for unxt-1.10.3-py3-none-any.whl
Algorithm Hash digest
SHA256 127c6a66e08126bba40e9f9383af31c089a9112d06eebabdf9eb1cdeff0f1359
MD5 fc91787f995716b11bdc2188334eb619
BLAKE2b-256 15e72635b0f644deb4cf9e92b3d3c10f6230ecde9814b85808b4302c8d7cf7fd

See more details on using hashes here.

Provenance

The following attestation bundles were made for unxt-1.10.3-py3-none-any.whl:

Publisher: cd-unxt.yml on GalacticDynamics/unxt

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page