Skip to main content

Add your description here

Project description

prosail-jax

A JAX port of PROSPECT-D + 4SAIL (PROSAIL), the canonical leaf+canopy radiative transfer model used in vegetation remote sensing.

The port is a single-file, self-contained module: differentiable, JIT-able, and vmap-able. It is bit-equivalent (within ~5×10⁻¹³ absolute reflectance) to the reference jgomezdans/prosail NumPy implementation across the full 400–2500 nm spectrum.

This port was written by Claude (Anthropic, Opus 4.7) in a single working session against the jgomezdans/prosail source as ground truth. Numerical equivalence, gradient correctness, and the various edge cases were tested as the port was developed.

Why?

The reference PROSAIL implementation is a NumPy + numba.jit translation of the original Fortran. It's fast, but it's not differentiable, not vectorisable across batches in a friction-free way, and not GPU-portable. This port gets you:

  • jax.grad through every biophysical parameter (N, Cab, Car, Cbrown, Cw, Cm, Ant, LAI, LIDFa, hspot) and every geometry angle (sun zenith, view zenith, relative azimuth) - verified against finite differences to ~10⁻⁹ relative.
  • jax.vmap for batched forward calls - train neural surrogates without Python loops, run Monte-Carlo parameter sweeps in one JIT'd pass.
  • jax.jit with XLA - same module runs on CPU, GPU, or TPU. On CPU the port is ~2.4× faster per call than the reference numba implementation; on GPU the gap will widen dramatically because the forward is dense arithmetic with no special-function calls or scatter/gather.
  • Composable: drop straight into a Flax/Equinox model as a differentiable decoder, plug into jax.scipy.optimize or optax for inversion, hand to NumPyro/BlackJAX for HMC-based parameter inference.

What's in the port

  • PROSPECT-D: leaf-level model. Inputs: N (mesophyll structure), Cab (chlorophyll), Car (carotenoids), Cbrown (brown pigment), Cw (water), Cm (dry matter), Ant (anthocyanins). Outputs: leaf reflectance and transmittance, 400–2500 nm at 1 nm.
  • 4SAIL: canopy-level model with the Verhoef (2007) four-stream formulation. Inputs: leaf rho/tau (from PROSPECT), LAI, leaf inclination distribution (Verhoef bimodal or Campbell ellipsoidal), hotspot, sun zenith, view zenith, relative azimuth, soil reflectance. Outputs: SDR / BHR / DHR / HDR canopy reflectance factors.
  • Linear soil mixing model (rsoil = rsoil * (psoil*dry + (1-psoil)*wet)).
  • A fast custom E1 / expi(-x) implementation, since jax.scipy.special.expi is roughly 10,000× slower than jnp.sin on CPU and dominates total runtime. The replacement combines a 30-term Abramowitz-Stegun series for k ≤ 1.5 with a fixed-iteration Lentz continued fraction for k > 1.5, giving ~5×10⁻¹¹ relative accuracy on E1.

What's not (yet) included:

  • PROSPECT-5 and PROSPECT-PRO: only PROSPECT-D is wired up. Adding 5 or PRO is a small change - load the appropriate spectral coefficient file and call prospect_d with the unused absorption coefficients zeroed (the same pattern the reference uses).

Setup

The repository is configured to be installed and run with uv.

Running the tests

Two test scripts ship with the port:

1. Numerical regression vs the reference

uv run python test_compare.py

Expected output:

=== PROSPECT-D ===
  prospect: max|d_refl|=8.91e-13  max|d_tran|=9.48e-13
  prospect: max|d_refl|=6.75e-13  max|d_tran|=3.12e-13
  ...

=== Full PROSAIL ===
  prosail: max|d|=2.31e-13  median|d|=2.78e-17  (lai=4.69, mla=61.4)
  prosail: max|d|=3.16e-13  median|d|=4.16e-17  (lai=5.05, mla=55.3)
  ...

