Skip to main content

A PyTorch-based library for kernel machines (SVM, DWD, kernel logistic, etc.)

Project description

TorchKM: Fast Kernel Machines in PyTorch

PyPI Python License GitHub Release tests docs

TorchKM is a GPU-accelerated PyTorch-based library for kernel machines including kernel SVM with a focus on fast and integrated train + tune workflows.

It currently provides:

  • Kernel classification: kernel SVM, kernel DWD, and kernel logistic regression
  • Kernel regression: kernel quantile regression
  • Fast model selection: pathwise solutions over a grid of regularization values (λ)
  • Exact cross-validation for kernel machines
  • GPU acceleration via PyTorch/CUDA, with safe CPU fallback
  • A scikit-learn–style API for easy integration into existing Python workflows

Why TorchKM?

Kernel methods are competitive supervised learning methods on tabular data. In practice, the dominant cost often arises not from a single model fit alone, but from repeated kernel-matrix computations and linear solves across cross-validation folds and tuning parameters.

TorchKM is built for that an integrated training and tuning pipeline. Benchmarks show competitive predictive performance together with substantial speedups over standard baselines.

Documentation

Full documentation, examples, API reference, benchmark-reproduction notes, and developer guides are available at:

https://yikaizhang95.github.io/torchkm/

Installation

Standard install

pip install torchkm

The default installation includes the high-level scikit-learn-style estimator API used in the examples.

Development install

git clone https://github.com/YikaiZhang95/torchkm.git
cd torchkm
pip install -e ".[dev,examples,viz]"
python -m pytest -q

Quickstart

sklearn-style estimators

import numpy as np
import torch
from sklearn.datasets import make_circles
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

from torchkm.estimators import TorchKMSVC

# Toy nonlinear classification task
X, y = make_circles(n_samples=120, factor=0.4, noise=0.08, random_state=0)
X = StandardScaler().fit_transform(X)
y = np.where(y == 0, -1, 1)

Xtr, Xte, ytr, yte = train_test_split(X, y, test_size=0.2, random_state=0)

Cs = np.logspace(2, -2, num=4)
device = "cuda" if torch.cuda.is_available() else "cpu"

clf = TorchKMSVC(
    kernel="rbf",
    Cs=Cs,
    cv=5,
    device=device,
    probability=True,
    max_iter=40,
)

clf.fit(Xtr, ytr)

print("best C:", clf.best_C_)
print("test accuracy:", (clf.predict(Xte) == yte).mean())
print("first 3 probabilities:\n", clf.predict_proba(Xte[:3]))

Set probability=True to enable Platt scaling and predict_proba.

Low-rank approximation

Use low_rank=True when you want to handle a larger data set. The recommended scikit-learn-style API sets this in the constructor:

clf = TorchKMSVC(
    kernel="rbf",
    low_rank=True,
    num_landmarks=40,
    nys_k=20,
    nC=4,
    cv=5,
    device=device,
    probability=True,
    max_iter=40,
)

clf.fit(Xtr, ytr)
(clf.predict(Xte) == yte).mean()

For convenience, low-rank Nyström fitting can also be enabled at fit time: clf.fit(X, y, low_rank=True).

clf = TorchKMSVC(kernel="rbf", Cs=Cs, cv=5, device=device, probability=True)
clf.fit(Xtr, ytr, low_rank=True, num_landmarks=40, nys_k=20)

Kernel quantile regression

import numpy as np
import torch
from sklearn.model_selection import train_test_split

from torchkm.estimators import TorchKMKQR

rng = np.random.default_rng(0)
X = rng.normal(size=(200, 5))
y = np.sin(X[:, 0]) + 0.2 * rng.normal(size=200)

Xtr, Xte, ytr, yte = train_test_split(X, y, test_size=0.25, random_state=0)

Cs = np.logspace(2, -2, 4)
device = "cuda" if torch.cuda.is_available() else "cpu"

qr = TorchKMKQR(
    kernel="rbf",
    Cs=Cs,
    nC=len(Cs),
    cv=3,
    tau=0.5,
    device=device,
    max_iter=40,
)
qr.fit(Xtr, ytr)
print("best C:", qr.best_C_)
print("predictions:", qr.predict(Xte[:3]))

