Skip to main content

A PyTorch-based library for kernel machines (SVM, DWD, kernel logistic, etc.)

Project description

TorchKM: Fast Kernel Machines in PyTorch

PyPI Python License GitHub Release

TorchKM is a PyTorch-based library for kernel machines with a focus on fast train + tune workflows.

It currently provides:

  • Kernel classification: kernel SVM, kernel DWD, and kernel logistic regression
  • Fast model selection: pathwise solutions over a grid of regularization values (λ)
  • Exact LOOCV support for kernel SVM
  • GPU acceleration via PyTorch/CUDA, with safe CPU fallback
  • A scikit-learn–style API for easy integration into existing Python workflows

Why TorchKM?

Kernel methods are still a strong choice when you want nonlinear decision boundaries, convex training objectives, and competitive performance on tabular or moderate-scale datasets. In practice, the bottleneck is often not training one model — it is training and tuning many models.

TorchKM is built for that workflow.

Installation

Minimal install

pip install torchkm

Install the scikit-learn wrapper

pip install "torchkm[sklearn]"

Development install

git clone https://github.com/YikaiZhang95/torchkm.git
cd torchkm
pip install -e ".[dev,sklearn]"
pytest -q

Quickstart

sklearn-style wrapper (recommended)

import numpy as np
import torch
from sklearn.datasets import make_circles
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

from torchkm.estimators import TorchKMSVC

# Toy nonlinear classification task
X, y = make_circles(n_samples=1200, factor=0.4, noise=0.08, random_state=0)
X = StandardScaler().fit_transform(X)
y = np.where(y == 0, -1, 1)   # TorchKM accepts {-1, +1} labels

Xtr, Xte, ytr, yte = train_test_split(X, y, test_size=0.2, random_state=0)

nfolds = 5
Cs = np.logspace(3, -3, num=12)

clf = TorchKMSVC(
    kernel="rbf",
    nC=len(Cs),
    Cs=Cs,
    cv=nfolds,
    device='cuda',
    probability=True,
    max_iter=200,
)

clf.fit(Xtr, ytr)

print("best lambda:", clf.best_C_)
print("test accuracy:", (clf.predict(Xte) == yte).mean())
print("first 3 probabilities:\n", clf.predict_proba(Xte[:3]))

Low-rank approximation

Use the low_rank=True when you want to handle a large data set.

clf = TorchKMSVC(
    kernel="rbf",
    low_rank=True,
    num_landmarks=500,
    nys_k=250,
    nC=20,
    cv=5,
    device='cuda',
    probability=True,
    platt_device='cuda',
)

clf.fit(Xtr, ytr)
(clf.predict(Xte) == yte).mean()

What TorchKM provides

High-level wrappers

  • TorchKMSVC — kernel SVM classifier
  • TorchKMDWD — kernel DWD classifier
  • TorchKMLogit — kernel logistic regression

Common methods:

  • fit(X, y)
  • decision_function(X)
  • predict(X)
  • predict_proba(X) (when probability=True)
  • platt_plot(X, y)

Low-level solvers

  • cvksvm — kernel SVM with cross-validation over λ
  • cvkdwd — kernel DWD with cross-validation over λ
  • cvklogit — kernel logistic regression with cross-validation over λ
  • cvkqr — kernel quantile regression with cross-validation over λ

Utilities

  • rbf_kernel, kernelMult, sigest

Device behavior

For portability, prefer:

device = "cuda" if torch.cuda.is_available() else "cpu"

TorchKM is designed to run on GPU when CUDA is available and fall back safely to CPU otherwise.

When to use TorchKM

TorchKM is a good fit when you want:

  • nonlinear kernel classifiers without leaving the PyTorch ecosystem
  • fast model selection across many regularization values
  • exact LOOCV support for kernel SVM
  • a wrapper API that feels familiar if you already use scikit-learn
  • lower-level access to kernel matrices and solver internals

Testing

pytest -q

Contributing

Issues, bug reports, benchmarks, documentation improvements, and pull requests are welcome.

Good first contributions include:

  • additional kernels
  • multiclass wrappers
  • more benchmark scripts
  • expanded examples and tutorials

Citation

If you use TorchKM in academic work, please cite:

@article{zhang2026torchkm,
  title   = {TorchKM: GPU-Accelerated Kernel Machines with Fast Model Selection in PyTorch},
  author  = {Zhang, Yikai and Jia, Gaoxiang and Ding, Jie and Wang, Boxiang},
  journal = {Journal of Machine Learning Research (MLOSS track)},
  year    = {2026},
  note    = {Software paper submission}
}

License

MIT License. See LICENSE.

Contact

Yikai Zhang
yikai-zhang@uiowa.edu

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchkm-4.2.3.tar.gz (54.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchkm-4.2.3-py3-none-any.whl (74.6 kB view details)

Uploaded Python 3

File details

Details for the file torchkm-4.2.3.tar.gz.

File metadata

  • Download URL: torchkm-4.2.3.tar.gz
  • Upload date:
  • Size: 54.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for torchkm-4.2.3.tar.gz
Algorithm Hash digest
SHA256 c3dd86d62a41012c8f9f774bc0cf7af2942c7b45e879817aab6e8d93fd280dcc
MD5 29bb2e30d102d5d56364cf410faec5a1
BLAKE2b-256 1f3848875e3127bc36e628c62aad00adef9c0e3299727f0cf3b085e227fc2baf

See more details on using hashes here.

File details

Details for the file torchkm-4.2.3-py3-none-any.whl.

File metadata

  • Download URL: torchkm-4.2.3-py3-none-any.whl
  • Upload date:
  • Size: 74.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for torchkm-4.2.3-py3-none-any.whl
Algorithm Hash digest
SHA256 a71bd940ca9f1204ea4ef5561afe5593bb444a02b5a8ac42196ce91a8dc6fea0
MD5 e99f710c6a43a84e6332938212858d4a
BLAKE2b-256 b696d08f0e1d11b7cdf58d2066533faa62840fbea5ce244de4b917c85c1863be

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page