Skip to main content

Utility functions for JaxGaussianProcesses

Project description


This project has now been incorporated into GPJax.

JaxUtils

CircleCI

JaxUtils provides utility functions for the JaxGaussianProcesses ecosystem.

Contents

PyTree

Overview

jaxutils.PyTree is a mixin class for registering a python class as a JAX PyTree. You would define your Python class as follows.

class MyClass(jaxutils.PyTree):
    ...

Example

import jaxutils

from jaxtyping import Float, Array

class Line(jaxutils.PyTree):
    def __init__(self, gradient: Float[Array, "1"], intercept: Float[Array, "1"]) -> None
        self.gradient = gradient
        self.intercept = intercept

    def y(self, x: Float[Array, "N"]) -> Float[Array, "N"]
        return x * self.gradient + self.intercept

Dataset

Overview

jaxutils.Dataset is a datset abstraction. In future, we wish to extend this to a heterotopic and isotopic data abstraction.

Example

import jaxutils
import jax.numpy as jnp

# Inputs
X = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])

# Outputs
y = jnp.array([[7.0], [8.0], [9.0]])

# Datset
D = jaxutils.Dataset(X=X, y=y)

print(f'The number of datapoints is {D.n}')
print(f'The input dimension is {D.in_dim}')
print(f'The output dimension is {D.out_dim}')
print(f'The input data is {D.X}')
print(f'The output data is {D.y}')
print(f'The data is supervised {D.is_supervised()}')
print(f'The data is unsupervised {D.is_unsupervised()}')
The number of datapoints is 3
The input dimension is 2
The output dimension is 1
The input data is [[1. 2.]
 [3. 4.]
 [5. 6.]]
The output data is [[7.]
 [8.]
 [9.]]
The data is supervised True
The data is unsupervised False

You can also add dataset together to concatenate them.

# New inputs
X_new = jnp.array([[1.5, 2.5], [3.5, 4.5], [5.5, 6.5]])

# New outputs
y_new = jnp.array([[7.0], [8.0], [9.0]])

# New dataset
D_new = jaxutils.Dataset(X=X_new, y=y_new)

# Concatenate the two datasets
D = D + D_new

print(f'The number of datapoints is {D.n}')
print(f'The input dimension is {D.in_dim}')
print(f'The output dimension is {D.out_dim}')
print(f'The input data is {D.X}')
print(f'The output data is {D.y}')
print(f'The data is supervised {D.is_supervised()}')
print(f'The data is unsupervised {D.is_unsupervised()}')
The number of datapoints is 6
The input dimension is 2
The output dimension is 1
The input data is [[1.  2. ]
 [3.  4. ]
 [5.  6. ]
 [1.5 2.5]
 [3.5 4.5]
 [5.5 6.5]]
The output data is [[7.]
 [8.]
 [9.]
 [7.]
 [8.]
 [9.]]
The data is supervised True
The data is unsupervised False

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxutils-nightly-0.0.8.dev20241028.tar.gz (30.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jaxutils_nightly-0.0.8.dev20241028-py3-none-any.whl (18.0 kB view details)

Uploaded Python 3

File details

Details for the file jaxutils-nightly-0.0.8.dev20241028.tar.gz.

File metadata

File hashes

Hashes for jaxutils-nightly-0.0.8.dev20241028.tar.gz
Algorithm Hash digest
SHA256 79ad30f24af95ed4440f937eead83ffed9ad24483677fd8e866c335627ffb69d
MD5 456a2e7ea517c2f53d4f82cc4a54bb0b
BLAKE2b-256 ae2a9f4dde0004239cd8c34327e1115734e0a8f56015c28a66a0d7e4e59c3e22

See more details on using hashes here.

File details

Details for the file jaxutils_nightly-0.0.8.dev20241028-py3-none-any.whl.

File metadata

File hashes

Hashes for jaxutils_nightly-0.0.8.dev20241028-py3-none-any.whl
Algorithm Hash digest
SHA256 df237005f70d86cb9fee94fd3e9b1259cc4f1135c8bb0cae11471f54ce06bbcd
MD5 f1ba580f64fb6f80c86bdca78eff4f5d
BLAKE2b-256 e77253b42790753beea1600b9353d289386d89aa6d6f955fa9386a2ccb3136ca

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page