Skip to main content

JAX-accelerated FAST-PT for computing perturbation theory power spectra

Project description

JAX-PT

JAX-PT is a rewrite of the FAST-PT codebase to be compatible with JAX's autodifferentiation and JIT compilation tools. This code can be integrated into full JAX data computation pipelines or used on its own. When compiled, the main JAX-PT functions (same as Fast-PT) can see a 5-100x speed increase on Fast-PT 4.0. (depending on the function) For more in depth examples on the features of functionality of Jax-PT, please see examples.

FAST-PT

FAST-PT is a code to calculate quantities in cosmological perturbation theory at 1-loop (including, e.g., corrections to the matter power spectrum). The code utilizes Fourier methods combined with analytic expressions to reduce the computation time to scale as N log N, where N is the number of grid points in the input linear power spectrum.

arXiv:1603.04826 arXiv:1609.05978 arXiv:1708.09247

Installation

Default Installation

pip install jax-pt

Dev Installation:

pip install jax-pt[dev]

GPU Usage

JAX-PT allows for you to specify a device to run your computations on. During init pass 'cpu', 'gpu', or any other jax.Device to the device kwarg:

import jax
import jax.numpy as jnp
from jaxpt import JAXPT

# Check available devices
print("Available devices:", jax.devices())

k = jnp.logspace(-3, 1, 1000)

# Create JAXPT instance (defaults to CPU)
jpt = JAXPT(k, warmup="moderate")

# Specify to use GPU
jpt = JAXPT(k, warmup="moderate", device="gpu")

# Add a different jax Device
devices = jax.devices()
jpt = JAXPT(k, warmup="moderate", device=devices[0]) # or any index from devices list

Please remember to install the correct jax CUDA libraries for your CUDA version. For example:

pip install jax[cuda12]

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_pt-1.0.0.tar.gz (23.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax_pt-1.0.0-py3-none-any.whl (18.0 kB view details)

Uploaded Python 3

File details

Details for the file jax_pt-1.0.0.tar.gz.

File metadata

  • Download URL: jax_pt-1.0.0.tar.gz
  • Upload date:
  • Size: 23.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.1

File hashes

Hashes for jax_pt-1.0.0.tar.gz
Algorithm Hash digest
SHA256 4314eb2e090c1abef06d42a0cc8aebb1feb3081c06dd363e95d729ab730ce1c6
MD5 952032cc0ee425d0e2dff2a6cd6ac076
BLAKE2b-256 fa021f882cc85444756ce220e2e501b07a6b46e9fbef1347a1325a8eb5ea3467

See more details on using hashes here.

File details

Details for the file jax_pt-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: jax_pt-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 18.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.1

File hashes

Hashes for jax_pt-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 00c731307197359e93cb81733715ebb4a420b9441828152e122fdcb539cb91f3
MD5 56ad835f30272c725cb5c766c17f735d
BLAKE2b-256 bb3e11a6111718d439ad6377f9e383fd8eb6654d456a60b567c423b1bd2785db

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page