Skip to main content

A framework for composing Neural Processes in Python

Project description

Neural Processes

CI Coverage Status Latest Docs Code style: black

A framework for composing Neural Processes in Python.

Installation

pip install neuralprocesses tensorflow tensorflow-probability  # For use with TensorFlow
pip install neuralprocesses torch                              # For use with PyTorch

If something is not working or unclear, please feel free to open an issue.

Documentation

See here.

TL;DR! Just Get me Started!

Here you go:

import torch

import neuralprocesses.torch as nps

# Construct a ConvCNP.
convcnp = nps.construct_convgnp(dim_x=1, dim_y=2, likelihood="het")

# Construct optimiser.
opt = torch.optim.Adam(convcnp.parameters(), 1e-3)

# Training: optimise the model for 32 batches.
for _ in range(32):
    # Sample a batch of new context and target sets. Replace this with your data. The
    # shapes are `(batch_size, dimensionality, num_data)`.
    xc = torch.randn(16, 1, 10)  # Context inputs
    yc = torch.randn(16, 2, 10)  # Context outputs
    xt = torch.randn(16, 1, 15)  # Target inputs
    yt = torch.randn(16, 2, 15)  # Target output

    # Compute the loss and update the model parameters.
    loss = -torch.mean(nps.loglik(convcnp, xc, yc, xt, yt, normalise=True))
    opt.zero_grad(set_to_none=True)
    loss.backward()
    opt.step()

# Testing: make some predictions.
mean, var, noiseless_samples, noisy_samples = nps.predict(
    convcnp,
    torch.randn(16, 1, 10),  # Context inputs
    torch.randn(16, 2, 10),  # Context outputs
    torch.randn(16, 1, 15),  # Target inputs
)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

neuralprocesses-0.2.4.tar.gz (71.5 kB view hashes)

Uploaded Source

Built Distribution

neuralprocesses-0.2.4-py3-none-any.whl (108.9 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page