Derivative Informed Neural Operators in JAX and Equinox
Project description
dinox
Implementation of Derivative Informed Neural Operators in jax. Build for fast performance in single GPU environments-- and specifically where all-data-can-fit-in-gpu-memory. In the future, this code will be generalized for the setting in which one has multiple GPUs and would like to take advantage. It will also be generalized to account for big-data (where not all samples can fit in gpu or cpu memory) -- Probably via memmapping.
Installation
Create a brand new environment. Use mamba in place of conda if you can. (i.e. run the first line below). The assumption is that conda is already installed on your machine.
If one has access to an NVIDIA gpu, use gpu_environment.yml, otherwise use cpu_environment.yml, which will install the dependencies for the code, but the code will not be as performant, since the library is a GPU-forward library.
conda install -c conda-forge mamba
mamba env create -f <gpu, cpu>_environment.yml
poetry install
Running dinox
python -m dinox -network_name "<name_to_save_network_as>" -data_dir "<location_of_jacobian_enriched_training_data>"
Examples
Note, the codebase needs to be generalized to work generally on CPUs. It also does not fully work on Apple Silicon (jax-metal has limitations)
Notes on why we require these packages:
cupy- for rapid permuting of data on GPUskvikio- for interfacing with NVIDIA GPU Direct Storage (GDS) for loading data directly to GPU, skipping the CPUequinox- Dinox is primarily build off of equinox and is therefore fully jax compatible. Most of dinox are simply lightweight utilities for dealing with mean H1 loss training of nerual networks with data that is enriched with Jacobians ($X, Y, dY/dX$)optax- we use optax for optimization, though any neural network optimization library can be used. We make choices primarily for speed.
Need to generalize this to figure out the actual minimal requirements in terms of cuda, jax versions, and kvikio. The main tricky parts are which versions of jax/kvikio/cudatoolkit/cuda-nvcc/cudnn work together well. For now, only want to restrict to python>=3.10
Let me know if anyone has depenency resolution issues.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file dinox-0.4.1.tar.gz.
File metadata
- Download URL: dinox-0.4.1.tar.gz
- Upload date:
- Size: 16.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.9.19
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
3f39f3dcae41d12769f141b5da7a5e5599b7299de603a288543bb96a347ca0e1
|
|
| MD5 |
e37745fc61e8809e97ab2a2e57d8648a
|
|
| BLAKE2b-256 |
d488c3c6cc267867e6a1bd00f44752497bccf9310ab2b30077ac4362f686d346
|
File details
Details for the file dinox-0.4.1-py3-none-any.whl.
File metadata
- Download URL: dinox-0.4.1-py3-none-any.whl
- Upload date:
- Size: 19.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.9.19
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
1487f02f31a61f24391badd83a80c5d61eefb1bf874a414fbc317436b0788675
|
|
| MD5 |
e4fb3dc9917175a2611906b7bf34b04e
|
|
| BLAKE2b-256 |
334584fdd3fd97b1bf7bac8974aeb439f3b3f0db8552bdd535d4e6c59d667091
|