Skip to main content

Utility functions for JaxGaussianProcesses

Project description


This project has now been incorporated into GPJax.

JaxUtils

CircleCI

JaxUtils provides utility functions for the JaxGaussianProcesses ecosystem.

Contents

PyTree

Overview

jaxutils.PyTree is a mixin class for registering a python class as a JAX PyTree. You would define your Python class as follows.

class MyClass(jaxutils.PyTree):
    ...

Example

import jaxutils

from jaxtyping import Float, Array

class Line(jaxutils.PyTree):
    def __init__(self, gradient: Float[Array, "1"], intercept: Float[Array, "1"]) -> None
        self.gradient = gradient
        self.intercept = intercept

    def y(self, x: Float[Array, "N"]) -> Float[Array, "N"]
        return x * self.gradient + self.intercept

Dataset

Overview

jaxutils.Dataset is a datset abstraction. In future, we wish to extend this to a heterotopic and isotopic data abstraction.

Example

import jaxutils
import jax.numpy as jnp

# Inputs
X = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])

# Outputs
y = jnp.array([[7.0], [8.0], [9.0]])

# Datset
D = jaxutils.Dataset(X=X, y=y)

print(f'The number of datapoints is {D.n}')
print(f'The input dimension is {D.in_dim}')
print(f'The output dimension is {D.out_dim}')
print(f'The input data is {D.X}')
print(f'The output data is {D.y}')
print(f'The data is supervised {D.is_supervised()}')
print(f'The data is unsupervised {D.is_unsupervised()}')
The number of datapoints is 3
The input dimension is 2
The output dimension is 1
The input data is [[1. 2.]
 [3. 4.]
 [5. 6.]]
The output data is [[7.]
 [8.]
 [9.]]
The data is supervised True
The data is unsupervised False

You can also add dataset together to concatenate them.

# New inputs
X_new = jnp.array([[1.5, 2.5], [3.5, 4.5], [5.5, 6.5]])

# New outputs
y_new = jnp.array([[7.0], [8.0], [9.0]])

# New dataset
D_new = jaxutils.Dataset(X=X_new, y=y_new)

# Concatenate the two datasets
D = D + D_new

print(f'The number of datapoints is {D.n}')
print(f'The input dimension is {D.in_dim}')
print(f'The output dimension is {D.out_dim}')
print(f'The input data is {D.X}')
print(f'The output data is {D.y}')
print(f'The data is supervised {D.is_supervised()}')
print(f'The data is unsupervised {D.is_unsupervised()}')
The number of datapoints is 6
The input dimension is 2
The output dimension is 1
The input data is [[1.  2. ]
 [3.  4. ]
 [5.  6. ]
 [1.5 2.5]
 [3.5 4.5]
 [5.5 6.5]]
The output data is [[7.]
 [8.]
 [9.]
 [7.]
 [8.]
 [9.]]
The data is supervised True
The data is unsupervised False

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxutils-nightly-0.0.8.dev20240422.tar.gz (30.2 kB view details)

Uploaded Source

Built Distribution

File details

Details for the file jaxutils-nightly-0.0.8.dev20240422.tar.gz.

File metadata

File hashes

Hashes for jaxutils-nightly-0.0.8.dev20240422.tar.gz
Algorithm Hash digest
SHA256 a8d295b757f4cdf80762aeb5d655e1881a851ba88f26d9c8c209669dac684d77
MD5 683d6525b895fa6b3e27aeaff4ed652d
BLAKE2b-256 c6d0dab93e4934a1d24d80775cfb7fe3d12d633362ea5dcaf2698406b660d01c

See more details on using hashes here.

File details

Details for the file jaxutils_nightly-0.0.8.dev20240422-py3-none-any.whl.

File metadata

File hashes

Hashes for jaxutils_nightly-0.0.8.dev20240422-py3-none-any.whl
Algorithm Hash digest
SHA256 2e5785eb13887efebc0b2d7d9d0ae0f9cae8d26ad234f329ed8b45d6a62c30a3
MD5 50cda3d0003bb4f267201006ebeebcdb
BLAKE2b-256 4a886cd59345c218202944c3b6623c8b2d9dd1b0eebced50b6aaee5e132b883c

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page