Skip to main content

A JAX-based gaussian process structure learning package for time series modelling.

Project description

$\texttt{gallifrey}$: Bayesian Time Series Structure Learning with Gaussian Processes

Documentation DOI License

$\texttt{gallifrey}$ is a Python package designed for Bayesian structure learning, inference, and analysis with Gaussian Process (GP) models, focused on time series data. It is a JAX-based python implementation of the julia package AutoGP.jl by Feras Saad.

$\texttt{gallifrey}$ utilizes JAX for efficient numerical computation and Sequential Monte Carlo (SMC) methods for robust posterior approximation. Unlike most Gaussian Process packages, where a covariance function needs to be specified explicitly, $\texttt{gallifrey}$ infers the covariance structure from the time series.

$\texttt{gallifrey}$ was created with exoplanet transit light curves in mind, but is applicable to a wide variety of time series modelling, analysis, and forecasting tasks.

Core Functionality

  • Gaussian Process (GP) Modeling: Implements Gaussian Processes, leveraging JAX for efficient computation, with a particular focus on accurate uncertainty estimation.

  • Bayesian Structure Learning: Provides a probabilistic framework for identifying latent structure within time series data by dynamically learning the covariance structure of the Gaussian Process.

  • Sequential Monte Carlo (SMC): Employs SMC for robust and fast posterior approximations.

Installation

$\texttt{gallifrey}$ requires Python 3.10 or later.

Option 1: Using pip (Recommended)

pip install gallifrey

Option 2: From source

git clone git@github.com:ChrisBoettner/gallifrey.git
cd gallifrey
pip install .

For development (editable) installation:

pip install -e .

Dependencies

$\texttt{gallifrey}$'s core functionality relies on the following packages:

  • blackjax (>=1.2.5,<2.0.0)
  • jax (>=0.5.0,<0.6.0)
  • flax (>=0.10.3,<0.11.0)
  • equinox (>=0.11.11,<0.12.0)
  • beartype (>=0.19.0,<0.20.0)
  • tensorflow-probability (>=0.25.0,<0.26.0)

Quick Start

This example demonstrates a basic workflow, from data generation to model fitting and prediction.

# Configure JAX to use all CPU cores
import multiprocessing
import os
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count={}".format(
    multiprocessing.cpu_count()
)

# Import necessary packages
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import seaborn as sns  # For plotting

# Import core components from gallifrey
from gallifrey.model import GPConfig, GPModel
from gallifrey.schedule import LinearSchedule

# Example Data Generation
rng_key = jr.PRNGKey(0)
key, data_key = jr.split(rng_key)
n = 120
noise_var = 9.0
x = jnp.linspace(0, 15, n)
y = (x + 0.01) * jnp.sin(x * 3.2) + jnp.sqrt(noise_var) * jr.normal(data_key, (n,))

# Split into training and test sets
xtrain = x[(x < 10)]
ytrain = y[(x < 10)]

# Model Initialization
config = GPConfig()  # Use default configuration (can be customized)
key, model_key = jr.split(key)
gpmodel = GPModel(
    model_key,
    x=xtrain,
    y=ytrain,
    num_particles=8,  # Number of particles for SMC
    config=config,
)

# Model Fitting (SMC)
key, smc_key = jr.split(key)
# Generate an annealing schedule (important for SMC)
annealing_schedule = LinearSchedule().generate(len(xtrain), 10)

final_smc_state, history = gpmodel.fit_smc(
    smc_key,
    annealing_schedule=annealing_schedule,
    n_mcmc=50,      # Number of MCMC steps per SMC iteration
    n_hmc=10,       # Number of HMC steps within each MCMC step
    verbosity=1,     # Control verbosity
)

# Update the model with the final SMC state
gpmodel = gpmodel.update_state(final_smc_state)

# Prediction
xtest = gpmodel.x_transform(jnp.linspace(0, 18, 60)) # Create x values for prediction
dist = gpmodel.get_mixture_distribution(xtest) # Get the predictive distribution

predictive_mean = dist.mean()
predictive_std = dist.stddev()

