Skip to main content

Classical machine learning using JAX and Equinox

Project description

eqx-learn logo

eqx-learn

A minimal, classical machine learning library built on JAX and Equinox.

Overview

eqx-learn implements classical machine learning algorithms (Linear Regression, PCA, Gaussian Processes, etc.) within the JAX/Equinox eco-system. It provides an API that is highly inspired by scikit-learn, but adapted for JAX's functional programming paradigm. All models are Equinox modules and therefore JAX PyTrees, allowing full differentiability.

NB: This library is currently in very early stages of development, focusing mainly on simple regression algorithms. However, the API has been carefully thought out, and pull requests for additional models, algorithms and utilities are more than welcome!

Installation

You can install eqx-learn directly from the GitHub. Ensure you have a working JAX installation (CPU or GPU) first. Then, using pip:

pip install git+https://github.com/eqx-learn/eqx-learn.git

Example: Linear Regression

This example demonstrates a simple linear regression implementation using the analytic OLS solution.

import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt

from eqxlearn.linear_model import LinearRegressor
from eqxlearn import fit

# Create synthetic data
X = jnp.linspace(0, 10, 50)[:, None]
y = 3.5 * X.squeeze() + 2.0 + jr.normal(jr.key(0), (50,))

# Solve the OLS problem using fit(). In its default mode, this calls the analytic solution provided by,
# LinearRegressor.solve(...). To use an iterative solution (e.g. for larger problems), simply pass
# fit(..., solution='iterative'). Note; LinearRegressor inherits from Regressor -> BaseModel -> eqx.Module
model = LinearRegressor()
model, _ = fit(model, X, y)

# Plot model prediction
X_test = jnp.linspace(0, 10, 100)
y_test = model.predict(X_test)
plt.scatter(X, y, color='black')
plt.plot(X_test, y_test)

Example: Gaussian Process Regression with Scaled Data

This example demonstrates a full machine learning pipeline for gaussian process regression. The code fits an RBF kernel to scaled input data, and then makes scaled output predictions.

import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt

from eqxlearn.preprocessing import StandardScaler
from eqxlearn.gaussian_process import GaussianProcessRegressor
from eqxlearn.gaussian_process.kernels import RBFKernel, WhiteNoiseKernel
from eqxlearn.pipeline import Pipeline
from eqxlearn.compose import TransformedTargetRegressor
from eqxlearn import fit

# 1. Generate scaled data
X_SCALE, Y_SCALE = 0.01, 100.0
X = X_SCALE * jnp.linspace(0, 10, 30)[:, None]
y = Y_SCALE * (jnp.sin(X / X_SCALE).squeeze() + 0.1 * jr.normal(jr.key(1), (30,)))

# 2. Create the model with scaled inputs and outputs
kernel = RBFKernel() + WhiteNoiseKernel(0.01)
pipeline = Pipeline([
    ("scaler_x", StandardScaler()),
    ("gp", GaussianProcessRegressor(kernel=kernel))
])
model = TransformedTargetRegressor(
    regressor=pipeline,
    transformer=StandardScaler()
)

# 3. Fit the model
# fit() inspects the requirements/capabilities of the model being passed.
# This includes conditioning on data via model.condition(), and exact solutions via model.solve().
# For wrapper models (e.g. Pipeline, TransformedTargetRegressor), these are forwarded appropriately.
# Then, fit() runs iterative optimization on the model (using e.g. the optax adam optimizer).
model, losses = fit(model, X, y)

# 4. Make predictions
X_test = X_SCALE * jnp.linspace(0, 10, 100)[:, None]
predictions, variance = model.predict(X_test, return_var=True)
std_dev = jnp.sqrt(variance)

# Plot
plt.figure(figsize=(10, 5))
plt.scatter(X, y, label="Training Data", color="black")
plt.plot(X_test, predictions, label="Model Prediction", color="blue")
plt.fill_between(
    X_test.squeeze(), 
    predictions - 1.96 * std_dev, 
    predictions + 1.96 * std_dev, 
    alpha=0.2, color="blue", label="95% CI"
)
plt.legend()
plt.title("GP with Feature & Target Scaling")
plt.show()

Key Differences from scikit-learn

Immutability & The fit Function

In eqx-learn, models are immutable PyTrees, since they derive from Equinox's Module. Unlike scikit-learn, calling fit is not a member function that updates attributes in place. Instead, it returns a new instance of the model with updated parameters. You must capture this return value:

model, history = fit(model, X, y)

Native Equinox Compatibility

Every estimator and transformer is a standard eqx.Module. This means you can use them anywhere in the JAX ecosystem:

  • Differentiable: You can take gradients through the model parameters or inputs using jax.grad.
  • JIT-table: The entire forward pass (__call__) is compatible with jax.jit.
  • Composable: You can use them inside your own custom training loops or neural network architectures.

Explicit Protocols and Single-Sample Logic

Instead of a monolithic fit method, models implement specific protocols based on their mathematical nature:

  • solve(X, y): Returns exact analytical parameters (e.g., OLS, PCA).
  • condition(X, y): Updates belief state (e.g., GPs).
  • loss(): Defines a custom objective for gradient descent.

Furthermore, models implement single-sample logic via __call__(x). Batching is handled automatically by the base class via jax.vmap, simplifying implementation.

Strict Dimensionality

eqx-learn avoids silent broadcasting. A Regressor strictly expects a target vector of shape (N,).

  • If your target is (N, 1) or (N, M), you must explicitly wrap your model in MultiOutputRegressor.
  • Transformers (like StandardScaler) support polymorphic inversion, accepting (mean, variance) tuples to correctly propagate uncertainty through pipelines.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

eqx_learn-0.2.3.tar.gz (6.5 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

eqx_learn-0.2.3-py3-none-any.whl (40.2 kB view details)

Uploaded Python 3

File details

Details for the file eqx_learn-0.2.3.tar.gz.

File metadata

  • Download URL: eqx_learn-0.2.3.tar.gz
  • Upload date:
  • Size: 6.5 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.3

File hashes

Hashes for eqx_learn-0.2.3.tar.gz
Algorithm Hash digest
SHA256 51bfb2ae5b74637f1e238a8d0735ed610cbc869317cf5841203d02135638704e
MD5 1baa381506fbc35a3fc78abe17572d45
BLAKE2b-256 7da6e975dc011f3c1f9bc66a1f07cec9b0371f57f5123ffe019827ca23f9bb27

See more details on using hashes here.

File details

Details for the file eqx_learn-0.2.3-py3-none-any.whl.

File metadata

  • Download URL: eqx_learn-0.2.3-py3-none-any.whl
  • Upload date:
  • Size: 40.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.3

File hashes

Hashes for eqx_learn-0.2.3-py3-none-any.whl
Algorithm Hash digest
SHA256 9d0da2810f2509fe410883d54803d29a8b3b7a7ebc3f22898fe500778822b6f8
MD5 fc1125a1686ada0babc09ffbe2f48439
BLAKE2b-256 998dc327739cd681dbff0343694366a112cb963cb6a59754950666d8196b1787

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page