Skip to main content

Implementation of the Instantaneous Orthogonal Linear Mixing Model

Project description

OILMM

CI Coverage Status Latest Docs Code style: black

Implementation of the Orthogonal Instantaneous Linear Mixing Model

Citation:

@inproceedings{Bruinsma:2020:Scalable_Exact_Inference_in_Multi-Output,
    title = {Scalable Exact Inference in Multi-Output {Gaussian} Processes},
    year = {2020},
    author = {Wessel P. Bruinsma and Eric Perim and Will Tebbutt and J. Scott Hosking and Arno Solin and Richard E. Turner},
    booktitle = {Proceedings of 37th International Conference on Machine Learning},
    series = {Proceedings of Machine Learning Research},
    publisher = {PMLR},
    volume = {119},
    eprint = {https://arxiv.org/abs/1911.06287},
}

Contents:

Requirements and Installation

See the instructions here. Then simply

pip install oilmm

TLDR

import numpy as np
from stheno import EQ, GP

# Use TensorFlow as the backend for the OILMM.
import tensorflow as tf
from oilmm.tensorflow import OILMM


def build_latent_processes(ps):
    # Return models for latent processes, which are noise-contaminated GPs.
    return [
        (
            p.variance.positive(1) * GP(EQ().stretch(p.length_scale.positive(1))),
            p.noise.positive(1e-2),
        )
        for p, _ in zip(ps, range(3))
    ]

# Construct model.  
prior = OILMM(tf.float32, build_latent_processes, num_outputs=6)

# Create some sample data.
x = np.linspace(0, 10, 100)
y = prior.sample(x)  # Sample from the prior.

# Fit the model to the data.
prior.fit(x, y, trace=True, jit=True)
prior.vs.print()  # Print all learned parameters.

# Make predictions.
posterior = prior.condition(x, y)  # Construct posterior model.
mean, var = posterior.predict(x)  # Predict with the posterior model.
lower = mean - 1.96 * np.sqrt(var)
upper = mean + 1.96 * np.sqrt(var)
Minimisation of "negative_log_marginal_likelihood":
    Iteration 1/1000:
        Time elapsed: 0.9 s
        Time left:  855.4 s
        Objective value: -0.1574
    Iteration 105/1000:
        Time elapsed: 1.0 s
        Time left:  15.5 s
        Objective value: -0.5402
    Done!
Termination message:
    CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH
latent_processes.processes[0].variance: 1.829
latent_processes.processes[0].length_scale: 1.078
latent_processes.processes[0].noise: 9.979e-03
latent_processes.processes[1].variance: 1.276
latent_processes.processes[1].length_scale: 0.9262
latent_processes.processes[1].noise: 0.03924
latent_processes.processes[2].variance: 1.497
latent_processes.processes[2].length_scale: 1.092
latent_processes.processes[2].noise: 0.04833
mixing_matrix.u:
    (6x3 array of data type float32)
    [[ 0.543 -0.237 -0.111]
     [ 0.578 -0.185 -0.357]
     [-0.204 -0.094 -0.567]
     [-0.554 -0.413 -0.081]
     [-0.12   0.571 -0.66 ]
     [-0.089 -0.636 -0.31 ]]
noise:      0.02245

Basic Usage

Examples of Latent Process Models

Smooth Processes

from stheno import GP, EQ

def build_latent_processes(ps):
    return [
        (
            p.variance.positive(1) * GP(EQ().stretch(p.length_scale.positive(1))),
            p.noise.positive(1e-2),
        )
        for p, _ in zip(ps, range(3))
    ]

Smooth Processes With A Rational Quadratic Kernel

from stheno import GP, RQ

def build_latent_processes(ps):
    return [
        (
            p.variance.positive(1)
            * GP(RQ(p.alpha.positive(1e-2)).stretch(p.length_scale.positive(1))),
            p.noise.positive(1e-2),
        )
        for p, _ in zip(ps, range(3))
    ]

Weakly Periodic Processes

from stheno import GP, EQ

def build_latent_processes(ps):
    return [
        (
            p.variance.positive(1)
            * GP(
                # Periodic component:
                EQ()
                .stretch(p.periodic.length_scale.positive(0.7))
                .periodic(p.periodic.period.positive(24))
                # Make the periodic component slowly change over time:
                * EQ().stretch(p.periodic.decay.positive(72))
            ),
            p.noise.positive(1e-2),
        )
        for p, _ in zip(ps, range(3))
    ]

Bayesian Linear Regression

from stheno import GP, Linear

num_features = 10


def build_latent_processes(ps):
    return [
        (
            GP(Linear().stretch(p.length_scales.positive(1, shape=(num_features,)))),
            p.noise.positive(1e-2),
        )
        for p, _ in zip(ps, range(3))
    ]

Advanced Usage

Use the OILMM Within Your Model

Kronecker-Structured Mixing Matrix

from matrix import Kronecker

p_left, m_left = 10, 3  # Shape of left factor in Kronecker product
p_right, m_right = 5, 2  # Shape of right factor in Kronecker product


def build_mixing_matrix(ps, p, m):
    return Kronecker(
        ps.left.orthogonal(shape=(p_left, m_left)),
        ps.right.orthogonal(shape=(p_right, m_right)),
    )


prior = OILMM(
    dtype,
    latent_processes=build_latent_processes,
    mixing_matrix=build_mixing_matrix,
    num_outputs=p_left * p_right
)

Reproduce Experiments From the Paper

TODO: Install requirements.

Scripts to rerun individual experiments from the paper can be found in the experiments folder. A shell script is provided to rerun all experiments from the paper at once:

sh run_experiments.sh

The results can then be found in the generated _experiments folder.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

oilmm-0.5.0.tar.gz (28.7 kB view details)

Uploaded Source

Built Distribution

oilmm-0.5.0-py3-none-any.whl (16.9 kB view details)

Uploaded Python 3

File details

Details for the file oilmm-0.5.0.tar.gz.

File metadata

  • Download URL: oilmm-0.5.0.tar.gz
  • Upload date:
  • Size: 28.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.2

File hashes

Hashes for oilmm-0.5.0.tar.gz
Algorithm Hash digest
SHA256 f7eed5dbf64ba04cac927aed107c669c5d550ee37f8f79774fc5a1acfe7eb48a
MD5 a30b03a5169e6844660fcc2a68f039b5
BLAKE2b-256 8afd3c55ac82067924ddc0aeb17d95098703b9a611ae34e94489478861be75f1

See more details on using hashes here.

File details

Details for the file oilmm-0.5.0-py3-none-any.whl.

File metadata

  • Download URL: oilmm-0.5.0-py3-none-any.whl
  • Upload date:
  • Size: 16.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.2

File hashes

Hashes for oilmm-0.5.0-py3-none-any.whl
Algorithm Hash digest
SHA256 20cb360efaea1c43175e2dd580a114fc750baaa60ab8d8af615d2cbcbf3d33ce
MD5 e657402fd316ec2c650fe9aff52d535f
BLAKE2b-256 38ad20fafe27fef0fb9a7f50e0b5a05140201650841bfa4d8c38bdb44a9cb168

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page