Skip to main content

Automatic differentiation for numerical relativity using JAX

Project description

autograv

Bridging numerical relativity and automatic differentiation using JAX

autograv is a Python library that uses JAX and automatic differentiation to compute various tensors and quantities from Einstein's general theory of relativity. Given a metric function, it can calculate Christoffel symbols, curvature tensors, and solve the Einstein field equations with high numerical precision.

Features

  • Automatic Differentiation: Uses JAX's jax.jacfwd for forward-mode automatic differentiation to compute derivatives of metric tensors with exact numerical precision
  • Tensor Calculus: Leverages jax.numpy.einsum for efficient Einstein summation notation operations
  • High Precision: Configured to use 64-bit floating point arithmetic for maximum accuracy
  • Pure Functions: All computations are functional and composable

What can you compute?

Given a metric tensor function, autograv can compute:

  • Christoffel symbols (affine connection coefficients)
  • Torsion tensor (verification that connection is symmetric)
  • Riemann curvature tensor (intrinsic curvature of spacetime)
  • Ricci tensor and Ricci scalar (curvature related to volume change)
  • Einstein tensor (left-hand side of Einstein field equations)
  • Stress-energy-momentum tensor (mass-energy content)
  • Kretschmann invariant (scalar curvature for detecting singularities)

Installation

# Using uv
uv pip install -e .

# Or using pip
pip install -e .

Quick Start

import jax.numpy as jnp
from autograv import (
    spherical_polar_metric,
    christoffel_symbols,
    riemann_tensor,
    einstein_tensor,
)

# Define coordinates
coordinates = jnp.array([5, jnp.pi/3, jnp.pi/2], dtype=jnp.float64)

# Compute Christoffel symbols for the 2-sphere
christoffels = christoffel_symbols(coordinates, spherical_polar_metric)
print(christoffels)

# Compute Riemann tensor
riemann = riemann_tensor(coordinates, spherical_polar_metric)
print(riemann)

Examples

The examples/ directory contains complete examples:

  • sphere_example.py: Computing quantities for a 2-sphere metric
  • schwarzschild_example.py: Computing quantities for the Schwarzschild black hole metric

Run them with:

uv run python examples/sphere_example.py
uv run python examples/schwarzschild_example.py

How it Works

Automatic Differentiation

Traditional approaches to computing derivatives in physics use either:

  • Symbolic differentiation: Exact but computationally expensive
  • Numerical differentiation: Fast but prone to floating-point errors

Automatic differentiation (autodiff) combines the best of both worlds by:

  1. Tracing computational operations to build a directed acyclic graph (DAG)
  2. Computing gradients via the chain rule by traversing the graph
  3. Achieving exact numerical precision at machine precision limits

JAX Integration

JAX provides:

  • jax.jacfwd: Forward-mode autodiff for computing Jacobians
  • jax.numpy.einsum: Efficient Einstein summation for tensor operations
  • NumPy-compatible API with GPU/TPU acceleration support

Example: Christoffel Symbols

Given a metric tensor g_ij, the Christoffel symbols are:

Γ^j_kl = (1/2) g^jm (∂g_mk/∂x^l + ∂g_lm/∂x^k - ∂g_kl/∂x^m)

In code:

def christoffel_symbols(coordinates, metric):
    g = metric(coordinates)
    g_inv = jnp.linalg.inv(g)
    jacobian = jax.jacfwd(metric)(coordinates)  # Automatic differentiation!
    
    return 0.5 * jnp.einsum('jm, klm -> jkl', g_inv,
                            jnp.einsum('klm -> mkl', jacobian) +
                            jnp.einsum('klm -> lmk', jacobian) - jacobian)

API Reference

Metrics

  • minkowski_metric(coordinates): Flat spacetime metric
  • spherical_polar_metric(coordinates): 2-sphere metric in (r, θ, φ)

Core Functions

  • christoffel_symbols(coordinates, metric): Affine connection coefficients
  • torsion_tensor(coordinates, metric): Antisymmetric part of connection
  • riemann_tensor(coordinates, metric): Curvature tensor
  • ricci_tensor(coordinates, metric): Trace of Riemann tensor
  • ricci_scalar(coordinates, metric): Scalar curvature
  • kretschmann_invariant(coordinates, metric): Curvature invariant
  • einstein_tensor(coordinates, metric): G_ij = R_ij - (1/2)g_ij R
  • stress_energy_momentum_tensor(coordinates, metric): T_ij from Einstein equations

Utilities

  • close_to_zero(func): Decorator to suppress near-zero numerical noise
  • TOLERANCE: Threshold for zero suppression (default: 1e-8)

Requirements

  • Python 3.11+
  • JAX (CPU-only on Windows, GPU/TPU support on Linux/macOS)
  • NumPy

Future Work

  • Add more standard metrics (Kerr, Kerr-Newman, FRW, etc.)
  • Implement Weyl tensor and Weyl invariant
  • Support for JIT compilation with @jax.jit
  • GPU/TPU acceleration examples
  • Integration with differential equation solvers
  • Visualization tools for curvature

License

MIT

Acknowledgments

Based on concepts from the blog post "Bridging numerical relativity and automatic differentiation using JAX". This project demonstrates the synergy between modern machine learning tools and classical physics computations.

References

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

autograv-0.1.0.tar.gz (6.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

autograv-0.1.0-py3-none-any.whl (7.2 kB view details)

Uploaded Python 3

File details

Details for the file autograv-0.1.0.tar.gz.

File metadata

  • Download URL: autograv-0.1.0.tar.gz
  • Upload date:
  • Size: 6.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.17 {"installer":{"name":"uv","version":"0.9.17","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":null,"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for autograv-0.1.0.tar.gz
Algorithm Hash digest
SHA256 e280929301c836c81d7ecde1bd72867893fee8c5249448d720cbd5530d94afa0
MD5 646e07846bb29b6af3585fbc5e14dca2
BLAKE2b-256 f8c829edc7ca3ab4f344cf00bdf7faffa81ab36831b890788f4431382cb388a1

See more details on using hashes here.

File details

Details for the file autograv-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: autograv-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 7.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.17 {"installer":{"name":"uv","version":"0.9.17","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":null,"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for autograv-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 6f13a1293ccb3440c2f685d4699a99304d8b03f37407751ac58661bd30f46be3
MD5 92ea54e13e9e5fb2e3ddd6a14f3ac527
BLAKE2b-256 b9c0bcd390a9961b646ffb45c530fcbb9f56391de03d3c2db1a575d619f2bde6

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page