Skip to main content

A unified framework for marginal and conditional two-sample testing with kernels in JAX.

Project description

🌱 rkhs

rkhs is a small Python framework for marginal and conditional two-sample testing with kernels in JAX.


⚙️ Installation

pip install rkhs

🚀 Features

rkhs provides JAX-native two-sample tests (marginal, conditional, mixed) based on kernel embeddings with analytical or bootstrap confidence bounds, a simple API, and pluggable kernels.

Three test modes — one API

You can test the following two-sample hypotheses with one common set of primitives:

  • Marginal: $H_0: P = Q$
  • Conditional: $H_0(x_1,x_2): P(\cdot\mid X=x_1) = Q(\cdot\mid X=x_2)$
  • Mixed: $H_0(x): P(\cdot\mid X=x) = Q$

The test compares kernel embeddings in RKHS norm and rejects $H_0$ at level $\alpha$ if

$$ |\hat\mu_P - \hat\mu_Q|_\mathcal{H} > \beta_P + \beta_Q \quad, $$

where $\beta_\ast$ are finite-sample confidence radii from the selected regime.

Confidence regimes

Analytical bounds. Finite-sample guarantees under the stated assumptions (conservative, little overhead).

Bootstrap bounds. Data-driven thresholds with typically higher power (cost scales with the number of resamples).

JAX integration

Works with jit/vmap, runs on CPU/GPU/TPU, and uses explicit PRNGKey for reproducibility.

Kernels

Popular kernels are built in: Gaussian, Matern, Laplacian, Polynomial, Linear.

Conditional tests use a scalar kernel on the input domain and a separate kernel on the output domain: VectorKernel(x=..., y=..., regularization=...).


🧩 Usage

1) Marginal two-sample test (analytical bounds)

import jax
from rkhs.testing import TestEmbedding, TwoSampleTest
from rkhs.kernels import GaussianKernel

# toy data: two 3D Gaussians with different means
xs_1 = jax.random.normal(key=jax.random.key(1), shape=(200, 3))
xs_2 = jax.random.normal(key=jax.random.key(2), shape=(200, 3)) + 1.0

# kernel on the sample space
kernel = GaussianKernel(bandwidth=1.5, data_shape=(3,))

# embedding + analytical confidence radius
kme_1 = TestEmbedding.analytical(
    kme=kernel.kme(xs_1),   # embed dataset in RKHS
    kernel_bound=1.0        # sup_x k(x, x)
)
kme_2 = TestEmbedding.analytical(
    kme=kernel.kme(xs_2),   # embed dataset in RKHS
    kernel_bound=1.0        # sup_x k(x, x)
) 

# level-α test
test = TwoSampleTest.from_embeddings(kme_1, kme_2, level=0.05)

decision = test.reject      # boolean (reject H_0?)
distance = test.distance    # RKHS distance
threshold = test.threshold  # β_P + β_Q
print(decision, distance, threshold)

2) Marginal test (bootstrap bounds)

import jax
from rkhs.testing import TestEmbedding, TwoSampleTest
from rkhs.kernels import GaussianKernel

# toy data: two 3D Gaussians with different means
xs_1 = jax.random.normal(key=jax.random.key(1), shape=(200, 3))
xs_2 = jax.random.normal(key=jax.random.key(2), shape=(200, 3)) + 1.0

# kernel on the sample space
kernel = GaussianKernel(bandwidth=1.5, data_shape=(3,))

# embedding + analytical confidence radius
kme_1 = TestEmbedding.bootstrap(
    kme=kernel.kme(xs_1),   # embed dataset in RKHS
    key=jax.random.key(3),  # random key
    n_bootstrap=1000        # number of bootstrap resamples
)
kme_2 = TestEmbedding.bootstrap(
    kme=kernel.kme(xs_2),   # embed dataset in RKHS
    key=jax.random.key(4),  # random key
    n_bootstrap=1000        # number of bootstrap resamples
) 

# level-α test
test = TwoSampleTest.from_embeddings(kme_1, kme_2, level=0.05)

