Skip to main content

Probabilistic deep learning using JAX

Project description

Ramsey

active ci coverage documentation version

Probabilistic deep learning using JAX

About

Ramsey is a library for probabilistic deep learning using JAX, Flax and NumPyro. Its scope covers

  • neural processes (vanilla, attentive, Markovian, convolutional, ...),
  • neural Laplace and Fourier operator models,
  • etc.

Example usage

You can, for instance, construct a simple neural process like this:

from flax import nnx

from ramsey import NP
from ramsey.nn import MLP  # just a flax.nnx module

def get_neural_process(in_features, out_features):
  dim = 128
  np = NP(
    latent_encoder=(
      MLP(in_features, [dim, dim], rngs=nnx.Rngs(0)),
      MLP(dim, [dim, dim * 2], rngs=nnx.Rngs(1))
    ),
    decoder=MLP(in_features, [dim, dim, out_features * 2], rngs=nnx.Rngs(2))
  )
  return np

neural_process = get_neural_process(1, 1)

The neural process above takes a decoder and a set of two latent encoders as arguments. All of these are typically flax.nnx MLPs, but Ramsey is flexible enough that you can change them, for instance, to CNNs or RNNs.

Ramsey provides a unified interface where each method implements (at least) __call__ and loss functions to transform a set of inputs and compute a training loss, respectively:

from jax import random as jr
from ramsey.data import sample_from_sine_function

data = sample_from_sine_function(jr.key(0))
x_context, y_context = data.x[:, :20, :],  data.y[:, :20, :]
x_target, y_target = data.x, data.y

# make a prediction
pred = neural_process(
  x_context=x_context,
  y_context=y_context,
  x_target=x_target,
)

# compute the loss
loss = neural_process.loss(
  x_context=x_context,
  y_context=y_context,
  x_target=x_target,
  y_target=y_target
)

Installation

To install from PyPI, call:

pip install ramsey

To install the latest GitHub , just call the following on the command line:

pip install git+https://github.com/ramsey-devs/ramsey@<RELEASE>

See also the installation instructions for JAX, if you plan to use Ramsey on GPU/TPU.

Contributing

Contributions in the form of pull requests are more than welcome. A good way to start is to check out issues labelled "good first issue".

In order to contribute:

  1. Clone Ramsey and install uv from here,
  2. create a new branch locally git checkout -b feature/my-new-feature or git checkout -b issue/fixes-bug,
  3. install all dependencies via uv sync --all-extras,
  4. implement your contribution and ideally a test case,
  5. test it by calling make format, make lints and make tests on the (Unix) command line,
  6. submit a PR 🙂

Why Ramsey

Just as the names of other probabilistic languages are inspired by researchers in the field (e.g., Stan, Edward, Turing), Ramsey takes its name from one of my favourite philosophers/mathematicians, Frank Ramsey.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ramsey-0.3.0.tar.gz (509.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

ramsey-0.3.0-py3-none-any.whl (27.1 kB view details)

Uploaded Python 3

File details

Details for the file ramsey-0.3.0.tar.gz.

File metadata

  • Download URL: ramsey-0.3.0.tar.gz
  • Upload date:
  • Size: 509.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.8

File hashes

Hashes for ramsey-0.3.0.tar.gz
Algorithm Hash digest
SHA256 c5bb024c60498ec397fa960fd73f9afa3f4d34c5030087fd1f4c71b343dedfc8
MD5 6fc05472d71efc407bcf107da29a6a34
BLAKE2b-256 7071d697d75e83f2c9fd506c2f3c812e765d0db43cd5798240bf3e8534e74862

See more details on using hashes here.

File details

Details for the file ramsey-0.3.0-py3-none-any.whl.

File metadata

  • Download URL: ramsey-0.3.0-py3-none-any.whl
  • Upload date:
  • Size: 27.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.8

File hashes

Hashes for ramsey-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 d875813c5339be37f010907d822b6b6439012e0ddfee342353edad60d8d436ac
MD5 cd45fc6779909a124e110f4197c1f322
BLAKE2b-256 03f3f58d4e5a0935a37249d5371b9a09c96a92b1f63d823f9e809696ed65cb91

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page