Probabilistic deep learning using JAX
Project description
Ramsey
Probabilistic deep learning using JAX
About
Ramsey is a library for probabilistic modelling using JAX, Flax and NumPyro. It offers high quality implementations of neural processes, Gaussian processes, Bayesian time series and state-space models, clustering processes, and everything else Bayesian.
Ramsey makes use of
- Flax`s module system for models with trainable parameters (such as neural or Gaussian processes),
- NumPyro for models where parameters are endowed with prior distributions (such as Gaussian processes, Bayesian neural networks, ARMA models)
and is hence aimed at being fully compatible with both of them.
Example usage
You can, for instance, construct a simple neural process like this:
from jax import random as jr
from ramsey import NP
from ramsey.nn import MLP
from ramsey.data import sample_from_sine_function
def get_neural_process():
dim = 128
np = NP(
decoder=MLP([dim] * 3 + [2]),
latent_encoder=(
MLP([dim] * 3), MLP([dim, dim * 2])
)
)
return np
key = jr.PRNGKey(23)
data = sample_from_sine_function(key)
neural_process = get_neural_process()
params = neural_process.init(key, x_context=data.x, y_context=data.y, x_target=data.x)
The neural process takes a decoder and a set of two latent encoders as arguments. All of these are typically MLPs, but Ramsey is flexible enough that you can change them, for instance, to CNNs or RNNs. Once the model is defined, you can initialize its parameters just like in Flax.
Installation
To install from PyPI, call:
pip install ramsey
To install the latest GitHub , just call the following on the command line:
pip install git+https://github.com/ramsey-devs/ramsey@<RELEASE>
See also the installation instructions for JAX, if you plan to use Ramsey on GPU/TPU.
Contributing
Contributions in the form of pull requests are more than welcome. A good way to start is to check out issues labelled "good first issue".
In order to contribute:
- Install Ramsey and dev dependencies via
pip install -e '.[dev]'
, - test your contribution/implementation by calling
tox
on the (Unix) command line before submitting a PR.
Why Ramsey
Just as the names of other probabilistic languages are inspired by researchers in the field (e.g., Stan, Edward, Turing), Ramsey takes its name from one of my favourite philosophers/mathematicians, Frank Ramsey.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file ramsey-0.2.1.tar.gz
.
File metadata
- Download URL: ramsey-0.2.1.tar.gz
- Upload date:
- Size: 31.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 425970e1d2f4dcd794bd7b062f96109bf59076b210426ca2d1c62ee372535f09 |
|
MD5 | ec3394d223a4f1aa6ff2e4658e0d8a53 |
|
BLAKE2b-256 | d6ae5af0d04a5056414d740af55c43062818c6e3bcf296f46af13d26f05ab7e7 |
File details
Details for the file ramsey-0.2.1-py3-none-any.whl
.
File metadata
- Download URL: ramsey-0.2.1-py3-none-any.whl
- Upload date:
- Size: 43.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 8239eb87cea9f3f218fcdb9507256c66cce184082025188fc1ef94ca3d2c4d82 |
|
MD5 | 20815278656316ab9f151538598fc202 |
|
BLAKE2b-256 | 55bae98c1a2e591055eae554baa72df0b43d54b53801586034d3a43bbf0b9fe8 |