Utility functions for JaxGaussianProcesses
Project description
This project has now been incorporated into GPJax.
JaxUtils
JaxUtils
provides utility functions for the JaxGaussianProcesses
ecosystem.
Contents
PyTree
Overview
jaxutils.PyTree
is a mixin class for registering a python class as a JAX PyTree. You would define your Python class as follows.
class MyClass(jaxutils.PyTree):
...
Example
import jaxutils
from jaxtyping import Float, Array
class Line(jaxutils.PyTree):
def __init__(self, gradient: Float[Array, "1"], intercept: Float[Array, "1"]) -> None
self.gradient = gradient
self.intercept = intercept
def y(self, x: Float[Array, "N"]) -> Float[Array, "N"]
return x * self.gradient + self.intercept
Dataset
Overview
jaxutils.Dataset
is a datset abstraction. In future, we wish to extend this to a heterotopic and isotopic data abstraction.
Example
import jaxutils
import jax.numpy as jnp
# Inputs
X = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
# Outputs
y = jnp.array([[7.0], [8.0], [9.0]])
# Datset
D = jaxutils.Dataset(X=X, y=y)
print(f'The number of datapoints is {D.n}')
print(f'The input dimension is {D.in_dim}')
print(f'The output dimension is {D.out_dim}')
print(f'The input data is {D.X}')
print(f'The output data is {D.y}')
print(f'The data is supervised {D.is_supervised()}')
print(f'The data is unsupervised {D.is_unsupervised()}')
The number of datapoints is 3
The input dimension is 2
The output dimension is 1
The input data is [[1. 2.]
[3. 4.]
[5. 6.]]
The output data is [[7.]
[8.]
[9.]]
The data is supervised True
The data is unsupervised False
You can also add dataset together to concatenate them.
# New inputs
X_new = jnp.array([[1.5, 2.5], [3.5, 4.5], [5.5, 6.5]])
# New outputs
y_new = jnp.array([[7.0], [8.0], [9.0]])
# New dataset
D_new = jaxutils.Dataset(X=X_new, y=y_new)
# Concatenate the two datasets
D = D + D_new
print(f'The number of datapoints is {D.n}')
print(f'The input dimension is {D.in_dim}')
print(f'The output dimension is {D.out_dim}')
print(f'The input data is {D.X}')
print(f'The output data is {D.y}')
print(f'The data is supervised {D.is_supervised()}')
print(f'The data is unsupervised {D.is_unsupervised()}')
The number of datapoints is 6
The input dimension is 2
The output dimension is 1
The input data is [[1. 2. ]
[3. 4. ]
[5. 6. ]
[1.5 2.5]
[3.5 4.5]
[5.5 6.5]]
The output data is [[7.]
[8.]
[9.]
[7.]
[8.]
[9.]]
The data is supervised True
The data is unsupervised False
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Close
Hashes for jaxutils-nightly-0.0.8.dev20240516.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | adec1ea47d45d39f1db478e539c836c7189f638a9afbd6fdbbbeb5a000b865f8 |
|
MD5 | c77c5531a1d5d83647a7a115bfbf32f5 |
|
BLAKE2b-256 | ddf59ba8a95a96d812f2b356bf1f5f74fab01501eeddb3abd629a8ecf629cd85 |
Close
Hashes for jaxutils_nightly-0.0.8.dev20240516-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 53ebb893034378dfd2d6b88d3ab7a98c813d486a9f5a6f694cd93a898396fc96 |
|
MD5 | 5fb94fda3fb9ec2bc71ce85768892ff1 |
|
BLAKE2b-256 | 105b634fe1c90e77994be8b18e6b412430685df89595d8d2bbc4364ef9236466 |