decision = test.reject      # boolean (reject H_0?)
distance = test.distance    # RKHS distance
threshold = test.threshold  # β_P + β_Q
print(decision, distance, threshold)

3) Conditional two-sample test at selected covariates

import jax
from rkhs import VectorKernel
from rkhs.testing import ConditionalTestEmbedding, TwoSampleTest
from rkhs.kernels import GaussianKernel

# synthetic x,y pairs with additive noise
xs_1 = jax.random.normal(key=jax.random.key(1), shape=(1000, 3))
ys_1 = xs_1 + jax.random.normal(key=jax.random.key(2), shape=(1000, 3)) * 0.05

xs_2 = jax.random.normal(key=jax.random.key(3), shape=(1000, 3))
ys_2 = xs_2 + jax.random.normal(key=jax.random.key(4), shape=(1000, 3)) * 0.05 + 0.5

# inputs at which to test for distributional equality: H_0(x): P(. | x) =? Q(. | x), for x in grid
covariates = jax.numpy.linspace(jax.numpy.array([-3, -3, -3]), jax.numpy.array([3, 3, 3]), num=100)

# vector-valued kernel over inputs X and outputs Y
kernel = VectorKernel(
    x=GaussianKernel(bandwidth=0.5, data_shape=(3,)),  # kernel used for ridge regression
    y=GaussianKernel(bandwidth=1.0, data_shape=(3,)),  # kernel used for embedding marginal distribution at each covariate
    regularization=0.1
)

# conditional embedding + bootstrap confidence radius
cme_1 = ConditionalTestEmbedding.bootstrap(
    cme=kernel.cme(xs_1, ys_1), # embed dataset in vector-valued RKHS
    grid=covariates,            # covariates used in bootstrap of threshold parameters
    key=jax.random.key(5),      # random key
    n_bootstrap=100             # number of bootstrap resamples
)

cme_2 = ConditionalTestEmbedding.bootstrap(
    cme=kernel.cme(xs_2, ys_2), # embed dataset in vector-valued RKHS
    grid=covariates,            # covariates used in bootstrap of threshold parameters
    key=jax.random.key(6),      # random key
    n_bootstrap=100             # number of bootstrap resamples
)

# evaluate CMEs at covariates -> embeds each distribution over Y in RKHS of `kernel.y`
kme_1 = cme_1(covariates)
kme_2 = cme_2(covariates)

# batched test across all covariates
test = TwoSampleTest.from_embeddings(kme_1, kme_2, level=0.05)
reject_per_x = test.reject   # Boolean array

decision = test.reject      # boolean (reject H_0(x)?, individually for each covariate). shape: covariates.shape
distance = test.distance    # RKHS distance. shape: covariates.shape
threshold = test.threshold  # β_P + β_Q
print(decision, distance, threshold)

4) Mixed test: $P(\cdot\mid X=x)$ vs. $Q$

import jax
from rkhs import VectorKernel
from rkhs.testing import TestEmbedding, ConditionalTestEmbedding, TwoSampleTest
from rkhs.kernels import GaussianKernel

# dataset from marginal distribution over Y
ys_1 = jax.random.normal(key=jax.random.key(1), shape=(1000, 3)) * 0.05

# synthetic x,y pairs with additive noise
xs_2 = jax.random.normal(key=jax.random.key(2), shape=(1000, 3))
ys_2 = xs_2 + jax.random.normal(key=jax.random.key(3), shape=(1000, 3)) * 0.05

# inputs at which to test for distributional equality: H_0(x): P =? Q(. | x), for x in grid
covariates = jax.numpy.linspace(jax.numpy.array([-3, -3, -3]), jax.numpy.array([3, 3, 3]), num=200)

y_kernel = GaussianKernel(bandwidth=1.0, data_shape=(3,))

# vector-valued kernel over inputs X and outputs Y
vector_kernel = VectorKernel(
    x=GaussianKernel(bandwidth=0.5, data_shape=(3,)),  # kernel used for ridge regression
    y=y_kernel,                                        # kernel used for embedding marginal distribution at each covariate
    regularization=0.1
)

