Skip to main content

A minimal JAX library for connectivity modelling at scale

Project description

tests Build documentation DOI

JAXScape Logo

JAXScape is a minimal JAX library for connectivity analysis at scales. It provide key utilities to build your own connectivity analysis workflow, including

  • differentiable and GPU-accelerated graph distance metrics
  • differentiable raster to graph and graph to raster mappings
  • moving window utilities for implementing large-scale connectivity analysis pipelines

JAXScape leverages JAX's capabilities to accelerate distance computations on CPUs/GPUs/TPUs, while ensuring differentiability of all implemented classes and methods for awesome sensitivity analysis and optimization.

Installation

uv add jaxscape

For GPU compatibility, install JAX following the official JAX installation guide. JAXScape will automatically use the JAX backend you have configured.

You may be required to install optional linear solvers for large-scale resistance distance computations (see the documentation page for details).

Quick start

import numpy as np
import matplotlib.pyplot as plt
import jax.numpy as jnp
from jaxscape import GridGraph
from jaxscape import LCPDistance, ResistanceDistance, RSPDistance

# loading jax array representing permeability
permeability = jnp.array(np.loadtxt("permeability.csv", delimiter=",")) + 0.001

# Create a grid graph where edge weights are the average permeability of the two nodes
grid = GridGraph(grid=permeability, fun=lambda x, y: (x + y) / 2)

# We set the source to the top left pixel, and compute distances to all other pixels with three different distance metrics
source = grid.coord_to_index(jnp.array([0]), jnp.array([0]))

distances = {
    "LCP distance": LCPDistance(),
    "Resistance distance": ResistanceDistance(),
    "RSP distance": RSPDistance(theta=0.01, cost=lambda x: -jnp.log(x)),
}

fig, axs = plt.subplots(1, 3, figsize=(10, 4))
for ax, (title, distance) in zip(axs, distances.items()):
    # Compute distances from all nodes to the source
    dist_to_node = distance(grid, source)

    # Convert from node values to 2D array and mask low-permeability areas
    dist_array = grid.node_values_to_array(dist_to_node.ravel())
    dist_array = dist_array * (permeability > 0.1)  # Mask barriers
    
    # Plotting
    im = ax.imshow(dist_array, cmap="magma")
    ax.axis("off")
    ax.set_title(title)
    fig.colorbar(im, ax=ax, shrink=0.2)

fig.suptitle("Distance to top left pixel")
plt.tight_layout()
plt.show()
Distances

But what's really cool about JAXscape is that you can autodiff through thoses distances! Check out the documentation to learn about applications and more!

Documentation

Comprehensive documentation is available at https://vboussange.github.io/jaxscape

Features and roadmap 🚀

See issues; most notably:

  • Support for direct and iterative linear sparse solvers on GPU (cf spineax)
  • Benchmark against CircuitScape, ConScape.jl and radish.

License

jaxscape is distributed under the terms of the MIT license.

Related packages

  • gdistance
  • ConScape
  • Circuitscape
  • graphhab
  • conefor
  • resistanceGA
  • landscapemetrics
  • radish

Citation

If you use JAXScape in your research, please cite:

@software{jaxscape2024,
  author = {Boussange, Victor},
  title = {JAXScape: A minimal JAX library for connectivity modelling at scale},
  year = {2025},
  doi = {10.5281/zenodo.15267703},
  url = {https://github.com/vboussange/jaxscape}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxscape-0.0.7.tar.gz (2.3 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jaxscape-0.0.7-py3-none-any.whl (30.0 kB view details)

Uploaded Python 3

File details

Details for the file jaxscape-0.0.7.tar.gz.

File metadata

  • Download URL: jaxscape-0.0.7.tar.gz
  • Upload date:
  • Size: 2.3 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.14

File hashes

Hashes for jaxscape-0.0.7.tar.gz
Algorithm Hash digest
SHA256 393dfda80937c44a3c4c10845e44f4c87e10bdecdc3ef1376cd83f2b5f604610
MD5 7c0c7168c00345f8e671221ac7d6f920
BLAKE2b-256 bada5e99c036f81a66671de23f4742cd4622d00c43234ec82cdfa34343fae355

See more details on using hashes here.

File details

Details for the file jaxscape-0.0.7-py3-none-any.whl.

File metadata

  • Download URL: jaxscape-0.0.7-py3-none-any.whl
  • Upload date:
  • Size: 30.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.14

File hashes

Hashes for jaxscape-0.0.7-py3-none-any.whl
Algorithm Hash digest
SHA256 1226bee7c20f389444d1a89f9159ac801c0c754b179f753a62b020eff3785642
MD5 b10577b57c797102a5a9a374333d568d
BLAKE2b-256 34d505e73f82ef724ccc71b8db532189a6495b800f3217c3e872fee6998e051e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page