Look for max absolute differences below ~10⁻¹². The medians sit around 10⁻¹⁷ (machine precision) - the spectral peaks where kall is small are where the custom E1 approximation gives up its last few digits, and those are exactly where the ~10⁻¹³ numbers come from.

2. JIT, vmap, gradients, edge cases

uv run python test_features.py

Expected output (on CPU):

=== JIT ===
  jit forward (mean over 100): 0.72 ms
  reference numpy/numba (mean over 10): 1.72 ms

=== VMAP (batch of 1000) ===
  vmap forward (B=1000, mean over 10): ~1800 ms
  per-sample: ~1800 us

=== GRAD ===
  grad at solution (should be ~0): max|g|=4.33e-16
  off-target grad finite=True  nonzero=True
  d/d_lai = 3.7340e-05
  d/d_cab = -8.3169e-07
  d/d_tts = -1.8036e-05

=== Edge cases ===
  LAI=0:    max|d|=0.00e+00
  hspot=0:  max|d|=2.57e-13
  Verhoef:  max|d|=9.35e-11
  nadir:    max|d|=2.82e-13

Sanity checks: every gradient should be finite (no NaN), the gradient at the target should be at machine zero, and all edge cases should match the reference to ~10⁻¹⁰ or better.

On GPU, expect the JIT forward to be roughly the same wall-clock time (transfer-dominated) but vmap to scale much further: B=65k crop trait samples in well under a second is realistic.

Quick start

from pathlib import Path
import jax
import jax.numpy as jnp
import prosail as ref               # only for its data files
import prosail_jax as pj

jax.config.update("jax_enable_x64", True)  # keep float64 for accuracy
coeffs, soil = pj.load_coeffs(Path(ref.__file__).parent)

# Single forward call
spectrum = pj.run_prosail(
    n=1.5, cab=40.0, car=10.0, cbrown=0.1, cw=0.015, cm=0.009,
    lai=3.0, lidfa=60.0, hspot=0.1,
    tts=30.0, tto=10.0, psi=90.0,
    coeffs=coeffs, soil=soil,
    typelidf=2,        # 1=Verhoef bimodal, 2=Campbell ellipsoidal
    factor="SDR",      # SDR | BHR | DHR | HDR
)
# spectrum: (2101,) array, 400-2500 nm at 1 nm

# Batched forward via vmap
def forward(n, cab, lai, tts):
    return pj.run_prosail(
        n, cab, 10.0, 0.1, 0.015, 0.009, lai, 60.0, 0.1, tts, 10.0, 90.0,
        coeffs=coeffs, soil=soil, typelidf=2, factor="SDR",
    )

batched = jax.jit(jax.vmap(forward))
out = batched(jnp.array([1.5, 1.7]),
              jnp.array([40., 50.]),
              jnp.array([3.0, 4.5]),
              jnp.array([30., 40.]))   # shape (2, 2101)

# Gradient - for example, sensitivity of NIR plateau (B8 ~840 nm) to LAI
def nir_response(lai):
    spec = forward(1.5, 40.0, lai, 30.0)
    return spec[840 - 400]   # 440th index = 840 nm

dnir_dlai = jax.grad(nir_response)(3.0)

For the typical Sentinel-2 surrogate-training use case, build the band-response matrix once and contract:

# (2101, n_bands) Sentinel-2 SRF matrix, normalised per column
srf = ...  # load from official ESA SRFs

def s2_forward(params):
    spec = pj.run_prosail(*params, coeffs=coeffs, soil=soil,
                          typelidf=2, factor="SDR")
    return spec @ srf

batched_s2 = jax.jit(jax.vmap(s2_forward))

Applications

Some things this port is well-suited to that the reference NumPy/Fortran implementation isn't:

Differentiable trait inversion

Given an observed canopy reflectance y_obs (Sentinel-2, MODIS, hyperspectral, ...), recover the underlying biophysical parameters by gradient descent on ||PROSAIL(θ) - y_obs||². With jax.grad and optax, this is a few-line script. Adam typically converges in 200–500 iterations from a sensible starting point. For comparison, traditional PROSAIL inversion uses lookup-table search or genetic algorithms - much more expensive and less accurate.

