Skip to main content

Fermat path tracing with JAX

Project description

Fermat path-tracing with JAX

arXiv link Latest Release Python version

fpt-jax is a standalone library for differentiable path-tracing using the Fermat principle, implemented with JAX.

Installation

You can install this package from PyPI:

pip install fpt-jax

Usage

This library implements a single function, trace_rays, which traces rays undergoing specular reflections and diffractions on planar objects defined by origins and basis vectors:

> from fpt_jax import trace_rays; help(trace_rays)

trace_rays
   (tx: jax.Array, rx: jax.Array,
    object_origins: jax.Array, object_vectors: jax.Array, *,
    num_iters: int, unroll: int | bool = 1,
    num_iters_linesearch: int = 1, unroll_linesearch: int | bool = 1) -> jax.Array:

Compute the points of interaction of rays with objects using Fermat's principle.

Each ray is obtained by minimizing the total travel distance from transmitter to receiver, using a quasi-Newton optimization algorithm (BFGS). At each iteration, a line search is performed to find the optimal step size along the descent direction.

This function accepts batched inputs, where the leading dimensions must be broadcast-compatible.

Args:
    tx: Transmitter positions of shape (..., 3).
    rx: Receiver positions of shape (..., 3).
    object_origins: Origins of the objects of shape (..., num_interactions, 3).
    object_vectors: Vectors defining the objects of shape (..., num_interactions, num_dims, 3).
    num_iters: Number of iterations for the optimization algorithm.
    unroll: If an integer, the number of optimization iterations to unroll in the JAX scan.
        If True, unroll all iterations. If False, do not unroll.
    num_iters_linesearch: Number of iterations for the line search fixed-point iteration.
    unroll_linesearch: If an integer, the number of fixed-point iterations to unroll in the JAX scan.
        If True, unroll all iterations. If False, do not unroll.
    implicit_diff: Whether to use implicit differentiation for computing the gradient.
        If True, assumes that the solution has converged and applies the implicit function theorem to differentiate the optimization problem with respect to the input parameters: tx, rx, object_origins, and object_vectors.
        If False, the gradient is computed by backpropagating through all iterations of the optimization algorithm.

        Using implicit differentiation is more memory- and computationally efficient, as it does not require storing intermediate values from all iterations, but it may be less accurate if the optimization has not fully converged. Moreover, implicit differentiation is not compatible with forward-mode autodiff in JAX.

Returns:
    The points of interaction of shape (..., num_interactions, 3).
    To include the transmitter and receiver positions, concatenate tx and rx to the result.


This algorithm is also available within DiffeRT, our differentiable ray tracing library for radio propagation.

Getting help

For any question about the method or its implementation, make sure to first read the related paper.

If you want to report a bug in this library or the underlying algorithm, please open an issue on this GitHub repository. If you want to request a new feature, please consider opening an issue on DiffeRT's GitHub repository instead.

Citing

If you use this library in your research, please cite our paper:

@misc{eertmans2025fpt,
  title         = {Fast, Differentiable, GPU-Accelerated Ray Tracing for Multiple Diffraction and Reflection Paths},
  author        = {Jérome Eertmans and Sophie Lequeu and Benoît Legat and Laurent Jacques and Claude Oestges},
  year          = 2025,
  url           = {https://arxiv.org/abs/2510.16172},
  eprint        = {2510.16172},
  archiveprefix = {arXiv},
  primaryclass  = {eess.SP}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

fpt_jax-0.1.0.tar.gz (62.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

fpt_jax-0.1.0-py3-none-any.whl (7.6 kB view details)

Uploaded Python 3

File details

Details for the file fpt_jax-0.1.0.tar.gz.

File metadata

  • Download URL: fpt_jax-0.1.0.tar.gz
  • Upload date:
  • Size: 62.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for fpt_jax-0.1.0.tar.gz
Algorithm Hash digest
SHA256 fb9a9746f6449c656099be54616c2bb0729c14707c6022de9c160e6e8938f144
MD5 2faac6fd5fe9af3924deadeeb2dca3bf
BLAKE2b-256 7bb904291d40d7a0fb6c6d6298e90eb4cf8326207193caa4fd0aa345a1f8df1c

See more details on using hashes here.

Provenance

The following attestation bundles were made for fpt_jax-0.1.0.tar.gz:

Publisher: publish.yml on jeertmans/fpt-jax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file fpt_jax-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: fpt_jax-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 7.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for fpt_jax-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 44e2bae17e07685cf7a9a84b8176be54e6619cf1ee5bd51bf351c9021c4b885d
MD5 c782d54e76643cfd030b8de1fd81019c
BLAKE2b-256 2f750cc154ad12e910dea3af76e8505f2ded96683f1a578923607d88da802ed0

See more details on using hashes here.

Provenance

The following attestation bundles were made for fpt_jax-0.1.0-py3-none-any.whl:

Publisher: publish.yml on jeertmans/fpt-jax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page