A lightweight JAX-only version of redback for electromagnetic transient analysis
Project description
Redback-JAX
A lightweight JAX-only rewrite of redback for electromagnetic transient modeling and Bayesian inference, designed to run efficiently on GPUs and TPUs in float32.
Overview
Redback-JAX reimplements redback's analytical transient models in JAX, using log10-space arithmetic throughout to stay float32-safe on GPU hardware. All bolometric functions return log10(L) rather than linear luminosities (which exceed the float32 maximum of ~3.4×10³⁸ erg/s). The full spectra pipeline — photosphere, blackbody SED, and bandflux integration — also operates in log10 space end-to-end.
Features
- Float32-safe physics: All models operate in log10 space; no overflow on GPU even for luminosities ~10⁴⁵ erg/s
- JIT-compiled and differentiable: Every model is decorated with
@jax.jit; gradients flow through the full pipeline viajax.grad vmap-based diffusion integrals: Arnett-style diffusion usesjax.vmapover time points with log-mirror quadrature nodes- Spectra pipeline:
make_spectra_model(bolometric_fn)wraps any bolometric model to produce time × wavelength spectra for bandflux/magnitude comparison - Clean inference API:
Prior,Likelihood,NestedSampler, andMCMCSampler— compose a full Bayesian fit in ~15 lines - Multi-sampler support: BlackJAX NUTS (MCMC) and nested sampling via
blackjax.nss
Models
Bolometric models (return log10_lbol in erg/s)
| Function | Physics | Reference |
|---|---|---|
arnett_bolometric |
Ni/Co decay + Arnett diffusion | Arnett 1982 |
magnetar_powered_bolometric |
Dipole spin-down + Arnett diffusion | Nicholl+ 2017 |
csm_interaction_bolometric |
Forward/reverse shocks + CSM diffusion | Chatzopoulos+ 2013 |
tde_analytical_bolometric |
t⁻⁵/³ fallback + Arnett diffusion | — |
shock_cooling_bolometric |
Shock-cooling envelope (n=10) | Piro 2021 |
shocked_cocoon_bolometric |
Shocked jet cocoon | Piro & Kollmeier 2018 |
metzger_kilonova_bolometric |
r-process ODE, 200 shells | Metzger 2017 |
magnetar_boosted_kilonova_bolometric |
r-process ODE + magnetar injection | Yu+ 2013 |
All bolometric functions return log10_lbol (log base-10 of luminosity in erg/s). This is the natural unit for GPU inference — float32 can represent log10 values for any physically realistic luminosity.
Spectra pipeline
make_spectra_model(bolometric_fn) wraps any bolometric model into a full SED pipeline:
- Calls
bolometric_fn(time, **kwargs)→log10_lbol - Computes photospheric temperature and radius in log10 space (with temperature floor)
- Evaluates blackbody flux density in log10 space
- Returns
(time, lambdas, spectra)in observer frame
Fitting bolometric data
Since models return log10_lbol, fit observed bolometric luminosities in log10 space:
import jax.numpy as jnp
from redback_jax.models.supernova_models import arnett_bolometric
# Observed data
log10_lbol_obs = jnp.log10(observed_lbol) # convert once
log10_lbol_err = sigma_lbol / (observed_lbol * jnp.log(10.0)) # propagate errors
# Model prediction
log10_lbol_model = arnett_bolometric(time, f_nickel=0.5, mej=1.0,
vej=10000.0, kappa=0.1, kappa_gamma=10.0)
# Gaussian log-likelihood in log10 space
log_like = -0.5 * jnp.sum(((log10_lbol_obs - log10_lbol_model) / log10_lbol_err)**2)
Bayesian inference — photometric fitting
The Prior / Likelihood / NestedSampler / MCMCSampler API handles the full pipeline: model evaluation, bandflux integration, and sampling.
import jax
from redback_jax.inference import Prior, Uniform, Likelihood, NestedSampler, MCMCSampler
from redback_jax.utils import luminosity_distance_cm
REDSHIFT = 0.01
DL_CM = luminosity_distance_cm(REDSHIFT) # ~1.37e26 cm
# Free parameters
prior = Prior([
Uniform(58580, 58620, name='t0'), # MJD explosion epoch
Uniform(0.05, 0.30, name='f_nickel'),
Uniform(0.5, 3.0, name='mej'),
Uniform(3000, 12000, name='vej'),
])
# Likelihood — transient.time (MJD), transient.y (AB mag), transient.y_err, transient.bands
likelihood = Likelihood(
model='arnett_spectra',
transient=transient,
fixed_params={
'redshift': REDSHIFT,
'lum_dist': DL_CM,
'temperature_floor': 5000.0,
'kappa': 0.07,
'kappa_gamma': 0.1,
},
)
# Nested sampling (BlackJAX)
ns_result = NestedSampler(likelihood, prior, outdir='results/').run(jax.random.PRNGKey(0))
ns_result.summary()
# Or MCMC with NUTS (BlackJAX)
mcmc_result = MCMCSampler(likelihood, prior, n_warmup=500, n_samples=2000, n_chains=4).run(
jax.random.PRNGKey(1)
)
mcmc_result.summary()
Available models
Pass any string from redback_jax.models.MODELS as the model argument:
from redback_jax.models import MODELS
print(list(MODELS.keys()))
# ['arnett_spectra', 'magnetar_spectra', 'csm_spectra', ...]
Direct spectra / magnitude evaluation
To compute magnitudes outside of inference (e.g. for plotting):
from redback_jax.sources import PrecomputedSpectraSource
from redback_jax.utils import luminosity_distance_cm
source = PrecomputedSpectraSource.from_arnett_model(
f_nickel=0.15, mej=1.0, vej=8000.0,
redshift=0.01,
cosmo_H0=67.66, cosmo_Om0=0.3111,
)
# AB magnitude in ztfr at a set of phases
phases = jnp.linspace(-5, 40, 200)
mags = source.bandmag({'amplitude': 1.0}, 'ztfr', phases)
Parameter conventions
Some parameters changed from the original redback package for float32 safety:
| Model | Old parameter | New parameter | Reason |
|---|---|---|---|
tde_analytical_bolometric |
l0 (erg/s, ~10⁴³) |
log10_l0 |
Linear value overflows float32 |
shock_cooling_bolometric |
mass (Msun), radius (cm), energy (erg) |
log10_mass, log10_radius, log10_energy |
Intermediate products overflow float32 |
All other parameter names match redback exactly.
Float32 design
Physical luminosities of transients (~10³⁸–10⁴⁵ erg/s) exceed float32 max (~3.4×10³⁸). Redback-JAX solves this by:
- Storing all engine luminosities as
log10(L)throughout - Using log-sum-exp for combining decay terms (Ni/Co engine)
- Normalising ODE state variables by a scale factor (
E_scale) in the kilonova scan - Computing prefactors in log10 before any exponentiation
- Keeping the blackbody SED, temperature, and photospheric radius all in log10 space
The only step that materialises linear values is the final bandflux integral over the SED — where the flux densities (~10⁻²⁰ erg/s/cm²/Å) are comfortably within float32 range.
Installation
git clone https://github.com/nikhil-sarin/redback-jax.git
cd redback-jax
pip install -e .
Dependencies
Python 3.12+ required.
Core: jax, numpy, scipy, pandas, matplotlib, astropy, wcosmo
Optional (inference): blackjax, flowmc, optax
Optional (bandflux): jax-bandflux (jax_supernovae)
Related Projects
- redback — the original full-featured package
- fiestaEM — similar JAX-based transient inference framework
- JAX — the underlying numerical computing library
License
GNU General Public License v3.0 — see LICENSE.
Acknowledgments
Based on the original redback package. Please cite the redback paper if you use this software.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file redback_jax-0.3.0.tar.gz.
File metadata
- Download URL: redback_jax-0.3.0.tar.gz
- Upload date:
- Size: 99.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ac2d1b3744db1fac2f02180b86b833120b7a178b06060c15d7db84437f53fb1b
|
|
| MD5 |
191de93854cc94e6dcc9848f21734fc1
|
|
| BLAKE2b-256 |
c49027cef8d38c8eb74209b8e2b16ac45d15e718ba0b90c1fb6c451ff014950a
|
Provenance
The following attestation bundles were made for redback_jax-0.3.0.tar.gz:
Publisher:
publish.yml on nikhil-sarin/redback-jax
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
redback_jax-0.3.0.tar.gz -
Subject digest:
ac2d1b3744db1fac2f02180b86b833120b7a178b06060c15d7db84437f53fb1b - Sigstore transparency entry: 1201136903
- Sigstore integration time:
-
Permalink:
nikhil-sarin/redback-jax@e1c74fa4243cb834308b466d81201a9eeb30f309 -
Branch / Tag:
refs/tags/v0.3.0 - Owner: https://github.com/nikhil-sarin
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@e1c74fa4243cb834308b466d81201a9eeb30f309 -
Trigger Event:
release
-
Statement type:
File details
Details for the file redback_jax-0.3.0-py3-none-any.whl.
File metadata
- Download URL: redback_jax-0.3.0-py3-none-any.whl
- Upload date:
- Size: 76.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
8872e4764a0235b868366c658edf5865087a1d262b58d883c79b8cb8466d92a9
|
|
| MD5 |
266e9ac6c5b4aab7f57494e0e408a01a
|
|
| BLAKE2b-256 |
8eb7f2686a641e01c4a6aa8587823bc7d242f42148763714ef9014ba0ebf0a7a
|
Provenance
The following attestation bundles were made for redback_jax-0.3.0-py3-none-any.whl:
Publisher:
publish.yml on nikhil-sarin/redback-jax
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
redback_jax-0.3.0-py3-none-any.whl -
Subject digest:
8872e4764a0235b868366c658edf5865087a1d262b58d883c79b8cb8466d92a9 - Sigstore transparency entry: 1201136915
- Sigstore integration time:
-
Permalink:
nikhil-sarin/redback-jax@e1c74fa4243cb834308b466d81201a9eeb30f309 -
Branch / Tag:
refs/tags/v0.3.0 - Owner: https://github.com/nikhil-sarin
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@e1c74fa4243cb834308b466d81201a9eeb30f309 -
Trigger Event:
release
-
Statement type: