Skip to main content

A PyTorch-based library for kernel machines (SVM, DWD, kernel logistic, etc.)

Project description

TorchKM: Fast Kernel Machines in PyTorch

PyPI Python License GitHub Release

TorchKM is a PyTorch-based library for kernel machines with a focus on fast train + tune workflows.

It currently provides:

  • Kernel classification: kernel SVM, kernel DWD, and kernel logistic regression
  • Fast model selection: pathwise solutions over a grid of regularization values (λ)
  • Exact LOOCV support for kernel SVM
  • GPU acceleration via PyTorch/CUDA, with safe CPU fallback
  • A scikit-learn–style API for easy integration into existing Python workflows

Why TorchKM?

Kernel methods are still a strong choice when you want nonlinear decision boundaries, convex training objectives, and competitive performance on tabular or moderate-scale datasets. In practice, the bottleneck is often not training one model — it is training and tuning many models.

TorchKM is built for that workflow.

Installation

Minimal install

pip install torchkm

Install the scikit-learn wrapper

pip install "torchkm[sklearn]"

Development install

git clone https://github.com/YikaiZhang95/torchkm.git
cd torchkm
pip install -e ".[dev,sklearn]"
pytest -q

Quickstart

sklearn-style wrapper (recommended)

import numpy as np
import torch
from sklearn.datasets import make_circles
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

from torchkm.estimators import TorchKMSVC

# Toy nonlinear classification task
X, y = make_circles(n_samples=1200, factor=0.4, noise=0.08, random_state=0)
X = StandardScaler().fit_transform(X)
y = np.where(y == 0, -1, 1)   # TorchKM accepts {-1, +1} labels

Xtr, Xte, ytr, yte = train_test_split(X, y, test_size=0.2, random_state=0)

nfolds = 5
Cs = np.logspace(3, -3, num=12)

clf = TorchKMSVC(
    kernel="rbf",
    nC=len(Cs),
    Cs=Cs,
    cv=nfolds,
    device='cuda',
    probability=True,
    max_iter=200,
)

clf.fit(Xtr, ytr)

print("best lambda:", clf.best_C_)
print("test accuracy:", (clf.predict(Xte) == yte).mean())
print("first 3 probabilities:\n", clf.predict_proba(Xte[:3]))

Low-rank approximation

Use the low_rank=True when you want to handle a large data set.

clf = TorchKMSVC(
    kernel="rbf",
    low_rank=True,
    num_landmarks=500,
    nys_k=250,
    nC=20,
    cv=5,
    device='cuda',
    probability=True,
    platt_device='cuda',
)

clf.fit(Xtr, ytr)
(clf.predict(Xte) == yte).mean()

What TorchKM provides

High-level wrappers

  • TorchKMSVC — kernel SVM classifier
  • TorchKMDWD — kernel DWD classifier
  • TorchKMLogit — kernel logistic regression

Common methods:

  • fit(X, y)
  • decision_function(X)
  • predict(X)
  • predict_proba(X) (when probability=True)
  • platt_plot(X, y)

Low-level solvers

  • cvksvm — kernel SVM with cross-validation over λ
  • cvkdwd — kernel DWD with cross-validation over λ
  • cvklogit — kernel logistic regression with cross-validation over λ

Utilities

  • rbf_kernel, kernelMult, sigest

Device behavior

For portability, prefer:

device = "cuda" if torch.cuda.is_available() else "cpu"

TorchKM is designed to run on GPU when CUDA is available and fall back safely to CPU otherwise.

When to use TorchKM

TorchKM is a good fit when you want:

  • nonlinear kernel classifiers without leaving the PyTorch ecosystem
  • fast model selection across many regularization values
  • exact LOOCV support for kernel SVM
  • a wrapper API that feels familiar if you already use scikit-learn
  • lower-level access to kernel matrices and solver internals

Testing

pytest -q

Contributing

Issues, bug reports, benchmarks, documentation improvements, and pull requests are welcome.

Good first contributions include:

  • additional kernels
  • multiclass wrappers
  • more benchmark scripts
  • expanded examples and tutorials

Citation

If you use TorchKM in academic work, please cite:

@article{zhang2026torchkm,
  title   = {TorchKM: GPU-Accelerated Kernel Machines with Fast Model Selection in PyTorch},
  author  = {Zhang, Yikai and Jia, Gaoxiang and Ding, Jie and Wang, Boxiang},
  journal = {Journal of Machine Learning Research (MLOSS track)},
  year    = {2026},
  note    = {Software paper submission}
}

License

MIT License. See LICENSE.

Contact

Yikai Zhang
yikai-zhang@uiowa.edu

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchkm-4.2.2.tar.gz (54.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchkm-4.2.2-py3-none-any.whl (74.6 kB view details)

Uploaded Python 3

File details

Details for the file torchkm-4.2.2.tar.gz.

File metadata

  • Download URL: torchkm-4.2.2.tar.gz
  • Upload date:
  • Size: 54.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for torchkm-4.2.2.tar.gz
Algorithm Hash digest
SHA256 6efd7be5209999dff9cbc125ccb3c9089d032ed81fd8c8cf5ffc89a6dc3e2f0b
MD5 60f1aeb6bb5e9368bdabb46907146e29
BLAKE2b-256 2a67772a767464b746281557f1826cd2863a1d340eea0518a522d9b20cb48863

See more details on using hashes here.

File details

Details for the file torchkm-4.2.2-py3-none-any.whl.

File metadata

  • Download URL: torchkm-4.2.2-py3-none-any.whl
  • Upload date:
  • Size: 74.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for torchkm-4.2.2-py3-none-any.whl
Algorithm Hash digest
SHA256 5483b1ab2b2ccb1b3b1b1fa43af6146a421a6317e70d5534e4e7c2018b65fcbc
MD5 04255152bc970abf34577fd68183dca2
BLAKE2b-256 7136d185a29e8999f1014d3309eb7a3f5a4504cf488db2d379bdfe1facc80d15

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page