Add your description here
Project description
prosail-jax
A JAX port of PROSPECT-D + 4SAIL (PROSAIL), the canonical leaf+canopy radiative transfer model used in vegetation remote sensing.
The port is a single-file, self-contained module: differentiable, JIT-able, and
vmap-able. It is bit-equivalent (within ~5×10⁻¹³ absolute reflectance) to the
reference jgomezdans/prosail NumPy
implementation across the full 400–2500 nm spectrum.
This port was written by Claude (Anthropic, Opus 4.7) in a single working session against the
jgomezdans/prosailsource as ground truth. Numerical equivalence, gradient correctness, and the various edge cases were tested as the port was developed.
Why?
The reference PROSAIL implementation is a NumPy + numba.jit translation of the
original Fortran. It's fast, but it's not differentiable, not vectorisable
across batches in a friction-free way, and not GPU-portable. This port gets you:
jax.gradthrough every biophysical parameter (N, Cab, Car, Cbrown, Cw, Cm, Ant, LAI, LIDFa, hspot) and every geometry angle (sun zenith, view zenith, relative azimuth) - verified against finite differences to ~10⁻⁹ relative.jax.vmapfor batched forward calls - train neural surrogates without Python loops, run Monte-Carlo parameter sweeps in one JIT'd pass.jax.jitwith XLA - same module runs on CPU, GPU, or TPU. On CPU the port is ~2.4× faster per call than the reference numba implementation; on GPU the gap will widen dramatically because the forward is dense arithmetic with no special-function calls or scatter/gather.- Composable: drop straight into a Flax/Equinox model as a differentiable
decoder, plug into
jax.scipy.optimizeoroptaxfor inversion, hand to NumPyro/BlackJAX for HMC-based parameter inference.
What's in the port
- PROSPECT-D: leaf-level model. Inputs: N (mesophyll structure), Cab (chlorophyll), Car (carotenoids), Cbrown (brown pigment), Cw (water), Cm (dry matter), Ant (anthocyanins). Outputs: leaf reflectance and transmittance, 400–2500 nm at 1 nm.
- 4SAIL: canopy-level model with the Verhoef (2007) four-stream formulation. Inputs: leaf rho/tau (from PROSPECT), LAI, leaf inclination distribution (Verhoef bimodal or Campbell ellipsoidal), hotspot, sun zenith, view zenith, relative azimuth, soil reflectance. Outputs: SDR / BHR / DHR / HDR canopy reflectance factors.
- Linear soil mixing model (
rsoil = rsoil * (psoil*dry + (1-psoil)*wet)). - A fast custom E1 /
expi(-x)implementation, sincejax.scipy.special.expiis roughly 10,000× slower thanjnp.sinon CPU and dominates total runtime. The replacement combines a 30-term Abramowitz-Stegun series fork ≤ 1.5with a fixed-iteration Lentz continued fraction fork > 1.5, giving ~5×10⁻¹¹ relative accuracy on E1.
What's not (yet) included:
- PROSPECT-5 and PROSPECT-PRO: only PROSPECT-D is wired up. Adding 5 or
PRO is a small change - load the appropriate spectral coefficient file and
call
prospect_dwith the unused absorption coefficients zeroed (the same pattern the reference uses).
Setup
The repository is configured to be installed and run with uv.
Running the tests
Two test scripts ship with the port:
1. Numerical regression vs the reference
uv run python test_compare.py
Expected output:
=== PROSPECT-D ===
prospect: max|d_refl|=8.91e-13 max|d_tran|=9.48e-13
prospect: max|d_refl|=6.75e-13 max|d_tran|=3.12e-13
...
=== Full PROSAIL ===
prosail: max|d|=2.31e-13 median|d|=2.78e-17 (lai=4.69, mla=61.4)
prosail: max|d|=3.16e-13 median|d|=4.16e-17 (lai=5.05, mla=55.3)
...
Look for max absolute differences below ~10⁻¹². The medians sit around 10⁻¹⁷
(machine precision) - the spectral peaks where kall is small are where the
custom E1 approximation gives up its last few digits, and those are exactly
where the ~10⁻¹³ numbers come from.
2. JIT, vmap, gradients, edge cases
uv run python test_features.py
Expected output (on CPU):
=== JIT ===
jit forward (mean over 100): 0.72 ms
reference numpy/numba (mean over 10): 1.72 ms
=== VMAP (batch of 1000) ===
vmap forward (B=1000, mean over 10): ~1800 ms
per-sample: ~1800 us
=== GRAD ===
grad at solution (should be ~0): max|g|=4.33e-16
off-target grad finite=True nonzero=True
d/d_lai = 3.7340e-05
d/d_cab = -8.3169e-07
d/d_tts = -1.8036e-05
=== Edge cases ===
LAI=0: max|d|=0.00e+00
hspot=0: max|d|=2.57e-13
Verhoef: max|d|=9.35e-11
nadir: max|d|=2.82e-13
Sanity checks: every gradient should be finite (no NaN), the gradient at the target should be at machine zero, and all edge cases should match the reference to ~10⁻¹⁰ or better.
On GPU, expect the JIT forward to be roughly the same wall-clock time (transfer-dominated) but vmap to scale much further: B=65k crop trait samples in well under a second is realistic.
Quick start
from pathlib import Path
import jax
import jax.numpy as jnp
import prosail as ref # only for its data files
import prosail_jax as pj
jax.config.update("jax_enable_x64", True) # keep float64 for accuracy
coeffs, soil = pj.load_coeffs(Path(ref.__file__).parent)
# Single forward call
spectrum = pj.run_prosail(
n=1.5, cab=40.0, car=10.0, cbrown=0.1, cw=0.015, cm=0.009,
lai=3.0, lidfa=60.0, hspot=0.1,
tts=30.0, tto=10.0, psi=90.0,
coeffs=coeffs, soil=soil,
typelidf=2, # 1=Verhoef bimodal, 2=Campbell ellipsoidal
factor="SDR", # SDR | BHR | DHR | HDR
)
# spectrum: (2101,) array, 400-2500 nm at 1 nm
# Batched forward via vmap
def forward(n, cab, lai, tts):
return pj.run_prosail(
n, cab, 10.0, 0.1, 0.015, 0.009, lai, 60.0, 0.1, tts, 10.0, 90.0,
coeffs=coeffs, soil=soil, typelidf=2, factor="SDR",
)
batched = jax.jit(jax.vmap(forward))
out = batched(jnp.array([1.5, 1.7]),
jnp.array([40., 50.]),
jnp.array([3.0, 4.5]),
jnp.array([30., 40.])) # shape (2, 2101)
# Gradient - for example, sensitivity of NIR plateau (B8 ~840 nm) to LAI
def nir_response(lai):
spec = forward(1.5, 40.0, lai, 30.0)
return spec[840 - 400] # 440th index = 840 nm
dnir_dlai = jax.grad(nir_response)(3.0)
For the typical Sentinel-2 surrogate-training use case, build the band-response matrix once and contract:
# (2101, n_bands) Sentinel-2 SRF matrix, normalised per column
srf = ... # load from official ESA SRFs
def s2_forward(params):
spec = pj.run_prosail(*params, coeffs=coeffs, soil=soil,
typelidf=2, factor="SDR")
return spec @ srf
batched_s2 = jax.jit(jax.vmap(s2_forward))
Applications
Some things this port is well-suited to that the reference NumPy/Fortran implementation isn't:
Differentiable trait inversion
Given an observed canopy reflectance y_obs (Sentinel-2, MODIS, hyperspectral,
...), recover the underlying biophysical parameters by gradient descent on
||PROSAIL(θ) - y_obs||². With jax.grad and optax, this is a few-line
script. Adam typically converges in 200–500 iterations from a sensible
starting point. For comparison, traditional PROSAIL inversion uses lookup-table
search or genetic algorithms - much more expensive and less accurate.
Embedded differentiable decoder
PROSAIL becomes a fixed, physics-grounded decoder inside a larger neural model. A common pattern: an encoder maps a satellite time-series to latent trait trajectories, PROSAIL decodes those traits to reflectance, and the loss is reconstruction against the observed series. The whole pipeline is end-to-end differentiable. This is essentially the architecture the trait-MTL project is heading toward, and was the main motivation for this port.
Bayesian parameter inference
Pair with NumPyro or BlackJAX for full HMC / NUTS sampling of biophysical posteriors. PROSAIL is non-linear and non-Gaussian in its outputs, so MCMC posterior shapes are non-trivial - having a JIT-compilable, differentiable likelihood opens the door to honest uncertainty quantification per pixel / parcel rather than the point estimates that LUT-based inversion produces.
Surrogate generation, but cheaper
The classical neural-network PROSAIL surrogate (train an MLP on millions of
PROSAIL evaluations to amortise inference) becomes cheaper to set up: vmap
generates the entire training set in one batched pass on a GPU, no Python
loops, no multiprocessing.Pool shenanigans. And once you have the
differentiable PROSAIL itself in JAX, the "do I even need a surrogate?"
question gets a different answer in many cases.
Sensitivity analysis and uncertainty propagation
jax.jacrev gives you the full (n_bands × n_params) Jacobian for any pixel,
in one call. Combine with parameter covariance and you get observational
uncertainty by linear propagation. For Sobol or Saltelli global sensitivity,
vmap the forward over a quasi-Monte-Carlo sample.
Hyperspectral and multi-sensor fusion
The forward returns the full 400–2500 nm spectrum at 1 nm. Multiplying by different sensor SRFs (Sentinel-2, Landsat 8/9 OLI, MODIS, EnMAP, PRISMA, hyperspectral airborne) gives sensor-specific predictions from the same underlying state. Useful for cross-sensor calibration and for building multi-sensor inversion problems where all observations share the same trait state but go through different SRFs.
References
- PROSPECT-D: Féret, J.-B., et al. (2017). PROSPECT-D: Towards modeling leaf optical properties through a complete lifecycle. Remote Sensing of Environment, 193, 204–215.
- 4SAIL: Verhoef, W., Jia, L., Xiao, Q., & Su, Z. (2007). Unified Optical-Thermal Four-Stream Radiative Transfer Theory for Homogeneous Vegetation Canopies. IEEE TGRS, 45(6), 1808–1822.
- Reference Python implementation: https://github.com/jgomezdans/prosail
- JAX: https://github.com/jax-ml/jax
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file prosail_jax-0.1.0.tar.gz.
File metadata
- Download URL: prosail_jax-0.1.0.tar.gz
- Upload date:
- Size: 17.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.9.27 {"installer":{"name":"uv","version":"0.9.27","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
94d9b332fe5d52e615a97491d3329be2c001f167e0f4ee685e858fdf9c8d37a3
|
|
| MD5 |
ded2bdc12c780aae157a1d378e689f47
|
|
| BLAKE2b-256 |
25e9f25271e0dc730f80e4d74c6fe5e87a343e09477530154fa323294e3d5406
|
File details
Details for the file prosail_jax-0.1.0-py3-none-any.whl.
File metadata
- Download URL: prosail_jax-0.1.0-py3-none-any.whl
- Upload date:
- Size: 15.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.9.27 {"installer":{"name":"uv","version":"0.9.27","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
760eea539828be14260c9c31eb0292222f8aa731f1302d3130a2ccd1e1818fdd
|
|
| MD5 |
a20a7a796404e84aa240ccb98a8fc175
|
|
| BLAKE2b-256 |
5730899994bbf024a169204c190d6696f4aec3ae52c2c0463d3e5904f49f6055
|