Skip to main content

This package provides tools to execute and implement Implicit Likelihood Inference tools in JAX.

Project description

JaxILI

CI Test black PyPI PyPI - License

This is a package to run Neural Density Estimation using Jax. The training is performed using optax (documentation available here) and the neural network are created using flax (see documentation).

The code is meant to provide tools to train Normalizing Flows easily to perform Implicit Likelihood Inference.

Installation

Install jaxili using PyPI:

$ pip install jaxili

First example: performing Neural Posterior Estimation

from jaxili.inference import NPE

First fetch the data you want to train on:

theta, x = ... #Theta corresponds to the parameter to be infered and x to the simulator output given theta.

Then create an inference object, add the simulations and train:

inference = NPE()
inference.append_simulations(theta, x)

learning_rate = ... #Choose your learning rate
num_epochs = ... #Choose the number of epochs
batch_size = ... #Choose the batch size
checkpoint_path = ... #Choose the checkpoint path
checkpoint_path = os.path.abspath(checkpoint_path) #Beware, this should be an absolute path.

metrics, density_estimator = inference.train(
    training_batch_size=batch_size,
    learning_rate=learning_rate,
    checkpoint_path=checkpoint_path,
    num_epochs=num_epochs
)

You can then fetch the posterior to sample from it.

posterior = inference.build_posterior()

observation = ... #The observation should have the shape [1, data vector size].
samples = posterior.sample(x=observation, num_samples=..., key=...) #You have to give a PRNGKey and specify the number of samples.

Training a conditional MAF

If you want to control the architecture of the network you can use the following code to train e.g. a Masked Autoregressive Flow (MAF).

import jax
import jax.numpy

from jaxili.utils import create_data_loader  #To create data loaders
from jaxili.train import TrainerModule #To perform the training
from jaxili.model import ConditionalMAF #The model used to learn the target distribution
from jaxili.loss import loss_nll_npe #Losses to train NFs with different configurations are provided

Given a train, validation and test set, one can create associated data loaders to perform the training.

train_loader, val_loader, test_loader = create_data_loader(
    train_set, val_set, test_set,
    train = [True, False, False],
    batch_size=128
)

You can then specify hyperparameters for your training

CHECKPOINT_PATH = ... #Path to save the weights of your neural network

loss_fn = loss_nll_npe

model_hparams_maf = {
    'n_in': dim_theta,
    'n_cond': dim_obs,
    'n_layers': 5,
    'layers': [50, 50],
    'activation': jax.nn.relu,
    'use_reverse': True,
    'seed' : 42
}

optimizer_hparams = { #hyperparameters of the optimizer for training
    'lr': 5e-4,
    'optimizer_name': 'adam'
}

logger_params = {
    'base_log_dir': CHECKPOINT_PATH
}

check_val_every_epoch = 1

debug = False

nde_class= "NPE"

A TrainerModule object can then be created to train the Neural Network:

trainer_maf_npe = TrainerModule(
    model_class=ConditionalMAF,
    model_hparams=model_hparams_maf,
    optimizer_hparams=optimizer_hparams,
    loss_fn=loss_fn,
    exmp_input=next(iter(train_loader)),
    logger_params=logger_params,
    debug=debug,
    check_val_every_epoch=check_val_every_epoch,
    nde_class=nde_class    
)

#Train the Neural Density Estimator
metrics_maf_npe = trainer_maf_npe.train_model(
    train_loader, val_loader, test_loader=test_loader, num_epochs=500, patience=20
)

The trained model can then be used to sample from or compute the log-probability of the learned distribution:

model_maf_npe = trainer_maf_npe.bind_model()

key, jax.random.PRNGKey(0)
samples_maf_npe = model_maf_npe.sample(
    observation, num_samples=10000, key=key
)
log_prob = model_maf_npe.apply(params, samples_maf_npe, observation, method="log_prob")

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxili-0.1.tar.gz (32.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jaxili-0.1-py3-none-any.whl (39.5 kB view details)

Uploaded Python 3

File details

Details for the file jaxili-0.1.tar.gz.

File metadata

  • Download URL: jaxili-0.1.tar.gz
  • Upload date:
  • Size: 32.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.8.0 colorama/0.4.4 importlib-metadata/6.8.0 keyring/24.2.0 pkginfo/1.8.2 readme-renderer/34.0 requests-toolbelt/0.9.1 requests/2.25.1 rfc3986/1.5.0 tqdm/4.57.0 urllib3/1.26.5 CPython/3.10.12

File hashes

Hashes for jaxili-0.1.tar.gz
Algorithm Hash digest
SHA256 1ced2148a42e6e1b43ae0bba0fcfe0e05c850a2afbc38419b0a93e7229efcc0c
MD5 aa1a6ada7d4a46c53fe00b2b36e2e32e
BLAKE2b-256 f13ba5e176feea66d9ffdd954b9c6ac2cf8da6e75da1e6bb9294b6e30e9a1e49

See more details on using hashes here.

File details

Details for the file jaxili-0.1-py3-none-any.whl.

File metadata

  • Download URL: jaxili-0.1-py3-none-any.whl
  • Upload date:
  • Size: 39.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.8.0 colorama/0.4.4 importlib-metadata/6.8.0 keyring/24.2.0 pkginfo/1.8.2 readme-renderer/34.0 requests-toolbelt/0.9.1 requests/2.25.1 rfc3986/1.5.0 tqdm/4.57.0 urllib3/1.26.5 CPython/3.10.12

File hashes

Hashes for jaxili-0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 67e54c379a950136e4e6b162698f938dd56ef5dd4f770bd0bd61c3c2b7e0be5f
MD5 d3f80eb996a7cf864e9ab8700e6a5f9e
BLAKE2b-256 9f78697e6a1229381a20af0ddbd3ea38f971785dbbfbee0d6311efd053f4d9b9

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page