Probabilistic deep learning using JAX
Project description
Ramsey
Probabilistic deep learning using JAX
About
Ramsey is a library for probabilistic deep learning using JAX, Flax and NumPyro. Its scope covers
- neural processes (vanilla, attentive, Markovian, convolutional, ...),
- neural Laplace and Fourier operator models,
- etc.
Example usage
You can, for instance, construct a simple neural process like this:
from flax import nnx
from ramsey import NP
from ramsey.nn import MLP # just a flax.nnx module
def get_neural_process(in_features, out_features):
dim = 128
np = NP(
latent_encoder=(
MLP(in_features, [dim, dim], rngs=nnx.Rngs(0)),
MLP(dim, [dim, dim * 2], rngs=nnx.Rngs(1))
),
decoder=MLP(in_features, [dim, dim, out_features * 2], rngs=nnx.Rngs(2))
)
return np
neural_process = get_neural_process(1, 1)
The neural process above takes a decoder and a set of two latent encoders as arguments. All of these are typically flax.nnx MLPs, but
Ramsey is flexible enough that you can change them, for instance, to CNNs or RNNs.
Ramsey provides a unified interface where each method implements (at least) __call__ and loss
functions to transform a set of inputs and compute a training loss, respectively:
from jax import random as jr
from ramsey.data import sample_from_sine_function
data = sample_from_sine_function(jr.key(0))
x_context, y_context = data.x[:, :20, :], data.y[:, :20, :]
x_target, y_target = data.x, data.y
# make a prediction
pred = neural_process(
x_context=x_context,
y_context=y_context,
x_target=x_target,
)
# compute the loss
loss = neural_process.loss(
x_context=x_context,
y_context=y_context,
x_target=x_target,
y_target=y_target
)
Installation
To install from PyPI, call:
pip install ramsey
To install the latest GitHub , just call the following on the command line:
pip install git+https://github.com/ramsey-devs/ramsey@<RELEASE>
See also the installation instructions for JAX, if you plan to use Ramsey on GPU/TPU.
Contributing
Contributions in the form of pull requests are more than welcome. A good way to start is to check out issues labelled "good first issue".
In order to contribute:
- Clone Ramsey and install
uvfrom here, - create a new branch locally
git checkout -b feature/my-new-featureorgit checkout -b issue/fixes-bug, - install all dependencies via
uv sync --all-extras, - implement your contribution and ideally a test case,
- test it by calling
make format,make lintsandmake testson the (Unix) command line, - submit a PR 🙂
Why Ramsey
Just as the names of other probabilistic languages are inspired by researchers in the field (e.g., Stan, Edward, Turing), Ramsey takes its name from one of my favourite philosophers/mathematicians, Frank Ramsey.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file ramsey-0.3.0.tar.gz.
File metadata
- Download URL: ramsey-0.3.0.tar.gz
- Upload date:
- Size: 509.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.8
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
c5bb024c60498ec397fa960fd73f9afa3f4d34c5030087fd1f4c71b343dedfc8
|
|
| MD5 |
6fc05472d71efc407bcf107da29a6a34
|
|
| BLAKE2b-256 |
7071d697d75e83f2c9fd506c2f3c812e765d0db43cd5798240bf3e8534e74862
|
File details
Details for the file ramsey-0.3.0-py3-none-any.whl.
File metadata
- Download URL: ramsey-0.3.0-py3-none-any.whl
- Upload date:
- Size: 27.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.8
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d875813c5339be37f010907d822b6b6439012e0ddfee342353edad60d8d436ac
|
|
| MD5 |
cd45fc6779909a124e110f4197c1f322
|
|
| BLAKE2b-256 |
03f3f58d4e5a0935a37249d5371b9a09c96a92b1f63d823f9e809696ed65cb91
|