Skip to main content

Neural Processes implementations in JAX and PyTorch

Project description

Neural Process Family

This library is in the early stages of development.

PyTorch versions are not yet supported.

Installation

You can choose between the following installation methods:

NPF as a library

# From PyPI (recommended)
pip install np-family

# Latest version (from current branch; dev-v4)
pip install git+https://github.com/yuneg11/Neural-Process-Family@dev-v4

# Specific release (from tag; 0.0.1.dev0)
pip install git+https://github.com/yuneg11/Neural-Process-Family@0.0.1.dev0

Then, you can use the library as follows:

from npf.jax.models import CNP

cnp = CNP(y_dim=1)

You should handle other logics (include train, evaluation, etc...)

NPF as an experiment framework

# Dependencies
pip install rich nxcl==0.0.3.dev3

## And ML frameworks (JAX, PyTorch)
# ex) pip install jax

# Clone the repository
git clone https://github.com/yuneg11/Neural-Process-Family npf
cd npf

Then, you can run the experiment, for example:

python scripts/jax/train.py -f configs/gp/rbf/inf/anp.yaml -lr 0.0001 --model.train_kwargs.num_latents 30

The output will be saved under outs/ directory. Details will be added in the future.

Download or build datasets

python -m npf.jax.data.save \
    --root <dataset-root> \
    --dataset <dataset-name>
  • <dataset-root>: The root path to save dataset. Default: ./datasets/
  • <dataset-name>: The name of the dataset to save. See below sections for available datasets.

Image datasets

You should install torch and torchvision to download the datastes. You can find the details in the download page.

For example,

# CUDA 11.3
conda install pytorch torchvision cudatoolkit=11.3 -c pytorch

Available datasets

  • MNIST: mnist
  • CIFAR10: cifar10
  • CIFAR100: cifar100
  • CelebA: celeba
  • SVHN: svhn

Sim2Real datasets

You should install numba and wget to simulate or download the datasets.

pip install numba wget

Available datasets

  • Lotka Volterra: lotka_volterra

    TODO: See npf.jax.data.save:save_lotka_volterra for more detailed options.

Models

  • CNP: Conditional Neural Process
  • NP: Neural Process
  • CANP: Conditional Attentive Neural Process
  • ANP: Attentive Neural Process
  • BNP: Bootstrapping Neural Process
  • BANP: Bootstrapping Attentive Neural Process
  • NeuBNP: Neural Bootstrapping Neural Process
  • NeuBANP: Neural Bootstrapping Attentive Neural Process
  • ConvCNP: Convolutional Conditional Neural Process
  • ConvNP: Convolutional Neural Process

Scripts

Train

python scripts/jax/train.py -f <config-file> [additional-options]

You can use your own config file or use the provided config files in the configs directory. For example, the following command will train a CNP model with learning rate of 0.0001 for 100 epochs:

python scripts/jax/train.py -f configs/gp/rbf/inf/anp.yaml \
    -lr 0.0001 \
    --train.num_epochs 100

You can see the help of the config file by using the following command:

python scripts/jax/train.py -f <config-file> --help

Test

# From a trained model directory
python scripts/jax/test.py -d <model-output-dir> [additional-options]

# From a new config file and a trained model checkpoint
python scripts/jax/test.py -f <config-file> -c <checkpoint-file-path> [additional-options]

You can directly test the trained model by specifying the output directory. For example:

python scripts/jax/test.py -d outs/CNP/Train/RBF/Inf/220704-181313-vweh

where outs/CNP/Train/RBF/Inf/220704-181313-vweh is the output directory of the trained model.

You can also replace or add the test-specific configs from the config file using the -tf / --test-config-file option. For example:

python scripts/jax/test.py -d outs/CNP/Train/RBF/Inf/220704-181313-vweh \
    -tf configs/gp/robust/matern.yaml

Test Bayesian optimization

# From a trained model directory
python scripts/jax/test_bo.py -d <model-output-dir> [additional-options]

# From a new config file and a trained model checkpoint
python scripts/jax/test_bo.py -f <config-file> -c <checkpoint-file-path> [additional-options]

Similar to above the test script, you can directly test the trained model by specifying the output directory. For example:

python scripts/jax/test_bo.py -d outs/CNP/Train/RBF/Inf/220704-181313-vweh

You can also replace or add the test-specific configs from the config file using the -bf / --bo-config-file option. For example:

python scripts/jax/test.py -d outs/CNP/Train/RBF/Inf/220704-181313-vweh \
    -bf configs/gp/rbf/bo_config.yaml




Appendix

Datasets

  1. 1D regression (x: [B, P, 1], y: [B, P, 1], mask: [B, P])

    • Gaussian processes, etc...
  2. 2D Image (x: [B, P, P, 2], y: [B, P, P, (1 or 3)], mask: [B, P, P])

    • Image completion, super resolution, etc...
  3. Bayesian optimization (x: [B, P, D], y: [B, P, 1], mask: [B, P])

Dimension rule

  • x: [batch, *data_specific_dims, data_dim]
  • y: [batch, *data_specific_dims, data_dim]
  • mask: [batch, *data_specific_dims]
  • outs: [batch, *model_specific_dims, *data_specific_dims, data_dim]

Examples

  1. At CNP 1D regression:

    • x: [batch, point, 1]
    • y: [batch, point, 1]
    • mask: [batch, point]
    • outs: [batch, point, 1]
  2. At NP 1D regression:

    • x: [batch, point, 1]
    • y: [batch, point, 1]
    • mask: [batch, point]
    • outs: [batch, latent, point, 1]
  3. At CNP 2D image regression:

    • x: [batch, height, width, 2]
    • y: [batch, height, width, 1 or 3]
    • mask: [batch, height, width]
    • outs: [batch, height, width, 1 or 3]
  4. At NP 2D image regression:

    • x: [batch, height, width, 2]
    • y: [batch, height, width, 1 or 3]
    • mask: [batch, height, width]
    • outs: [batch, latent, height, width, 1 or 3]
  5. At BNP 1D regression:

    • x: [batch, point, 1]
    • y: [batch, point, 1]
    • mask: [batch, point]
    • outs: [batch, sample, point, 1]
  6. At BNP 2D image regression:

    • x: [batch, height, width, 2]
    • y: [batch, height, width, 1 or 3]
    • mask: [batch, height, width]
    • outs: [batch, sample, height, width, 1 or 3]

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

np-family-0.2.0.tar.gz (32.7 kB view details)

Uploaded Source

Built Distribution

np_family-0.2.0-py3-none-any.whl (46.0 kB view details)

Uploaded Python 3

File details

Details for the file np-family-0.2.0.tar.gz.

File metadata

  • Download URL: np-family-0.2.0.tar.gz
  • Upload date:
  • Size: 32.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.12

File hashes

Hashes for np-family-0.2.0.tar.gz
Algorithm Hash digest
SHA256 f76a8c3bfa19453a0b699f36d883950237d3d36920a88b3489d6a1b5230b8ac5
MD5 afe5029640380b4ac4d5104631d25e4e
BLAKE2b-256 020e302c4c3d8309df8ce2cedd31065cd3274d4cabe3193786e8ba283f51e06c

See more details on using hashes here.

File details

Details for the file np_family-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: np_family-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 46.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.12

File hashes

Hashes for np_family-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 f825a213b14826e51d61f000a716a3f91cea3dc26e764e63794f32202e8647f2
MD5 169c40357ff3b64c124785c512a43bcd
BLAKE2b-256 a0791bc1dcde6535800e618d84b1bd054fb4088886def0a63ae940126c2e071f

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page