A PyTorch-based library for kernel machines (SVM, DWD, kernel logistic, etc.)
Project description
TorchKM: Fast Kernel Machines in PyTorch
TorchKM is a GPU-accelerated PyTorch-based library for kernel machines including kernel SVM with a focus on fast and integrated train + tune workflows.
It currently provides:
- Kernel classification: kernel SVM, kernel DWD, and kernel logistic regression
- Kernel regression: kernel quantile regression
- Fast model selection: pathwise solutions over a grid of regularization values (
λ) - Exact cross-validation for kernel machines
- GPU acceleration via PyTorch/CUDA, with safe CPU fallback
- A scikit-learn–style API for easy integration into existing Python workflows
Why TorchKM?
Kernel methods are competitive supervised learning methods on tabular data. In practice, the dominant cost often arises not from a single model fit alone, but from repeated kernel-matrix computations and linear solves across cross-validation folds and tuning parameters.
TorchKM is built for that an integrated training and tuning pipeline. Benchmarks show competitive predictive performance together with substantial speedups over standard baselines.
Documentation
Full documentation, examples, API reference, benchmark-reproduction notes, and developer guides are available at:
https://yikaizhang95.github.io/torchkm/
Installation
Standard install
pip install torchkm
The default installation includes the high-level scikit-learn-style estimator API used in the examples.
Development install
git clone https://github.com/YikaiZhang95/torchkm.git
cd torchkm
pip install -e ".[dev,examples,viz]"
python -m pytest -q
Quickstart
sklearn-style estimators
import numpy as np
import torch
from sklearn.datasets import make_circles
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from torchkm.estimators import TorchKMSVC
# Toy nonlinear classification task
X, y = make_circles(n_samples=120, factor=0.4, noise=0.08, random_state=0)
X = StandardScaler().fit_transform(X)
y = np.where(y == 0, -1, 1)
Xtr, Xte, ytr, yte = train_test_split(X, y, test_size=0.2, random_state=0)
Cs = np.logspace(2, -2, num=4)
device = "cuda" if torch.cuda.is_available() else "cpu"
clf = TorchKMSVC(
kernel="rbf",
Cs=Cs,
cv=5,
device=device,
probability=True,
max_iter=40,
)
clf.fit(Xtr, ytr)
print("best C:", clf.best_C_)
print("test accuracy:", (clf.predict(Xte) == yte).mean())
print("first 3 probabilities:\n", clf.predict_proba(Xte[:3]))
Set probability=True to enable Platt scaling and predict_proba.
Low-rank approximation
Use low_rank=True when you want to handle a larger data set. The recommended
scikit-learn-style API sets this in the constructor:
clf = TorchKMSVC(
kernel="rbf",
low_rank=True,
num_landmarks=40,
nys_k=20,
nC=4,
cv=5,
device=device,
probability=True,
max_iter=40,
)
clf.fit(Xtr, ytr)
(clf.predict(Xte) == yte).mean()
For convenience, low-rank Nyström fitting can also be enabled at fit time:
clf.fit(X, y, low_rank=True).
clf = TorchKMSVC(kernel="rbf", Cs=Cs, cv=5, device=device, probability=True)
clf.fit(Xtr, ytr, low_rank=True, num_landmarks=40, nys_k=20)
Kernel quantile regression
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from torchkm.estimators import TorchKMKQR
rng = np.random.default_rng(0)
X = rng.normal(size=(200, 5))
y = np.sin(X[:, 0]) + 0.2 * rng.normal(size=200)
Xtr, Xte, ytr, yte = train_test_split(X, y, test_size=0.25, random_state=0)
Cs = np.logspace(2, -2, 4)
device = "cuda" if torch.cuda.is_available() else "cpu"
qr = TorchKMKQR(
kernel="rbf",
Cs=Cs,
nC=len(Cs),
cv=3,
tau=0.5,
device=device,
max_iter=40,
)
qr.fit(Xtr, ytr)
print("best C:", qr.best_C_)
print("predictions:", qr.predict(Xte[:3]))
qr_nys = TorchKMKQR(
kernel="rbf",
Cs=Cs,
nC=len(Cs),
cv=3,
tau=0.5,
low_rank=True,
num_landmarks=40,
nys_k=20,
device=device,
max_iter=40,
)
qr_nys.fit(Xtr, ytr)
print("Nyström predictions:", qr_nys.predict(Xte[:3]))
What TorchKM provides
sklearn-style estimators
TorchKMSVC— kernel SVM classifierTorchKMDWD— kernel DWD classifierTorchKMLogit— kernel logistic regressionTorchKMKQR— kernel quantile regression
TorchKMKQR provides kernel quantile regression. Use
TorchKMKQR(low_rank=True) for the Nyström approximation. There is no separate
TorchKMNysKQR class.
Common methods:
fit(X, y)decision_function(X)predict(X)predict_proba(X)(whenprobability=True)platt_plot(X, y)
Utilities
rbf_kernel,kernelMult,sigest
Device behavior
For portability, prefer:
device = "cuda" if torch.cuda.is_available() else "cpu"
TorchKM is designed to run on GPU when CUDA is available and fall back safely to CPU otherwise.
When to use TorchKM
TorchKM is a good fit when you want:
- nonlinear kernel classifiers without leaving the PyTorch ecosystem
- fast model selection across many regularization values
- exact LOOCV support for kernel SVM
- a wrapper API that feels familiar if you already use scikit-learn
- lower-level access to kernel matrices and solver internals
Testing and coverage
python -m pytest -q
Run coverage locally with:
python -m pytest -q --cov=torchkm --cov-report=term-missing:skip-covered --cov-report=xml --cov-report=html
The GitHub Actions test workflow runs the CPU-safe test suite across Python 3.10, 3.11, and 3.12. On Python 3.11, CI runs the suite with coverage and enforces a minimum coverage threshold of 90%.
The most recent passing CUDA validation currently reports 214 tests passed and 98% coverage on commit 29fbdf4 using an NVIDIA L40S.
CUDA smoke tests are marked with pytest.mark.cuda and skip automatically when
CUDA is unavailable. On a CUDA machine, run:
python -m pytest -q -m cuda
Contributing
Issues, bug reports, benchmarks, documentation improvements, and pull requests are welcome. See CONTRIBUTING.md for development setup, testing, pull request guidelines, and bug-report instructions.
Good first contributions include:
- additional kernels
- multiclass wrappers
- more benchmark scripts
- expanded examples and tutorials
Citation
If you use TorchKM in academic work, please cite:
@article{zhang2026torchkm,
title = {TorchKM: GPU-Accelerated Kernel Machines with Fast Model Selection in PyTorch},
author = {Zhang, Yikai and Jia, Gaoxiang and Ding, Jie and Wang, Boxiang},
year = {2026},
note = {Software paper submission}
}
License
MIT License. See LICENSE.
Contact
Yikai Zhang
skyezhang1995@gmail.com
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torchkm-4.3.0.tar.gz.
File metadata
- Download URL: torchkm-4.3.0.tar.gz
- Upload date:
- Size: 75.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
8cc682be4f509e542e8baf7050893c4b1bf686526d4f42e9e0a26d1829607023
|
|
| MD5 |
d5744c63213752d74a1e985f28d24661
|
|
| BLAKE2b-256 |
dee678e370ea69e1e591fedfc69c15eab35c67e74102e8aae082f780fefb2756
|
File details
Details for the file torchkm-4.3.0-py3-none-any.whl.
File metadata
- Download URL: torchkm-4.3.0-py3-none-any.whl
- Upload date:
- Size: 74.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
cee244cb81e077289c82f15287c93f6d8843d6c372759cbd628d73b27849f8af
|
|
| MD5 |
39916cf5c095922179db128473f5fff7
|
|
| BLAKE2b-256 |
adc6baf967e37254fcffcfa60815a68e2c0180935640d9ec4e79ab661bdb63fb
|