Skip to main content

A JAX-based package for calculating supernovae Bandfluxes

Project description

JAX Bandflux for Supernovae

PyPI version

This repository presents an implementation of supernova light curve modelling using JAX. The codebase offers a differentiable approach to core SNCosmo functionality implemented in JAX.

Installation

To install the repository, please execute the following command:

git clone git@github.com:samleeney/JAX-bandflux.git && cd JAX-bandflux && python -m venv venv && source venv/bin/activate && pip install -r requirements.txt

For nested sampling functionality, install the required packages with:

pip install git+https://github.com/handley-lab/blackjax@nested_sampling distrax

Usage

Running the Code

This repository follows a structure similar to the Using a custom fitter example provided by SNCosmo. You may define the objective function as illustrated below:

def objective(parameters):
    # Create a dictionary containing parameters
    params = {
        'z': parameters[0],
        't0': parameters[1],
        'x0': parameters[2],
        'x1': parameters[3],
        'c': parameters[4]
    }
    
    # Compute model fluxes for all observations
    model_flux = []
    for i, (band_name, t, zp, zpsys) in enumerate(zip(data['band'], data['time'], data['zp'], data['zpsys'])):
        flux = salt3_bandflux(t, band_dict[band_name], params, zp=zp, zpsys=zpsys)
        # Extract the scalar value from the array
        flux_val = float(flux.ravel()[0])
        model_flux.append(flux_val)
    
    # Convert to a JAX array and calculate the chi-squared statistic
    model_flux = jnp.array(model_flux)
    chi2 = jnp.sum(((data['flux'] - model_flux) / data['fluxerr'])**2)
    
    # Display the total chi-squared for debugging purposes
    print(f"\nTotal chi-squared: {float(chi2):.2f}\n")
    
    return chi2

Pass this function to your sampler of choice. A complete example, analogous to the SNCosmo case, is provided in fmin_bfgs.py. A nested sampling implementation is also available in ns.py.

To execute an example, run:

python examples/fmin_bfgs.py

Data Loading

The repository offers flexible routines for loading supernova light curve data, particularly optimised for HSF DR1 format. There are various methods to load and process your data.

To load data for a specific supernova:

from jax_supernovae.data import load_hsf_data

# Load data for a specific supernova
data = load_hsf_data('19agl', base_dir='data')

The data is returned as an Astropy Table that includes:

  • time: Observation times (MJD)
  • band: Filter or band names
  • flux: Flux measurements
  • fluxerr: Errors associated with flux measurements
  • zp: Zero points (defaults to 27.5 when not provided)

For analysis-ready JAX arrays and automatic bandpass registration, use:

from jax_supernovae.data import load_and_process_data

# Load and process data with automatic bandpass registration
times, fluxes, fluxerrs, zps, band_indices, bridges = load_and_process_data(
    sn_name='19agl',
    data_dir='data'  # Optional, the default is 'data'
)

This function performs the following steps:

  1. Loads raw data from the specified directory (default: 'data').
  2. Registers the required bandpasses.
  3. Converts data into JAX arrays.
  4. Generates band indices for efficient processing.
  5. Precomputes bridge data for each band, required for JAX optimisation.

Testing

This repository implements the JAX version of the SNCosmo bandflux function. Although the implementations are nearly identical, a minor difference exists due to the absence of a specific interpolation function in JAX, resulting in a discrepancy of approximately 0.001% in bandflux results. Tests have been provided to confirm that key functions in the SNCosmo version correspond with our JAX implementation. It is recommended to run these tests, especially following any modifications.

What is the .airules file?

Large Language Models are frequently used to optimise research and development. The .airules file provides context to help LLMs understand and work with this codebase. This is particularly important for new code that will not have been included in model training datasets. The file contains detailed information about data structures, core functions, critical implementation rules, and testing requirements. If you are using Cursor, rename this file to .cursorrules and it will be automatically interpreted as context.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_bandflux-0.1.5.tar.gz (16.1 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax_bandflux-0.1.5-py3-none-any.whl (4.0 MB view details)

Uploaded Python 3

File details

Details for the file jax_bandflux-0.1.5.tar.gz.

File metadata

  • Download URL: jax_bandflux-0.1.5.tar.gz
  • Upload date:
  • Size: 16.1 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.10.16

File hashes

Hashes for jax_bandflux-0.1.5.tar.gz
Algorithm Hash digest
SHA256 96bd0bc4ec09b12bfbecf8179f238d2d005fb516010c6be8d1adaea9b91a70a1
MD5 561f9e6c4dc3aed130cb5e20705d0301
BLAKE2b-256 a37d8205b7ecfac9b74e0f075e405de031bedccecd5a45be0ab94fb96a7741b0

See more details on using hashes here.

File details

Details for the file jax_bandflux-0.1.5-py3-none-any.whl.

File metadata

  • Download URL: jax_bandflux-0.1.5-py3-none-any.whl
  • Upload date:
  • Size: 4.0 MB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.10.16

File hashes

Hashes for jax_bandflux-0.1.5-py3-none-any.whl
Algorithm Hash digest
SHA256 6321c9ecf7efd4088edb119f7f5036d5c2805cce5d3887efb58ba4bb02903ff5
MD5 3b17bd2521d4406345d8cf33855232e3
BLAKE2b-256 32c9089f5e97f078b627fa2d89280f784b3feb62e0716bb156c8327dbc9dabaf

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page