qr_nys = TorchKMKQR(
    kernel="rbf",
    Cs=Cs,
    nC=len(Cs),
    cv=3,
    tau=0.5,
    low_rank=True,
    num_landmarks=40,
    nys_k=20,
    device=device,
    max_iter=40,
)
qr_nys.fit(Xtr, ytr)
print("Nyström predictions:", qr_nys.predict(Xte[:3]))

What TorchKM provides

sklearn-style estimators

  • TorchKMSVC — kernel SVM classifier
  • TorchKMDWD — kernel DWD classifier
  • TorchKMLogit — kernel logistic regression
  • TorchKMKQR — kernel quantile regression

TorchKMKQR provides kernel quantile regression. Use TorchKMKQR(low_rank=True) for the Nyström approximation. There is no separate TorchKMNysKQR class.

Common methods:

  • fit(X, y)
  • decision_function(X)
  • predict(X)
  • predict_proba(X) (when probability=True)
  • platt_plot(X, y)

Utilities

  • rbf_kernel, kernelMult, sigest

Device behavior

For portability, prefer:

device = "cuda" if torch.cuda.is_available() else "cpu"

TorchKM is designed to run on GPU when CUDA is available and fall back safely to CPU otherwise.

When to use TorchKM

TorchKM is a good fit when you want:

  • nonlinear kernel classifiers without leaving the PyTorch ecosystem
  • fast model selection across many regularization values
  • exact LOOCV support for kernel SVM
  • a wrapper API that feels familiar if you already use scikit-learn
  • lower-level access to kernel matrices and solver internals

Testing and coverage

python -m pytest -q

Run coverage locally with:

python -m pytest -q --cov=torchkm --cov-report=term-missing:skip-covered --cov-report=xml --cov-report=html

The GitHub Actions test workflow runs the CPU-safe test suite across Python 3.10, 3.11, and 3.12. On Python 3.11, CI runs the suite with coverage and enforces a minimum coverage threshold of 90%.

The most recent passing CUDA validation currently reports 214 tests passed and 98% coverage on commit 29fbdf4 using an NVIDIA L40S.

CUDA smoke tests are marked with pytest.mark.cuda and skip automatically when CUDA is unavailable. On a CUDA machine, run:

python -m pytest -q -m cuda

Contributing

Issues, bug reports, benchmarks, documentation improvements, and pull requests are welcome. See CONTRIBUTING.md for development setup, testing, pull request guidelines, and bug-report instructions.

Good first contributions include:

  • additional kernels
  • multiclass wrappers
  • more benchmark scripts
  • expanded examples and tutorials

Citation

If you use TorchKM in academic work, please cite:

@article{zhang2026torchkm,
  title   = {TorchKM: GPU-Accelerated Kernel Machines with Fast Model Selection in PyTorch},
  author  = {Zhang, Yikai and Jia, Gaoxiang and Ding, Jie and Wang, Boxiang},
  year    = {2026},
  note    = {Software paper submission}
}

License

MIT License. See LICENSE.

Contact

Yikai Zhang
skyezhang1995@gmail.com

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchkm-4.3.0.tar.gz (75.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchkm-4.3.0-py3-none-any.whl (74.9 kB view details)

Uploaded Python 3

File details

Details for the file torchkm-4.3.0.tar.gz.

File metadata

  • Download URL: torchkm-4.3.0.tar.gz
  • Upload date:
  • Size: 75.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for torchkm-4.3.0.tar.gz
Algorithm Hash digest
SHA256 8cc682be4f509e542e8baf7050893c4b1bf686526d4f42e9e0a26d1829607023
MD5 d5744c63213752d74a1e985f28d24661
BLAKE2b-256 dee678e370ea69e1e591fedfc69c15eab35c67e74102e8aae082f780fefb2756

See more details on using hashes here.

File details

Details for the file torchkm-4.3.0-py3-none-any.whl.

File metadata

  • Download URL: torchkm-4.3.0-py3-none-any.whl
  • Upload date:
  • Size: 74.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for torchkm-4.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 cee244cb81e077289c82f15287c93f6d8843d6c372759cbd628d73b27849f8af
MD5 39916cf5c095922179db128473f5fff7
BLAKE2b-256 adc6baf967e37254fcffcfa60815a68e2c0180935640d9ec4e79ab661bdb63fb

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page