Skip to main content

Classical machine learning using JAX and Equinox

Project description

eqx-learn logo

eqx-learn

A minimal, classical machine learning library built on JAX and Equinox.

Overview

eqx-learn implements classical machine learning algorithms (Linear Regression, PCA, Gaussian Processes, etc.) within the JAX/Equinox eco-system. It provides an API that is highly inspired by scikit-learn, but adapted for JAX's functional programming paradigm. All models are Equinox modules and therefore JAX PyTrees, allowing full differentiability.

NB: This library is currently in very early stages of development, focusing mainly on simple regression algorithms. However, the API has been carefully thought out, and pull requests for additional models, algorithms and utilities are more than welcome!

Installation

You can install eqx-learn directly from the GitHub. Ensure you have a working JAX installation (CPU or GPU) first. Then, using pip:

pip install git+https://github.com/eqx-learn/eqx-learn.git

Example: Linear Regression

This example demonstrates a simple linear regression implementation using the analytic OLS solution.

import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt

from eqxlearn.linear_model import LinearRegressor
from eqxlearn import fit

# Create synthetic data
X = jnp.linspace(0, 10, 50)[:, None]
y = 3.5 * X.squeeze() + 2.0 + jr.normal(jr.key(0), (50,))

# Solve the OLS problem using fit(). In its default mode, this calls the analytic solution provided by,
# LinearRegressor.solve(...). To use an iterative solution (e.g. for larger problems), simply pass
# fit(..., solution='iterative'). Note; LinearRegressor inherits from Regressor -> BaseModel -> eqx.Module
model = LinearRegressor()
model, _ = fit(model, X, y)

# Plot model prediction
X_test = jnp.linspace(0, 10, 100)
y_test = model.predict(X_test)
plt.scatter(X, y, color='black')
plt.plot(X_test, y_test)

Example: Gaussian Process Regression with Scaled Data

This example demonstrates a full machine learning pipeline for gaussian process regression. The code fits an RBF kernel to scaled input data, and then makes scaled output predictions.

import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt

from eqxlearn.preprocessing import StandardScaler
from eqxlearn.gaussian_process import GaussianProcessRegressor
from eqxlearn.gaussian_process.kernels import RBFKernel, WhiteNoiseKernel
from eqxlearn.pipeline import Pipeline
from eqxlearn.compose import TransformedTargetRegressor
from eqxlearn import fit

# 1. Generate scaled data
X_SCALE, Y_SCALE = 0.01, 100.0
X = X_SCALE * jnp.linspace(0, 10, 30)[:, None]
y = Y_SCALE * (jnp.sin(X / X_SCALE).squeeze() + 0.1 * jr.normal(jr.key(1), (30,)))

# 2. Create the model with scaled inputs and outputs
kernel = RBFKernel() + WhiteNoiseKernel(0.01)
pipeline = Pipeline([
    ("scaler_x", StandardScaler()),
    ("gp", GaussianProcessRegressor(kernel=kernel))
])
model = TransformedTargetRegressor(
    regressor=pipeline,
    transformer=StandardScaler()
)

# 3. Fit the model
# fit() inspects the requirements/capabilities of the model being passed.
# This includes conditioning on data via model.condition(), and exact solutions via model.solve().
# For wrapper models (e.g. Pipeline, TransformedTargetRegressor), these are forwarded appropriately.
# Then, fit() runs iterative optimization on the model (using e.g. the optax adam optimizer).
model, losses = fit(model, X, y)

# 4. Make predictions
X_test = X_SCALE * jnp.linspace(0, 10, 100)[:, None]
predictions, variance = model.predict(X_test, return_var=True)
std_dev = jnp.sqrt(variance)

# Plot
plt.figure(figsize=(10, 5))
plt.scatter(X, y, label="Training Data", color="black")
plt.plot(X_test, predictions, label="Model Prediction", color="blue")
plt.fill_between(
    X_test.squeeze(), 
    predictions - 1.96 * std_dev, 
    predictions + 1.96 * std_dev, 
    alpha=0.2, color="blue", label="95% CI"
)
plt.legend()
plt.title("GP with Feature & Target Scaling")
plt.show()

Key Differences from scikit-learn

Immutability & The fit Function

In eqx-learn, models are immutable PyTrees, since they derive from Equinox's Module. Unlike scikit-learn, calling fit is not a member function that updates attributes in place. Instead, it returns a new instance of the model with updated parameters. You must capture this return value:

model, history = fit(model, X, y)

Native Equinox Compatibility

Every estimator and transformer is a standard eqx.Module. This means you can use them anywhere in the JAX ecosystem:

  • Differentiable: You can take gradients through the model parameters or inputs using jax.grad.
  • JIT-table: The entire forward pass (__call__) is compatible with jax.jit.
  • Composable: You can use them inside your own custom training loops or neural network architectures.

Explicit Protocols and Single-Sample Logic

Instead of a monolithic fit method, models implement specific protocols based on their mathematical nature:

  • solve(X, y): Returns exact analytical parameters (e.g., OLS, PCA).
  • condition(X, y): Updates belief state (e.g., GPs).
  • loss(): Defines a custom objective for gradient descent.

Furthermore, models implement single-sample logic via __call__(x). Batching is handled automatically by the base class via jax.vmap, simplifying implementation.

Strict Dimensionality

eqx-learn avoids silent broadcasting. A Regressor strictly expects a target vector of shape (N,).

  • If your target is (N, 1) or (N, M), you must explicitly wrap your model in MultiOutputRegressor.
  • Transformers (like StandardScaler) support polymorphic inversion, accepting (mean, variance) tuples to correctly propagate uncertainty through pipelines.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

eqx_learn-0.2.4.tar.gz (6.5 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

eqx_learn-0.2.4-py3-none-any.whl (40.2 kB view details)

Uploaded Python 3

File details

Details for the file eqx_learn-0.2.4.tar.gz.

File metadata

  • Download URL: eqx_learn-0.2.4.tar.gz
  • Upload date:
  • Size: 6.5 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for eqx_learn-0.2.4.tar.gz
Algorithm Hash digest
SHA256 81617ccf9d8f683af32f7a36e534c3c7915cc390200e4a95c995bc3eb8d584ec
MD5 6e500e0b44631638da2b8a1111317ebd
BLAKE2b-256 93317b99dd10ceaa7b4296c814bf00769fa9fe02743419f7f241c1df033b7be0

See more details on using hashes here.

Provenance

The following attestation bundles were made for eqx_learn-0.2.4.tar.gz:

Publisher: publish.yml on gvcallen/eqx-learn

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file eqx_learn-0.2.4-py3-none-any.whl.

File metadata

  • Download URL: eqx_learn-0.2.4-py3-none-any.whl
  • Upload date:
  • Size: 40.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for eqx_learn-0.2.4-py3-none-any.whl
Algorithm Hash digest
SHA256 4ace930149951e13fcebd67232ee0db517254822a49efc2404be081303258933
MD5 2fbe0393dad42cd3b9f7e754b43d6190
BLAKE2b-256 10a79c50c6825f1d6365831a7ebb4e62fea8141f52f5deedd9988f6d6e67cc23

See more details on using hashes here.

Provenance

The following attestation bundles were made for eqx_learn-0.2.4-py3-none-any.whl:

Publisher: publish.yml on gvcallen/eqx-learn

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page