Skip to main content

Utility functions for JaxGaussianProcesses

Project description


This project has now been incorporated into GPJax.

JaxUtils

CircleCI

JaxUtils provides utility functions for the JaxGaussianProcesses ecosystem.

Contents

PyTree

Overview

jaxutils.PyTree is a mixin class for registering a python class as a JAX PyTree. You would define your Python class as follows.

class MyClass(jaxutils.PyTree):
    ...

Example

import jaxutils

from jaxtyping import Float, Array

class Line(jaxutils.PyTree):
    def __init__(self, gradient: Float[Array, "1"], intercept: Float[Array, "1"]) -> None
        self.gradient = gradient
        self.intercept = intercept

    def y(self, x: Float[Array, "N"]) -> Float[Array, "N"]
        return x * self.gradient + self.intercept

Dataset

Overview

jaxutils.Dataset is a datset abstraction. In future, we wish to extend this to a heterotopic and isotopic data abstraction.

Example

import jaxutils
import jax.numpy as jnp

# Inputs
X = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])

# Outputs
y = jnp.array([[7.0], [8.0], [9.0]])

# Datset
D = jaxutils.Dataset(X=X, y=y)

print(f'The number of datapoints is {D.n}')
print(f'The input dimension is {D.in_dim}')
print(f'The output dimension is {D.out_dim}')
print(f'The input data is {D.X}')
print(f'The output data is {D.y}')
print(f'The data is supervised {D.is_supervised()}')
print(f'The data is unsupervised {D.is_unsupervised()}')
The number of datapoints is 3
The input dimension is 2
The output dimension is 1
The input data is [[1. 2.]
 [3. 4.]
 [5. 6.]]
The output data is [[7.]
 [8.]
 [9.]]
The data is supervised True
The data is unsupervised False

You can also add dataset together to concatenate them.

# New inputs
X_new = jnp.array([[1.5, 2.5], [3.5, 4.5], [5.5, 6.5]])

# New outputs
y_new = jnp.array([[7.0], [8.0], [9.0]])

# New dataset
D_new = jaxutils.Dataset(X=X_new, y=y_new)

# Concatenate the two datasets
D = D + D_new

print(f'The number of datapoints is {D.n}')
print(f'The input dimension is {D.in_dim}')
print(f'The output dimension is {D.out_dim}')
print(f'The input data is {D.X}')
print(f'The output data is {D.y}')
print(f'The data is supervised {D.is_supervised()}')
print(f'The data is unsupervised {D.is_unsupervised()}')
The number of datapoints is 6
The input dimension is 2
The output dimension is 1
The input data is [[1.  2. ]
 [3.  4. ]
 [5.  6. ]
 [1.5 2.5]
 [3.5 4.5]
 [5.5 6.5]]
The output data is [[7.]
 [8.]
 [9.]
 [7.]
 [8.]
 [9.]]
The data is supervised True
The data is unsupervised False

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxutils-nightly-0.0.8.dev20240606.tar.gz (30.2 kB view details)

Uploaded Source

Built Distribution

File details

Details for the file jaxutils-nightly-0.0.8.dev20240606.tar.gz.

File metadata

File hashes

Hashes for jaxutils-nightly-0.0.8.dev20240606.tar.gz
Algorithm Hash digest
SHA256 649f4b92b78848a634fb1015da049b0b021fac6a23e426ef5964d3488281b7ea
MD5 00f405ede7103ce76b866471344715d8
BLAKE2b-256 f869ddd4a2a19d3e977a1b280620d4d51f8bd2aeadcd482d9a90318589a181f5

See more details on using hashes here.

File details

Details for the file jaxutils_nightly-0.0.8.dev20240606-py3-none-any.whl.

File metadata

File hashes

Hashes for jaxutils_nightly-0.0.8.dev20240606-py3-none-any.whl
Algorithm Hash digest
SHA256 8eb7c0859c71ab2a9e28f617a442ddae88955c1faa3577c86a875f755cb78d96
MD5 8377910e92de0ed2d90e978a333e1d08
BLAKE2b-256 6239cd5e6c27276f36a4f99eacd81c5d5110071dd9ec721d78e5bb15f7178da4

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page