Skip to main content

Numerical relativity surrogate model for gravitational waveforms in Jax

Project description

JaxNRSur

Numerical relativity surrogate waveform in Jax

doc Coverage Status

Quickstart

Installation

The recommended way to install jaxNRSur is via uv. uv is a python package and project manager that takes inspiration and is written in rust. You can find the installation instructions here. Once you have uv installed, you can install jaxNRSur withuv add JaxNRSur in the project you are developing. If you want to try this package out, clone this repository and cd into the directory, then run uv sync --dev should produce an environment in the directory. The environment should have .venv/bin/activate which you can run source .venv/bin/activate to activate the environment.

Alternatively, you can install jaxNRSur with pip: clone the repository, then do pip install ..

If you want to use GPU, you will need to run uv sync --all-extras or pip install -U "jax[cuda12]" to install the version of jax which is compatible with an Nvidia GPU.

Basic Usage

jaxNRSur has a pretty simple interface. At its core, a surrogate waveform is parameterized as h(t, theta), where h is the strain, t is the time sample array, theta is the gravitational wave source parameters such as the masses and spins.

Right now, jaxNRSur supports the following models:

  • NRHybSur3dq8Model
  • NRSur7dq4Model
import jax.numpy as jnp
from jaxnrsur.NRHybSur3dq8 import NRHybSur3dq8Model
from jaxnrsur.NRSur7dq4 import NRSur7dq4Model

time = jnp.linspace(-1000, 100, 100000)
params = jnp.array([0.9, 0.1, 0.1])
model = NRHybSur3dq8Model()
h = model(time, params)
params = jnp.array([0.9, 0.0, 0.5, 0.0, 0.5, 0.0, 0.3])
model = NRSur7dq4Model()
h = model(time, params)

Jax features

JIT Compilation

jax not only support JIT compilation like numba, but it also do so in an accelerator-aware manner, meaning once we have developed the source code in jax, it is immediately compatible with accelerators such as GPUs. To use JIT to speed up the code, all you have to do is the following:

#Let's use NRSur7dq4 as an example
model = NRHybSur3dq8Model()
jitted_model = eqx.filter_jit(model)
h = jitted_model(time, params)

Note that since our NRHybSur3dq8 contains some parameters that are not compatible with jax JIT tranformation, we built the package on top of equinox, which is a JAX-compatible library for building neural networks. It allows us to write object oriented code that knows how to handle the parameters associated with each model without the need of writing purely functional code. Instead of using the default jax JIT transformation jax.jit, we use equinox.filter_jit, which seperate out parameters that are not compatible with jax.jit before it runs the transformation under the hood.

Automatic Differentiation

The next feature offered by jax is automatic differentiation. This allows us to compute the gradient of a function with respect to its input parameters. This is the corner stone of deep learning, as it allows us to use gradient descent to optimize the parameters of a model. In our case, for people who are interested in parameter estimation, perhaps one may want to use the gradient of the posterior function with respect to the parameters. For people who are interested in optimizing the waveform parameters, one can compute the gradient of the waveform with respect to the parameters. Here we give two examples of building the gradient functions for NRHybSur3dq8, one with respect to the time grid, and one with respect to the parameters.

def target(time, params):
  return jnp.sum(model(time, params)).real

grad_target_time = jax.grad(target, argnums=0)
grad_target_params = jax.grad(target, argnums=1)

Vectorization

jax offers the use of vmap to vectorize functions. This is different from a for loop considering jax will fuse some of the operations under the hood to achieve better performance. On a CPU this makes a small but noticible differences since modern day CPUs offers some vectorization capabilities, but where this really shines is on accelerators. From some initial microbenchmarks on a 5070Ti (16GB VRAM), the NRSur7dq4 waveform can be evaluated for 500 parameters in ~300 ms.

Similar to JIT, we need to use eqx.filter_vmap instead of jax.jit in our case. Another note is that vmap does not compile the code, so one needs to use eqx.filter_jit on top of eqx.filter_vmap to compile the code.

params = jnp.array([[0.9, 0.1, 0.1]])
params_multi = jnp.repeat(params, 10, axis=0)
h_multi = eqx.filter_jit(eqx.filter_vmap(model.get_waveform, in_axes=(None, 0)))(
    time, params_multi
)

Benchmark

Static Badge

Local data cache

If the data is not already downloaded, then this package will look for the data on Zenodo and download it into $HOME/.jaxNRSur. If the data is already downloaded, then the package will reuse the cached data

Attribution

Coming soon. For now, give us a star and keep an eye out!

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxnrsur-1.0.0.tar.gz (1.2 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jaxnrsur-1.0.0-py3-none-any.whl (30.1 kB view details)

Uploaded Python 3

File details

Details for the file jaxnrsur-1.0.0.tar.gz.

File metadata

  • Download URL: jaxnrsur-1.0.0.tar.gz
  • Upload date:
  • Size: 1.2 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.23

File hashes

Hashes for jaxnrsur-1.0.0.tar.gz
Algorithm Hash digest
SHA256 63141efaf76fc2ebfe123f4aad686563bca99cf03c9cd3e2a23e848e47a2dd7a
MD5 8e24f226dc679ea9ecd36c9285f2cfb0
BLAKE2b-256 dc6f6e3f368897744bb27d8da9492fad81f9c193c63cdfb30f567a5262982837

See more details on using hashes here.

File details

Details for the file jaxnrsur-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: jaxnrsur-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 30.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.23

File hashes

Hashes for jaxnrsur-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 a78667ab7da28eac620f68cd5225d208d63fea538547c49bacbab185d7e4fccb
MD5 5be805a21c88d3833abca03f0d18abfd
BLAKE2b-256 96741a5547e1db279b4741cbfc19582583cc4eeada324b7b85d815b01873e9fc

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page