# Visualization
plt.figure(figsize=(12, 6))
plt.plot(xtest, predictive_mean, label="Predictive Mean", color="C0")
plt.fill_between(
    xtest,
    predictive_mean - predictive_std,
    predictive_mean + predictive_std,
    alpha=0.3,
    label="Predictive Std. Dev.",
    color="C0"
)
plt.scatter(gpmodel.x_transformed, gpmodel.y_transformed, label="Training Data", color="C1", s=20)
plt.scatter(gpmodel.x_transform(x), gpmodel.y_transform(y), label="All Data", color="C2", s=10, alpha=0.5)
plt.show()

Documentation and further examples

More detailed examples can be found in the notebooks/ directory and the documentation.

Contributing

We welcome bug reports, feature requests, and pull requests.

Citation

If you use $\texttt{gallifrey}$ in your research, please cite it as:

@article{https://doi.org/10.1051/0004-6361/202554518,
  doi = {10.1051/0004-6361/202554518},
  author = {Boettner, Christopher},
  title = {gallifrey: JAX-based Gaussian Process Structure Learning for Astronomical Time Series},
  year = {2025},
  journal = {A\&A},
  publisher = {EDP Sciences},
  issn = {0004-6361, 1432-0746},
  eprint = {2505.20394},
  archiveprefix = {arXiv},
  primaryclass = {astro-ph},
  keywords = {Astrophysics - Earth and Planetary Astrophysics,Astrophysics - Instrumentation and Methods for Astrophysics},
  copyright = {{\copyright} 2025, ESO},
}

And please also cite the original paper by Saad et al.

@article{https://doi.org/10.48550/arxiv.2307.09607,
  doi = {10.48550/ARXIV.2307.09607},
  url = {https://arxiv.org/abs/2307.09607},
  author = {Saad,  Feras A. and Patton,  Brian J. and Hoffman,  Matthew D. and Saurous,  Rif A. and Mansinghka,  Vikash K.},
  keywords = {Machine Learning (cs.LG),  Artificial Intelligence (cs.AI),  Methodology (stat.ME),  Machine Learning (stat.ML),  FOS: Computer and information sciences,  FOS: Computer and information sciences},
  title = {Sequential Monte Carlo Learning for Time Series Structure Discovery},
  publisher = {arXiv},
  year = {2023},
  copyright = {arXiv.org perpetual,  non-exclusive license}
}

Acknowledgements

This package is a direct re-implementation of AutoGP.jl and would not be possible without it. The Gaussian Procress implementation is strongly inspired by the fantastic packages GPJax and tinygp.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gallifrey-0.1.1.tar.gz (73.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

gallifrey-0.1.1-py3-none-any.whl (86.3 kB view details)

Uploaded Python 3

File details

Details for the file gallifrey-0.1.1.tar.gz.

File metadata

  • Download URL: gallifrey-0.1.1.tar.gz
  • Upload date:
  • Size: 73.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for gallifrey-0.1.1.tar.gz
Algorithm Hash digest
SHA256 716e6381185ac9b849c1f59a7600efa7622c43141179982d22b0af8dc4e453ff
MD5 a732ad800b04047195ec527a22c0f66b
BLAKE2b-256 85edf55b0dac410193a07f04a87ef76cd62fd818d2723b9b8ea3bccd8cf4c4f7

See more details on using hashes here.

Provenance

The following attestation bundles were made for gallifrey-0.1.1.tar.gz:

Publisher: python-publish.yml on ChrisBoettner/gallifrey

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file gallifrey-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: gallifrey-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 86.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for gallifrey-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 7841ebbc12eec54b1d4b07d50511ff2993f17e802006bf009ecc29c5765e5342
MD5 46f06e080925f3cab65abd3ed4d29f3b
BLAKE2b-256 f1806c4c8b16cea76a83712b434e602244c9e873d92e9b7cbea504d3a716011c

See more details on using hashes here.

Provenance

The following attestation bundles were made for gallifrey-0.1.1-py3-none-any.whl:

Publisher: python-publish.yml on ChrisBoettner/gallifrey

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page