Skip to main content

A JAX-based research framework for writing differentiable numerical simulators with arbitrary discretizations

Project description

jaxdf - JAX-based Discretization Framework

Support License: LGPL v3 codecov CI

Overview | Example | Installation | Documentation | Support


Overview

Jaxdf is a package based on JAX that provides a coding framework for creating differentiable numerical simulators with arbitrary discretizations.

The primary objective of Jaxdf is to aid in the construction of numerical models for physical systems, like wave propagation, or the numerical resolution of partial differential equations, in a manner that is easily tailored to the user's research requirements. These models are pure functions that can be seamlessly integrated into arbitrary differentiable programs written in JAX. For instance, they can be employed as layers within neural networks, or utilized in constructing a physics loss function.


Example

The script below constructs the non-linear operator (∇2 + sin), applying a Fourier spectral discretization on a square 2D domain. It then utilizes this operator to define a loss function. The gradient of this loss function is calculated using JAX's Automatic Differentiation.

from jaxdf import operators as jops
from jaxdf import FourierSeries, operator
from jaxdf.geometry import Domain
from jax import numpy as jnp
from jax import jit, grad


# Defining operator
@operator
def custom_op(u, *, params=None):
  grad_u = jops.gradient(u)
  diag_jacobian = jops.diag_jacobian(grad_u)
  laplacian = jops.sum_over_dims(diag_jacobian)
  sin_u = jops.compose(u)(jnp.sin)
  return laplacian + sin_u

# Defining discretizations
domain = Domain((128, 128), (1., 1.))
parameters = jnp.ones((128,128,1))
u = FourierSeries(parameters, domain)

# Define a differentiable loss function
@jit
def loss(u):
  v = custom_op(u)
  return jnp.mean(jnp.abs(v.on_grid)**2)

gradient = grad(loss)(u) # gradient is a FourierSeries

Installation

Before proceeding with the installation of jaxdf, ensure that JAX is already installed on your system. If you intend to utilize jaxdf with NVidia GPU support, follow the instructions to install JAX accordingly.

To install jaxdf from PyPI, use the pip command:

pip install jaxdf

For development purposes, install jaxdf by either cloning the repository or downloading and extracting the compressed archive. Afterward, navigate to the root folder in a terminal, and execute the following command:

pip install --upgrade poetry
poetry install

This will install the dependencies and the package itself (in editable mode).

Support

Support

If you encounter any issues with the code or wish to suggest new features, please feel free to open an issue. If you seek guidance, wish to discuss something, or simply want to say hi, don't hesitate to write a message in our Discord channel.


Contributing

Contributions are absolutely welcome! Most contributions start with an issue. Please don't hesitate to create issues in which you ask for features, give feedback on performances, or simply want to reach out.

To make a pull request, please look at the detailed Contributing guide for how to do it, but fundamentally keep in mind the following main guidelines:

  • If you add a new feature or fix a bug:
    • Make sure it is covered by tests
    • Add a line in the changelog using kacl-cli
  • If you changed something in the documentation, make sure that the documentation site can be correctly build using mkdocs serve


Citation

arXiv

An initial version of this package was presented at the Differentiable Programming workshop at NeurIPS 2021.

@article{stanziola2021jaxdf,
    author={Stanziola, Antonio and Arridge, Simon and Cox, Ben T. and Treeby, Bradley E.},
    title={A research framework for writing differentiable PDE discretizations in JAX},
    year={2021},
    journal={Differentiable Programming workshop at Neural Information Processing Systems 2021}
}

Acknowledgements

Related projects

  1. odl Operator Discretization Library (ODL) is a python library for fast prototyping focusing on (but not restricted to) inverse problems.
  2. deepXDE: a TensorFlow and PyTorch library for scientific machine learning.
  3. SciML: SciML is a NumFOCUS sponsored open source software organization created to unify the packages for scientific machine learning.

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxdf-0.2.8.tar.gz (25.9 kB view details)

Uploaded Source

Built Distribution

jaxdf-0.2.8-py3-none-any.whl (28.5 kB view details)

Uploaded Python 3

File details

Details for the file jaxdf-0.2.8.tar.gz.

File metadata

  • Download URL: jaxdf-0.2.8.tar.gz
  • Upload date:
  • Size: 25.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.3 CPython/3.10.14 Linux/6.5.0-1025-azure

File hashes

Hashes for jaxdf-0.2.8.tar.gz
Algorithm Hash digest
SHA256 d5af416a13e7ba9f6c7a72a79d3f69fb9b90f639dcf3d6bda3be8e0e198f1e18
MD5 f1f46e057216df593de36e59fc863dbc
BLAKE2b-256 3409e3b89ea3d73c74d12f45f6d3e09bd66c39665d38d58cfc672a12c875333d

See more details on using hashes here.

File details

Details for the file jaxdf-0.2.8-py3-none-any.whl.

File metadata

  • Download URL: jaxdf-0.2.8-py3-none-any.whl
  • Upload date:
  • Size: 28.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.3 CPython/3.10.14 Linux/6.5.0-1025-azure

File hashes

Hashes for jaxdf-0.2.8-py3-none-any.whl
Algorithm Hash digest
SHA256 825924f513ed82049b6bdeaff82c4727b9ed172e235f363ef3f7db716e1f0556
MD5 2b30bce09b7d2c56e7a52b80ba104eb2
BLAKE2b-256 f2089b7b4524a3d9282bc3592f92a9854a9981c161c995ef2d44e9e2f2be44a3

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page