Skip to main content

Derivative Informed Neural Operators in JAX and Equinox

Project description

dinox

Implementation of Derivative Informed Neural Operators in jax. Build for fast performance in single GPU environments-- and specifically where all-data-can-fit-in-gpu-memory. In the future, this code will be generalized for the setting in which one has multiple GPUs and would like to take advantage. It will also be generalized to account for big-data (where not all samples can fit in gpu or cpu memory) -- Probably via memmapping.

Installation

Create a brand new environment. Use mamba in place of conda if you can. (i.e. run the first line below). The assumption is that conda is already installed on your machine.

If one has access to an NVIDIA gpu, use gpu_environment.yml, otherwise use cpu_environment.yml, which will install the dependencies for the code, but the code will not be as performant, since the library is a GPU-forward library.

conda install -c conda-forge mamba

mamba env create -f <gpu, cpu>_environment.yml
poetry install

Running dinox

python -m dinox -network_name "<name_to_save_network_as>" -data_dir "<location_of_jacobian_enriched_training_data>"

Examples

Note, the codebase needs to be generalized to work generally on CPUs. It also does not fully work on Apple Silicon (jax-metal has limitations)

Notes on why we require these packages:

  • cupy - for rapid permuting of data on GPUs
  • kvikio - for interfacing with NVIDIA GPU Direct Storage (GDS) for loading data directly to GPU, skipping the CPU
  • equinox - Dinox is primarily build off of equinox and is therefore fully jax compatible. Most of dinox are simply lightweight utilities for dealing with mean H1 loss training of nerual networks with data that is enriched with Jacobians ($X, Y, dY/dX$)
  • optax - we use optax for optimization, though any neural network optimization library can be used. We make choices primarily for speed.

Need to generalize this to figure out the actual minimal requirements in terms of cuda, jax versions, and kvikio. The main tricky parts are which versions of jax/kvikio/cudatoolkit/cuda-nvcc/cudnn work together well. For now, only want to restrict to python>=3.10

Let me know if anyone has depenency resolution issues.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

dinox-0.4.0.tar.gz (16.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

dinox-0.4.0-py3-none-any.whl (19.0 kB view details)

Uploaded Python 3

File details

Details for the file dinox-0.4.0.tar.gz.

File metadata

  • Download URL: dinox-0.4.0.tar.gz
  • Upload date:
  • Size: 16.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.19

File hashes

Hashes for dinox-0.4.0.tar.gz
Algorithm Hash digest
SHA256 77a3f3d54d502419d73d1dcfe3e0b9fb84024e1d9e099f1dcaa863bca49a02a8
MD5 8e59edc3d84f4935e525d856097b5a42
BLAKE2b-256 c66a263832e9bfd7ef362284f7fee2db6a9fcd9532bcc77c1e389056ba0b1068

See more details on using hashes here.

File details

Details for the file dinox-0.4.0-py3-none-any.whl.

File metadata

  • Download URL: dinox-0.4.0-py3-none-any.whl
  • Upload date:
  • Size: 19.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.19

File hashes

Hashes for dinox-0.4.0-py3-none-any.whl
Algorithm Hash digest
SHA256 860c6f36e31fd4508558842d761a7be4763b1a68ddf675e00e81caae943c1dc4
MD5 06e2333345aacb467a830051ace06f0f
BLAKE2b-256 7862fe49f909fdb4571721c4a9dd7de9ce0b65231f3d44bfc4b8ef9e5bb7ebd0

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page