Embedded differentiable decoder

PROSAIL becomes a fixed, physics-grounded decoder inside a larger neural model. A common pattern: an encoder maps a satellite time-series to latent trait trajectories, PROSAIL decodes those traits to reflectance, and the loss is reconstruction against the observed series. The whole pipeline is end-to-end differentiable. This is essentially the architecture the trait-MTL project is heading toward, and was the main motivation for this port.

Bayesian parameter inference

Pair with NumPyro or BlackJAX for full HMC / NUTS sampling of biophysical posteriors. PROSAIL is non-linear and non-Gaussian in its outputs, so MCMC posterior shapes are non-trivial - having a JIT-compilable, differentiable likelihood opens the door to honest uncertainty quantification per pixel / parcel rather than the point estimates that LUT-based inversion produces.

Surrogate generation, but cheaper

The classical neural-network PROSAIL surrogate (train an MLP on millions of PROSAIL evaluations to amortise inference) becomes cheaper to set up: vmap generates the entire training set in one batched pass on a GPU, no Python loops, no multiprocessing.Pool shenanigans. And once you have the differentiable PROSAIL itself in JAX, the "do I even need a surrogate?" question gets a different answer in many cases.

Sensitivity analysis and uncertainty propagation

jax.jacrev gives you the full (n_bands × n_params) Jacobian for any pixel, in one call. Combine with parameter covariance and you get observational uncertainty by linear propagation. For Sobol or Saltelli global sensitivity, vmap the forward over a quasi-Monte-Carlo sample.

Hyperspectral and multi-sensor fusion

The forward returns the full 400–2500 nm spectrum at 1 nm. Multiplying by different sensor SRFs (Sentinel-2, Landsat 8/9 OLI, MODIS, EnMAP, PRISMA, hyperspectral airborne) gives sensor-specific predictions from the same underlying state. Useful for cross-sensor calibration and for building multi-sensor inversion problems where all observations share the same trait state but go through different SRFs.

References

  • PROSPECT-D: Féret, J.-B., et al. (2017). PROSPECT-D: Towards modeling leaf optical properties through a complete lifecycle. Remote Sensing of Environment, 193, 204–215.
  • 4SAIL: Verhoef, W., Jia, L., Xiao, Q., & Su, Z. (2007). Unified Optical-Thermal Four-Stream Radiative Transfer Theory for Homogeneous Vegetation Canopies. IEEE TGRS, 45(6), 1808–1822.
  • Reference Python implementation: https://github.com/jgomezdans/prosail
  • JAX: https://github.com/jax-ml/jax

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

prosail_jax-0.1.0.tar.gz (17.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

prosail_jax-0.1.0-py3-none-any.whl (15.9 kB view details)

Uploaded Python 3

File details

Details for the file prosail_jax-0.1.0.tar.gz.

File metadata

  • Download URL: prosail_jax-0.1.0.tar.gz
  • Upload date:
  • Size: 17.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.27 {"installer":{"name":"uv","version":"0.9.27","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for prosail_jax-0.1.0.tar.gz
Algorithm Hash digest
SHA256 94d9b332fe5d52e615a97491d3329be2c001f167e0f4ee685e858fdf9c8d37a3
MD5 ded2bdc12c780aae157a1d378e689f47
BLAKE2b-256 25e9f25271e0dc730f80e4d74c6fe5e87a343e09477530154fa323294e3d5406

See more details on using hashes here.

File details

Details for the file prosail_jax-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: prosail_jax-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 15.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.27 {"installer":{"name":"uv","version":"0.9.27","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for prosail_jax-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 760eea539828be14260c9c31eb0292222f8aa731f1302d3130a2ccd1e1818fdd
MD5 a20a7a796404e84aa240ccb98a8fc175
BLAKE2b-256 5730899994bbf024a169204c190d6696f4aec3ae52c2c0463d3e5904f49f6055

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page