Skip to main content

lenstronomy, but in JAX

Project description

https://github.com/lenstronomy/JAXtronomy/workflows/Tests/badge.svg https://img.shields.io/badge/License-BSD_3--Clause-blue.svg https://codecov.io/gh/lenstronomy/JAXtronomy/graph/badge.svg?token=6EJAX8CF62 https://img.shields.io/badge/code%20style-black-000000.svg https://img.shields.io/badge/%20formatter-docformatter-fedcba.svg https://img.shields.io/badge/%20style-sphinx-0a507a.svg https://img.shields.io/pypi/v/JAXtronomy?label=PyPI&logo=pypi

JAX port of lenstronomy, for parallelized, GPU accelerated, and differentiable gravitational lensing and image simulations.

The goal of this library is to reimplement lenstronomy functionalities in pure JAX to allow for automatic differentiation, GPU acceleration, and batched computations.

Guiding Principles:

  • Strive to be a drop-in replacement for lenstronomy, i.e. provide a close match to the lenstronomy API.

  • Each function/feature will be tested against the reference lenstronomy implementation.

  • This package will aim to be a subset of lenstronomy (i.e. only contains functions with a reference lenstronomy implementation).

  • Implementations should be easy to read and understand.

  • Code should be pip installable on any machine, no compilation required.

  • Any notable differences between the JAX and reference implementations will be clearly documented.

Installation and Usage

GPU

To use JAXtronomy with an NVIDIA GPU on Linux, first install JAX with

pip install -U "jax[cuda13]"

For other GPUs or operating systems, installation is more complicated. See the JAX installation instructions for GPU for more details. Then, install JAXtronomy with

pip install jaxtronomy

By default, JAX will still use CPU for computations. To change this, run the following line of code immediately after importing JAX

jax.config.update("jax_platform_name", "gpu")

CPU

For standard CPU-only usage, JAXtronomy and JAX can both be installed with

pip install jaxtronomy

For computations parallelized across CPU cores, an environment variable needs to be set before importing JAX indicating the number CPU devices to use. For example, to use 16 CPU cores, this can be done with

os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=16"

Example notebook: An example notebook has been made available, which showcases the features and improvements in JAXtronomy.

Performance comparison between JAXtronomy and lenstronomy

We compare the runtimes between JAXtronomy and lenstronomy by timing 1,000 function executions. While lenstronomy is always run using CPU, JAXtronomy can be run using either CPU or GPU. These tests were run using an Intel(R) Xeon(R) Gold 6338 CPU @ 2.00GHz, an NVIDIA A100 GPU, and JAX version 0.7.0. A performance comparison notebook has been made available for reproducibility.

LensModel ray-shooting

The table below shows how much faster JAXtronomy is compared to lenstronomy for different deflector profiles and different grid sizes. Some comparisons vary significantly with values of function arguments.

Deflector Profile

60x60 grid (JAX w/ cpu)

180x180 grid (JAX w/ cpu)

180x180 grid (JAX w/ gpu)

CONVERGENCE

0.4x

1.1x

0.5x

CSE

1.6x

2.6x

2.6x

EPL

5.1x - 15x

8.4x - 18x

36x - 120x

EPL (jax) vs EPL_NUMBA

1.4x

3.0x

13x

EPL_MULTIPOLE_M1M3M4

2.1x - 7x

6.4x - 13x

42x - 108x

EPL_MULTIPOLE_M1M3M4_ELL

1.8x - 3.3x

2.5x - 3.3x

120x - 140x

GAUSSIAN

1.4x

2.8x

2.3x

GAUSSIAN_POTENTIAL

1.2x

2.6x

1.9x

HERNQUIST

2.0x

3.4x

5.8x

HERNQUIST_ELLIPSE_CSE

3.8x

5.4x

40x

MULTIPOLE

0.9x

1.0x

8.3x - 14x

MULTIPOLE_ELL (m=1, m=3, m=4)

1.5x, 1.5x, 2.1x

2.0x, 2.0x, 2.8x

75x, 75x, 70x

NFW

1.6x

3.3x

4.5x

NFW_ELLIPSE_CSE

4.1x

6.7x

31x

NIE

0.5x

0.5x

2.0x

PJAFFE

1.0x

1.2x

2.8x

PJAFFE_ELLIPSE_POTENTIAL

1.4x

1.6x

3.1x

SHEAR

0.7x

2.0x

0.9x

SIE

0.5x

0.5x

2.0x

SIS

1.4x

3.3x

2.0x

TNFW

2.4x

5.8x

7.5x

LightModel surface brightness

The table below shows how much faster JAXtronomy is compared to lenstronomy for different source profiles and different grid sizes.

Source Profile

60x60 grid (JAX w/ cpu)

180x180 grid (JAX w/ cpu)

180x180 grid (JAX w/ gpu)

CORE_SERSIC

2.0x

6.7x

4.2x

GAUSSIAN

1.0x

2.5x

1.3x

GAUSSIAN_ELLIPSE

1.5x

3.6x

2.0x

MULTI_GAUSSIAN (5 components)

3.7x

11x

7.8x

MULTI_GAUSSIAN_ELLIPSE (5 components)

4.0x

13x

6.9x

SERSIC

1.0x

1.7x

3.9x

SERSIC_ELLIPSE

1.9x

5.7x

3.2x

SERSIC_ELLIPSE_Q_PHI

1.7x

5.5x

3.3x

SHAPELETS (n_max=6)

6.2x

3.4x

15x

SHAPELETS (n_max=10)

6.0x

4.5x

17x

FFT Pixel Kernel Convolution

Convolution runtimes vary significantly, depending on both grid size and kernel size.

  • For a 60x60 grid, and kernel sizes ranging from 3 to 45, jaxtronomy on CPU is about 1.1x to 2.9x faster than lenstronomy, with no obvious correlation to kernel size.

  • For a 60x60 grid, and kernel sizes ranging from 3 to 45, jaxtronomy on GPU is about 1.5x to 3.5x faster than lenstronomy, with JAX performing better with higher kernel sizes.

  • For a 180x180 grid, and kernel sizes ranging from 9 to 135, jaxtronomy on CPU is about 0.7x to 2.5x as fast as lenstronomy, with no obvious correlation to kernel size.

  • For a 180x180 grid, and kernel sizes ranging from 9 t0 135, jaxtronomy on GPU is about 10x to 20x as fast as lenstronomy, with JAX performing better with higher kernel sizes.

Community guidelines

Contributing to jaxtronomy

The guidelines for contributing to JAXtronomy are the same as lenstronomy, which can be found here. In short,

  • Fork the repository

  • Write clean, well-documented code, following conventions

  • Submit pull requests

Reporting issues, seeking support, and feature requests

  • Submit a Github issue

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxtronomy-0.1.2.tar.gz (236.1 kB view details)

Uploaded Source

File details

Details for the file jaxtronomy-0.1.2.tar.gz.

File metadata

  • Download URL: jaxtronomy-0.1.2.tar.gz
  • Upload date:
  • Size: 236.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.3

File hashes

Hashes for jaxtronomy-0.1.2.tar.gz
Algorithm Hash digest
SHA256 944d9bf373caf3d244bfc4cfe0b2405dfb41851ea6a90e0aa68c2ed42506a829
MD5 09f1c94aca3efece74a218788155833c
BLAKE2b-256 c23c58c7cd86a78b31fb1effa6831b5437afc91ce5ecd8e83e3c710ac9e9d068

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page