# embedding + analytical confidence radius (can be drop-in replaced with bootstrap radius)
kme_1 = TestEmbedding.analytical(
    kme=y_kernel.kme(ys_1), # embed dataset in RKHS
    kernel_bound=1.0        # sup_x k(x, x)
)

# conditional embedding + analytical confidence radius
cme_2 = ConditionalTestEmbedding.bootstrap(
    cme=vector_kernel.cme(xs_2, ys_2),  # embed dataset in vector-valued RKHS
    grid=covariates,                    # covariates used in bootstrap of threshold parameters
    key=jax.random.key(4),              # random key
    n_bootstrap=100                     # number of bootstrap resamples
)

# evaluate CME at covariates -> embeds each distribution over Y in RKHS of `kernel.y`
kme_2 = cme_2(covariates)

# batched test across all covariates
test = TwoSampleTest.from_embeddings(kme_1, kme_2, level=0.05)
reject_per_x = test.reject   # Boolean array

decision = test.reject      # boolean (reject H_0(x)?, individually for each covariate). shape: covariates.shape
distance = test.distance    # RKHS distance. shape: covariates.shape
threshold = test.threshold  # β_P + β_Q
print(decision, distance, threshold)

🔍 Kernel quick reference

  • LinearKernel — compares means (first moment).
  • PolynomialKernel(degree=d) — compares moments up to degree $d$.
  • Gaussian, Matern, Laplacian — characteristic; compare full distributions.

For conditional tests:

  • Input kernel (x): used to learn the conditional embedding (not for comparison).
  • Output kernel (y): determines what aspects of the conditional law are compared.

🧠 Notes

  • Embeddings preserve batch axes; passing a batch of covariates returns a batch of embeddings.
  • All randomness is explicit via jax.random.PRNGKey.
  • You can use your own custom kernel by extending rkhs.Kernel:
from jax import Array
from rkhs import Kernel
import jax

class MyCustomKernel(Kernel):
    def __init__(self, data_shape: tuple[int, ...]):
        super().__init__(data_shape)
        ...
    
    def _dot(self, x1: Array, x2: Array) -> Array:
        ...  # your logic here (must be jit-compilable)

📚 References

  • Marginal test: Gretton, A., et al. (2012). A Kernel Two-Sample Test. JMLR page · PDF

  • Conditional test: Massiani, P.-F., et al. (2025). A Kernel Conditional Two-Sample Test. arXiv · PDF

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

rkhs-0.0.1.tar.gz (19.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

rkhs-0.0.1-py3-none-any.whl (16.5 kB view details)

Uploaded Python 3

File details

Details for the file rkhs-0.0.1.tar.gz.

File metadata

  • Download URL: rkhs-0.0.1.tar.gz
  • Upload date:
  • Size: 19.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for rkhs-0.0.1.tar.gz
Algorithm Hash digest
SHA256 c3b9e12d0828faf57ef72bfb1f4cc7de9afd03f8a4c2a8ec6b802dd66915d8c1
MD5 f2e9935a935199dda1486a17c638526f
BLAKE2b-256 041ef1dc3e2ee4e01e4f0486eaac53c7a02179c07d5e56341956b6ec67943c54

See more details on using hashes here.

Provenance

The following attestation bundles were made for rkhs-0.0.1.tar.gz:

Publisher: pypi-release.yml on lukashaverbeck/rkhs

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file rkhs-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: rkhs-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 16.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for rkhs-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 899d4a078dd4c7e703f4ff14b188a1fd53ea114339b99128e75a45a25285ff82
MD5 e2c47a94ba0dbc0aa7bd05d644fc606b
BLAKE2b-256 0585a37aefe7019da6ad1559d993d7f0c47be7cb7f56313003b16c0b190add89

See more details on using hashes here.

Provenance

The following attestation bundles were made for rkhs-0.0.1-py3-none-any.whl:

Publisher: pypi-release.yml on lukashaverbeck/rkhs

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page