Skip to main content

A PyTorch-based library for kernel machines (SVM, DWD, kernel logistic, etc.)

Project description

TorchKM: Fast Kernel Machines in PyTorch

PyPI Python License GitHub Release

TorchKM is a PyTorch-based library for kernel machines with a focus on fast train + tune workflows.

It currently provides:

  • Kernel classification: kernel SVM, kernel DWD, and kernel logistic regression
  • Fast model selection: pathwise solutions over a grid of regularization values (λ)
  • Exact LOOCV support for kernel SVM
  • GPU acceleration via PyTorch/CUDA, with safe CPU fallback
  • A scikit-learn–style API for easy integration into existing Python workflows

Why TorchKM?

Kernel methods are still a strong choice when you want nonlinear decision boundaries, convex training objectives, and competitive performance on tabular or moderate-scale datasets. In practice, the bottleneck is often not training one model — it is training and tuning many models.

TorchKM is built for that workflow.

Installation

Minimal install

pip install torchkm

Install the scikit-learn wrapper

pip install "torchkm[sklearn]"

Development install

git clone https://github.com/YikaiZhang95/torchkm.git
cd torchkm
pip install -e ".[dev,sklearn]"
pytest -q

Quickstart

sklearn-style wrapper (recommended)

import numpy as np
import torch
from sklearn.datasets import make_circles
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

from torchkm.estimators import TorchKMSVC

# Toy nonlinear classification task
X, y = make_circles(n_samples=1200, factor=0.4, noise=0.08, random_state=0)
X = StandardScaler().fit_transform(X)
y = np.where(y == 0, -1, 1)   # TorchKM accepts {-1, +1} labels

Xtr, Xte, ytr, yte = train_test_split(X, y, test_size=0.2, random_state=0)

nfolds = 5
Cs = np.logspace(3, -3, num=12)

clf = TorchKMSVC(
    kernel="rbf",
    nC=len(Cs),
    Cs=Cs,
    nfolds=nfolds,
    device='cuda',
    probability=True,
    max_iter=200,
)

clf.fit(Xtr, ytr)

print("best lambda:", clf.best_C_)
print("test accuracy:", (clf.predict(Xte) == yte).mean())
print("first 3 probabilities:\n", clf.predict_proba(Xte[:3]))

Low-rank approximation

Use the low_rank=True when you want to handle a large data set.

clf = TorchKMSVC(
    kernel="rbf",
    low_rank=True,
    num_landmarks=500,
    nys_k=250,
    nC=20,
    nfolds=5,
    device='cuda',
    probability=True,
    platt_device='cuda',
)

clf.fit(Xtr, ytr)
(clf.predict(Xte) == yte).mean()

What TorchKM provides

High-level wrappers

  • TorchKMSVC — kernel SVM classifier
  • TorchKMDWD — kernel DWD classifier
  • TorchKMLogit — kernel logistic regression

Common methods:

  • fit(X, y)
  • decision_function(X)
  • predict(X)
  • predict_proba(X) (when probability=True)
  • platt_plot(X, y)

Low-level solvers

  • cvksvm — kernel SVM with cross-validation over λ
  • cvkdwd — kernel DWD with cross-validation over λ
  • cvklogit — kernel logistic regression with cross-validation over λ

Utilities

  • rbf_kernel, kernelMult, sigest

Device behavior

For portability, prefer:

device = "cuda" if torch.cuda.is_available() else "cpu"

TorchKM is designed to run on GPU when CUDA is available and fall back safely to CPU otherwise.

When to use TorchKM

TorchKM is a good fit when you want:

  • nonlinear kernel classifiers without leaving the PyTorch ecosystem
  • fast model selection across many regularization values
  • exact LOOCV support for kernel SVM
  • a wrapper API that feels familiar if you already use scikit-learn
  • lower-level access to kernel matrices and solver internals

Testing

pytest -q

Contributing

Issues, bug reports, benchmarks, documentation improvements, and pull requests are welcome.

Good first contributions include:

  • additional kernels
  • multiclass wrappers
  • more benchmark scripts
  • expanded examples and tutorials

Citation

If you use TorchKM in academic work, please cite:

@article{zhang2026torchkm,
  title   = {TorchKM: GPU-Accelerated Kernel Machines with Fast Model Selection in PyTorch},
  author  = {Zhang, Yikai and Jia, Gaoxiang and Wang, Boxiang},
  journal = {Journal of Machine Learning Research (MLOSS track)},
  year    = {2026},
  note    = {Software paper submission}
}

License

MIT License. See LICENSE.

Contact

Yikai Zhang
yikai-zhang@uiowa.edu

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchkm-4.2.1.tar.gz (54.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchkm-4.2.1-py3-none-any.whl (74.6 kB view details)

Uploaded Python 3

File details

Details for the file torchkm-4.2.1.tar.gz.

File metadata

  • Download URL: torchkm-4.2.1.tar.gz
  • Upload date:
  • Size: 54.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for torchkm-4.2.1.tar.gz
Algorithm Hash digest
SHA256 6fb9dbc0dec675264b9373e0d6f742e36901b0aa85b3596c43d631ae455be364
MD5 58a01c077549caa1d44ad9a6cb997beb
BLAKE2b-256 7f17b0fd7ff5ef56c1a9e4fc2e5d781b66d34a4d977a19e88efe90f41d6eb1f5

See more details on using hashes here.

File details

Details for the file torchkm-4.2.1-py3-none-any.whl.

File metadata

  • Download URL: torchkm-4.2.1-py3-none-any.whl
  • Upload date:
  • Size: 74.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for torchkm-4.2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 15c5abd7c9446c96c614c712874037f82bc995adcbdaa940e007b8ed124dfcd6
MD5 c0ecb2697cc63b6ef95cd613263d3cbb
BLAKE2b-256 39a2c5ab89e406aa335d23300fe52dc8d772756a3d4977512f61786214b2031f

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page