JAX-accelerated FAST-PT for computing perturbation theory power spectra
Project description
JAX-PT
JAX-PT is a rewrite of the FAST-PT codebase to be compatible with JAX's autodifferentiation and JIT compilation tools. This code can be integrated into full JAX data computation pipelines or used on its own. When compiled, the main JAX-PT functions (same as Fast-PT) can see a 5-100x speed increase on Fast-PT 4.0. (depending on the function) For more in depth examples on the features of functionality of Jax-PT, please see examples.
FAST-PT
FAST-PT is a code to calculate quantities in cosmological perturbation theory at 1-loop (including, e.g., corrections to the matter power spectrum). The code utilizes Fourier methods combined with analytic expressions to reduce the computation time to scale as N log N, where N is the number of grid points in the input linear power spectrum.
Installation
Default Installation
pip install jax-pt
Dev Installation:
pip install jax-pt[dev]
GPU Usage
JAX-PT allows for you to specify a device to run your computations on. During init pass 'cpu', 'gpu', or any other jax.Device to the device kwarg:
import jax
import jax.numpy as jnp
from jaxpt import JAXPT
# Check available devices
print("Available devices:", jax.devices())
k = jnp.logspace(-3, 1, 1000)
# Create JAXPT instance (defaults to CPU)
jpt = JAXPT(k, warmup="moderate")
# Specify to use GPU
jpt = JAXPT(k, warmup="moderate", device="gpu")
# Add a different jax Device
devices = jax.devices()
jpt = JAXPT(k, warmup="moderate", device=devices[0]) # or any index from devices list
Please remember to install the correct jax CUDA libraries for your CUDA version. For example:
pip install jax[cuda12]
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file jax_pt-1.0.0.tar.gz.
File metadata
- Download URL: jax_pt-1.0.0.tar.gz
- Upload date:
- Size: 23.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.1
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
4314eb2e090c1abef06d42a0cc8aebb1feb3081c06dd363e95d729ab730ce1c6
|
|
| MD5 |
952032cc0ee425d0e2dff2a6cd6ac076
|
|
| BLAKE2b-256 |
fa021f882cc85444756ce220e2e501b07a6b46e9fbef1347a1325a8eb5ea3467
|
File details
Details for the file jax_pt-1.0.0-py3-none-any.whl.
File metadata
- Download URL: jax_pt-1.0.0-py3-none-any.whl
- Upload date:
- Size: 18.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.1
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
00c731307197359e93cb81733715ebb4a420b9441828152e122fdcb539cb91f3
|
|
| MD5 |
56ad835f30272c725cb5c766c17f735d
|
|
| BLAKE2b-256 |
bb3e11a6111718d439ad6377f9e383fd8eb6654d456a60b567c423b1bd2785db
|