Skip to main content

Derivative Informed Neural Operators in JAX and Equinox

Project description

dinox

Implementation of Derivative Informed Neural Operators in jax. Build for fast performance in single GPU environments-- and specifically where all-data-can-fit-in-gpu-memory. In the future, this code will be generalized for the setting in which one has multiple GPUs and would like to take advantage. It will also be generalized to account for big-data (where not all samples can fit in gpu or cpu memory) -- Probably via memmapping.

Installation

Create a brand new environment. Use mamba in place of conda if you can. (i.e. run the first line below). The assumption is that conda is already installed on your machine.

If one has access to an NVIDIA gpu, use gpu_environment.yml, otherwise use cpu_environment.yml, which will install the dependencies for the code, but the code will not be as performant, since the library is a GPU-forward library.

conda install -c conda-forge mamba

mamba env create -f <gpu, cpu>_environment.yml
poetry install

Running dinox

python -m dinox -network_name "<name_to_save_network_as>" -data_dir "<location_of_jacobian_enriched_training_data>"

Examples

Note, the codebase needs to be generalized to work generally on CPUs. It also does not fully work on Apple Silicon (jax-metal has limitations)

Notes on why we require these packages:

  • cupy - for rapid permuting of data on GPUs
  • kvikio - for interfacing with NVIDIA GPU Direct Storage (GDS) for loading data directly to GPU, skipping the CPU
  • equinox - Dinox is primarily build off of equinox and is therefore fully jax compatible. Most of dinox are simply lightweight utilities for dealing with mean H1 loss training of nerual networks with data that is enriched with Jacobians ($X, Y, dY/dX$)
  • optax - we use optax for optimization, though any neural network optimization library can be used. We make choices primarily for speed.

Need to generalize this to figure out the actual minimal requirements in terms of cuda, jax versions, and kvikio. The main tricky parts are which versions of jax/kvikio/cudatoolkit/cuda-nvcc/cudnn work together well. For now, only want to restrict to python>=3.10

Let me know if anyone has depenency resolution issues.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

dinox-0.4.1.tar.gz (16.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

dinox-0.4.1-py3-none-any.whl (19.3 kB view details)

Uploaded Python 3

File details

Details for the file dinox-0.4.1.tar.gz.

File metadata

  • Download URL: dinox-0.4.1.tar.gz
  • Upload date:
  • Size: 16.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.19

File hashes

Hashes for dinox-0.4.1.tar.gz
Algorithm Hash digest
SHA256 3f39f3dcae41d12769f141b5da7a5e5599b7299de603a288543bb96a347ca0e1
MD5 e37745fc61e8809e97ab2a2e57d8648a
BLAKE2b-256 d488c3c6cc267867e6a1bd00f44752497bccf9310ab2b30077ac4362f686d346

See more details on using hashes here.

File details

Details for the file dinox-0.4.1-py3-none-any.whl.

File metadata

  • Download URL: dinox-0.4.1-py3-none-any.whl
  • Upload date:
  • Size: 19.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.19

File hashes

Hashes for dinox-0.4.1-py3-none-any.whl
Algorithm Hash digest
SHA256 1487f02f31a61f24391badd83a80c5d61eefb1bf874a414fbc317436b0788675
MD5 e4fb3dc9917175a2611906b7bf34b04e
BLAKE2b-256 334584fdd3fd97b1bf7bac8974aeb439f3b3f0db8552bdd535d4e6c59d667091

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page