Automatic differentiation for numerical relativity using JAX
Project description
autograv
Bridging numerical relativity and automatic differentiation using JAX
autograv is a Python library that uses JAX and automatic differentiation to compute various tensors and quantities from Einstein's general theory of relativity. Given a metric function, it can calculate Christoffel symbols, curvature tensors, and solve the Einstein field equations with high numerical precision.
Features
- Automatic Differentiation: Uses JAX's
jax.jacfwdfor forward-mode automatic differentiation to compute derivatives of metric tensors with exact numerical precision - Tensor Calculus: Leverages
jax.numpy.einsumfor efficient Einstein summation notation operations - High Precision: Configured to use 64-bit floating point arithmetic for maximum accuracy
- Pure Functions: All computations are functional and composable
What can you compute?
Given a metric tensor function, autograv can compute:
- Christoffel symbols (affine connection coefficients)
- Torsion tensor (verification that connection is symmetric)
- Riemann curvature tensor (intrinsic curvature of spacetime)
- Ricci tensor and Ricci scalar (curvature related to volume change)
- Einstein tensor (left-hand side of Einstein field equations)
- Stress-energy-momentum tensor (mass-energy content)
- Kretschmann invariant (scalar curvature for detecting singularities)
Installation
# Using uv
uv pip install -e .
# Or using pip
pip install -e .
Quick Start
import jax.numpy as jnp
from autograv import (
spherical_polar_metric,
christoffel_symbols,
riemann_tensor,
einstein_tensor,
)
# Define coordinates
coordinates = jnp.array([5, jnp.pi/3, jnp.pi/2], dtype=jnp.float64)
# Compute Christoffel symbols for the 2-sphere
christoffels = christoffel_symbols(coordinates, spherical_polar_metric)
print(christoffels)
# Compute Riemann tensor
riemann = riemann_tensor(coordinates, spherical_polar_metric)
print(riemann)
Examples
The examples/ directory contains complete examples:
sphere_example.py: Computing quantities for a 2-sphere metricschwarzschild_example.py: Computing quantities for the Schwarzschild black hole metric
Run them with:
uv run python examples/sphere_example.py
uv run python examples/schwarzschild_example.py
How it Works
Automatic Differentiation
Traditional approaches to computing derivatives in physics use either:
- Symbolic differentiation: Exact but computationally expensive
- Numerical differentiation: Fast but prone to floating-point errors
Automatic differentiation (autodiff) combines the best of both worlds by:
- Tracing computational operations to build a directed acyclic graph (DAG)
- Computing gradients via the chain rule by traversing the graph
- Achieving exact numerical precision at machine precision limits
JAX Integration
JAX provides:
jax.jacfwd: Forward-mode autodiff for computing Jacobiansjax.numpy.einsum: Efficient Einstein summation for tensor operations- NumPy-compatible API with GPU/TPU acceleration support
Example: Christoffel Symbols
Given a metric tensor g_ij, the Christoffel symbols are:
Γ^j_kl = (1/2) g^jm (∂g_mk/∂x^l + ∂g_lm/∂x^k - ∂g_kl/∂x^m)
In code:
def christoffel_symbols(coordinates, metric):
g = metric(coordinates)
g_inv = jnp.linalg.inv(g)
jacobian = jax.jacfwd(metric)(coordinates) # Automatic differentiation!
return 0.5 * jnp.einsum('jm, klm -> jkl', g_inv,
jnp.einsum('klm -> mkl', jacobian) +
jnp.einsum('klm -> lmk', jacobian) - jacobian)
API Reference
Metrics
minkowski_metric(coordinates): Flat spacetime metricspherical_polar_metric(coordinates): 2-sphere metric in (r, θ, φ)
Core Functions
christoffel_symbols(coordinates, metric): Affine connection coefficientstorsion_tensor(coordinates, metric): Antisymmetric part of connectionriemann_tensor(coordinates, metric): Curvature tensorricci_tensor(coordinates, metric): Trace of Riemann tensorricci_scalar(coordinates, metric): Scalar curvaturekretschmann_invariant(coordinates, metric): Curvature invarianteinstein_tensor(coordinates, metric): G_ij = R_ij - (1/2)g_ij Rstress_energy_momentum_tensor(coordinates, metric): T_ij from Einstein equations
Utilities
close_to_zero(func): Decorator to suppress near-zero numerical noiseTOLERANCE: Threshold for zero suppression (default: 1e-8)
Requirements
- Python 3.11+
- JAX (CPU-only on Windows, GPU/TPU support on Linux/macOS)
- NumPy
Future Work
- Add more standard metrics (Kerr, Kerr-Newman, FRW, etc.)
- Implement Weyl tensor and Weyl invariant
- Support for JIT compilation with
@jax.jit - GPU/TPU acceleration examples
- Integration with differential equation solvers
- Visualization tools for curvature
License
MIT
Acknowledgments
Based on concepts from the blog post "Bridging numerical relativity and automatic differentiation using JAX". This project demonstrates the synergy between modern machine learning tools and classical physics computations.
References
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file autograv-0.1.0.tar.gz.
File metadata
- Download URL: autograv-0.1.0.tar.gz
- Upload date:
- Size: 6.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.9.17 {"installer":{"name":"uv","version":"0.9.17","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":null,"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e280929301c836c81d7ecde1bd72867893fee8c5249448d720cbd5530d94afa0
|
|
| MD5 |
646e07846bb29b6af3585fbc5e14dca2
|
|
| BLAKE2b-256 |
f8c829edc7ca3ab4f344cf00bdf7faffa81ab36831b890788f4431382cb388a1
|
File details
Details for the file autograv-0.1.0-py3-none-any.whl.
File metadata
- Download URL: autograv-0.1.0-py3-none-any.whl
- Upload date:
- Size: 7.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.9.17 {"installer":{"name":"uv","version":"0.9.17","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":null,"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
6f13a1293ccb3440c2f685d4699a99304d8b03f37407751ac58661bd30f46be3
|
|
| MD5 |
92ea54e13e9e5fb2e3ddd6a14f3ac527
|
|
| BLAKE2b-256 |
b9c0bcd390a9961b646ffb45c530fcbb9f56391de03d3c2db1a575d619f2bde6
|