A PyTorch-based library for kernel machines (SVM, DWD, kernel logistic, etc.)
Project description
TorchKM: Fast Kernel Machines in PyTorch
TorchKM is a PyTorch-based library for kernel machines with a focus on fast train + tune workflows.
It currently provides:
- Kernel classification: kernel SVM, kernel DWD, and kernel logistic regression
- Fast model selection: pathwise solutions over a grid of regularization values (
λ) - Exact LOOCV support for kernel SVM
- GPU acceleration via PyTorch/CUDA, with safe CPU fallback
- A scikit-learn–style API for easy integration into existing Python workflows
Why TorchKM?
Kernel methods are still a strong choice when you want nonlinear decision boundaries, convex training objectives, and competitive performance on tabular or moderate-scale datasets. In practice, the bottleneck is often not training one model — it is training and tuning many models.
TorchKM is built for that workflow.
Installation
Minimal install
pip install torchkm
Install the scikit-learn wrapper
pip install "torchkm[sklearn]"
Development install
git clone https://github.com/YikaiZhang95/torchkm.git
cd torchkm
pip install -e ".[dev,sklearn]"
pytest -q
Quickstart
sklearn-style wrapper (recommended)
import numpy as np
import torch
from sklearn.datasets import make_circles
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from torchkm.sklearn_wrapper import TorchKMSVC
device = "cuda" if torch.cuda.is_available() else "cpu"
# Toy nonlinear classification task
X, y = make_circles(n_samples=1200, factor=0.4, noise=0.08, random_state=0)
X = StandardScaler().fit_transform(X)
y = np.where(y == 0, -1, 1) # TorchKM accepts {-1, +1} labels
Xtr, Xte, ytr, yte = train_test_split(X, y, test_size=0.2, random_state=0)
nfolds = 5
foldid = (np.arange(Xtr.shape[0]) % nfolds) + 1
ulam = np.logspace(3, -3, num=12)
clf = TorchKMSVC(
kernel="rbf",
nlam=len(ulam),
ulam=ulam,
nfolds=nfolds,
foldid=foldid,
device=device,
probability=True,
maxit=200,
)
clf.fit(Xtr, ytr)
print("best lambda:", clf.best_lambda_)
print("test accuracy:", (clf.predict(Xte) == yte).mean())
print("first 3 probabilities:\n", clf.predict_proba(Xte[:3]))
Low-level solver API
Use the low-level API when you want explicit control of the kernel matrix or already have a precomputed kernel.
import torch
from torchkm.cvksvm import cvksvm
from torchkm.functions import sigest, rbf_kernel
device = "cuda" if torch.cuda.is_available() else "cpu"
X_train = torch.randn(200, 10, dtype=torch.float64)
y_train = torch.where(torch.rand(200) > 0.5, 1.0, -1.0).to(torch.float64)
sigma = float(sigest(X_train.cpu(), frac=0.5))
K_train = rbf_kernel(X_train.to(device), sigma)
ulam = torch.logspace(3, -3, steps=10, dtype=torch.float64, device=device)
foldid = (torch.arange(X_train.shape[0], device=device) % 5 + 1).to(torch.int64)
model = cvksvm(
Kmat=K_train,
y=y_train.to(device),
nlam=len(ulam),
ulam=ulam,
nfolds=5,
foldid=foldid,
device=device,
maxit=200,
)
model.fit()
What TorchKM provides
High-level wrappers
TorchKMSVC— kernel SVM classifierTorchKMDWD— kernel DWD classifierTorchKMLogit— kernel logistic regression
Common methods:
fit(X, y)decision_function(X)predict(X)predict_proba(X)(whenprobability=True)
Low-level solvers
cvksvm— kernel SVM with cross-validation overλcvkdwd— kernel DWD with cross-validation overλcvklogit— kernel logistic regression with cross-validation overλ
Utilities
rbf_kernel,kernelMult,sigestPlattScalerTorchfor probability calibration
Device behavior
For portability, prefer:
device = "cuda" if torch.cuda.is_available() else "cpu"
TorchKM is designed to run on GPU when CUDA is available and fall back safely to CPU otherwise.
When to use TorchKM
TorchKM is a good fit when you want:
- nonlinear kernel classifiers without leaving the PyTorch ecosystem
- fast model selection across many regularization values
- exact LOOCV support for kernel SVM
- a wrapper API that feels familiar if you already use scikit-learn
- lower-level access to kernel matrices and solver internals
Testing
pytest -q
Contributing
Issues, bug reports, benchmarks, documentation improvements, and pull requests are welcome.
Good first contributions include:
- additional kernels
- multiclass wrappers
- more benchmark scripts
- expanded examples and tutorials
Citation
If you use TorchKM in academic work, please cite:
@article{zhang2026torchkm,
title = {TorchKM: GPU-Accelerated Kernel Machines with Fast Model Selection in PyTorch},
author = {Zhang, Yikai and Jia, Gaoxiang and Wang, Boxiang},
journal = {Journal of Machine Learning Research (MLOSS track)},
year = {2026},
note = {Software paper submission}
}
License
MIT License. See LICENSE.
Contact
Yikai Zhang
yikai-zhang@uiowa.edu
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torchkm-4.1.0.tar.gz.
File metadata
- Download URL: torchkm-4.1.0.tar.gz
- Upload date:
- Size: 52.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
6ad00079f279c527ecdcc37063851fe8c53e762238431ff4507082dd0fe9b51e
|
|
| MD5 |
bb396a34435c2acfe3ecde14846a4df1
|
|
| BLAKE2b-256 |
f796be5b53ccf11c00f1b38725ac2d0153b732fde00703412d0a05841ff425af
|
File details
Details for the file torchkm-4.1.0-py3-none-any.whl.
File metadata
- Download URL: torchkm-4.1.0-py3-none-any.whl
- Upload date:
- Size: 72.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b3e3a5597d4ef33b4027f08e4625f5e2ae83e0d3af3529971d20d6b0d2ec6ecb
|
|
| MD5 |
22604e917d6c48d5b1c4e995614bb387
|
|
| BLAKE2b-256 |
bf60cd1ade0ffb16754c6608f76169359be5121dd42b45711f27a441879d2836
|