A minimal JAX library for connectivity modelling at scale
Project description
JAXScape is a minimal JAX library for connectivity analysis at scales. It provide key utilities to build your own connectivity analysis workflow, including
- differentiable and GPU-accelerated graph distance metrics
- differentiable raster to graph and graph to raster mappings
- moving window utilities for implementing large-scale connectivity analysis pipelines
JAXScape leverages JAX's capabilities to accelerate distance computations on CPUs/GPUs/TPUs, while ensuring differentiability of all implemented classes and methods for awesome sensitivity analysis and optimization.
Installation
uv add jaxscape
For GPU compatibility, install JAX following the official JAX installation guide. JAXScape will automatically use the JAX backend you have configured.
You may be required to install optional linear solvers for large-scale resistance distance computations (see the documentation page for details).
Quick start
import numpy as np
import matplotlib.pyplot as plt
import jax.numpy as jnp
from jaxscape import GridGraph
from jaxscape import LCPDistance, ResistanceDistance, RSPDistance
# loading jax array representing permeability
permeability = jnp.array(np.loadtxt("permeability.csv", delimiter=",")) + 0.001
# Create a grid graph where edge weights are the average permeability of the two nodes
grid = GridGraph(grid=permeability, fun=lambda x, y: (x + y) / 2)
# We set the source to the top left pixel, and compute distances to all other pixels with three different distance metrics
source = grid.coord_to_index(jnp.array([0]), jnp.array([0]))
distances = {
"LCP distance": LCPDistance(),
"Resistance distance": ResistanceDistance(),
"RSP distance": RSPDistance(theta=0.01, cost=lambda x: -jnp.log(x)),
}
fig, axs = plt.subplots(1, 3, figsize=(10, 4))
for ax, (title, distance) in zip(axs, distances.items()):
# Compute distances from all nodes to the source
dist_to_node = distance(grid, source)
# Convert from node values to 2D array and mask low-permeability areas
dist_array = grid.node_values_to_array(dist_to_node.ravel())
dist_array = dist_array * (permeability > 0.1) # Mask barriers
# Plotting
im = ax.imshow(dist_array, cmap="magma")
ax.axis("off")
ax.set_title(title)
fig.colorbar(im, ax=ax, shrink=0.2)
fig.suptitle("Distance to top left pixel")
plt.tight_layout()
plt.show()
But what's really cool about JAXscape is that you can autodiff through thoses distances! Check out the documentation to learn about applications and more!
Documentation
Comprehensive documentation is available at https://vboussange.github.io/jaxscape
Features and roadmap 🚀
See issues; most notably:
- Support for direct and iterative linear sparse solvers on GPU (cf spineax)
- Benchmark against
CircuitScape,ConScape.jlandradish.
License
jaxscape is distributed under the terms of the MIT license.
Related packages
- gdistance
- ConScape
- Circuitscape
- graphhab
- conefor
- resistanceGA
- landscapemetrics
- radish
Citation
If you use JAXScape in your research, please cite:
@software{jaxscape2024,
author = {Boussange, Victor},
title = {JAXScape: A minimal JAX library for connectivity modelling at scale},
year = {2025},
doi = {10.5281/zenodo.15267703},
url = {https://github.com/vboussange/jaxscape}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file jaxscape-0.0.7.tar.gz.
File metadata
- Download URL: jaxscape-0.0.7.tar.gz
- Upload date:
- Size: 2.3 MB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.14
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
393dfda80937c44a3c4c10845e44f4c87e10bdecdc3ef1376cd83f2b5f604610
|
|
| MD5 |
7c0c7168c00345f8e671221ac7d6f920
|
|
| BLAKE2b-256 |
bada5e99c036f81a66671de23f4742cd4622d00c43234ec82cdfa34343fae355
|
File details
Details for the file jaxscape-0.0.7-py3-none-any.whl.
File metadata
- Download URL: jaxscape-0.0.7-py3-none-any.whl
- Upload date:
- Size: 30.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.14
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
1226bee7c20f389444d1a89f9159ac801c0c754b179f753a62b020eff3785642
|
|
| MD5 |
b10577b57c797102a5a9a374333d568d
|
|
| BLAKE2b-256 |
34d505e73f82ef724ccc71b8db532189a6495b800f3217c3e872fee6998e051e
|