Skip to main content

Utility functions for JaxGaussianProcesses

Project description


This project has now been incorporated into GPJax.

JaxUtils

CircleCI

JaxUtils provides utility functions for the JaxGaussianProcesses ecosystem.

Contents

PyTree

Overview

jaxutils.PyTree is a mixin class for registering a python class as a JAX PyTree. You would define your Python class as follows.

class MyClass(jaxutils.PyTree):
    ...

Example

import jaxutils

from jaxtyping import Float, Array

class Line(jaxutils.PyTree):
    def __init__(self, gradient: Float[Array, "1"], intercept: Float[Array, "1"]) -> None
        self.gradient = gradient
        self.intercept = intercept

    def y(self, x: Float[Array, "N"]) -> Float[Array, "N"]
        return x * self.gradient + self.intercept

Dataset

Overview

jaxutils.Dataset is a datset abstraction. In future, we wish to extend this to a heterotopic and isotopic data abstraction.

Example

import jaxutils
import jax.numpy as jnp

# Inputs
X = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])

# Outputs
y = jnp.array([[7.0], [8.0], [9.0]])

# Datset
D = jaxutils.Dataset(X=X, y=y)

print(f'The number of datapoints is {D.n}')
print(f'The input dimension is {D.in_dim}')
print(f'The output dimension is {D.out_dim}')
print(f'The input data is {D.X}')
print(f'The output data is {D.y}')
print(f'The data is supervised {D.is_supervised()}')
print(f'The data is unsupervised {D.is_unsupervised()}')
The number of datapoints is 3
The input dimension is 2
The output dimension is 1
The input data is [[1. 2.]
 [3. 4.]
 [5. 6.]]
The output data is [[7.]
 [8.]
 [9.]]
The data is supervised True
The data is unsupervised False

You can also add dataset together to concatenate them.

# New inputs
X_new = jnp.array([[1.5, 2.5], [3.5, 4.5], [5.5, 6.5]])

# New outputs
y_new = jnp.array([[7.0], [8.0], [9.0]])

# New dataset
D_new = jaxutils.Dataset(X=X_new, y=y_new)

# Concatenate the two datasets
D = D + D_new

print(f'The number of datapoints is {D.n}')
print(f'The input dimension is {D.in_dim}')
print(f'The output dimension is {D.out_dim}')
print(f'The input data is {D.X}')
print(f'The output data is {D.y}')
print(f'The data is supervised {D.is_supervised()}')
print(f'The data is unsupervised {D.is_unsupervised()}')
The number of datapoints is 6
The input dimension is 2
The output dimension is 1
The input data is [[1.  2. ]
 [3.  4. ]
 [5.  6. ]
 [1.5 2.5]
 [3.5 4.5]
 [5.5 6.5]]
The output data is [[7.]
 [8.]
 [9.]
 [7.]
 [8.]
 [9.]]
The data is supervised True
The data is unsupervised False

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxutils-nightly-0.0.8.dev20241129.tar.gz (30.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jaxutils_nightly-0.0.8.dev20241129-py3-none-any.whl (18.0 kB view details)

Uploaded Python 3

File details

Details for the file jaxutils-nightly-0.0.8.dev20241129.tar.gz.

File metadata

File hashes

Hashes for jaxutils-nightly-0.0.8.dev20241129.tar.gz
Algorithm Hash digest
SHA256 b2e283487a624f5921ee7ec8e7891d1220a054b8f8e1e7ec01a1fcd459b46d4a
MD5 9d4b94b6c098727dd8085a45ea96509c
BLAKE2b-256 24568780b5fb306275463220f6a10975da509bfb9995e5f36090208c63572327

See more details on using hashes here.

File details

Details for the file jaxutils_nightly-0.0.8.dev20241129-py3-none-any.whl.

File metadata

File hashes

Hashes for jaxutils_nightly-0.0.8.dev20241129-py3-none-any.whl
Algorithm Hash digest
SHA256 ab5a4bedbec20fd6b67d2b88df4bb2dfc098a60230703d91ff5a5a752acd0f12
MD5 98fa5ac206581dd69a5f59ccb5a5b13e
BLAKE2b-256 6901489a676853721a83e73fab720bab358edfac7e0ddb3a22e1899e9bd5094a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page