Skip to main content

torchsvm, a PyTorch-based library that trains kernel SVMs and other large-margin classifiers

Project description

torchsvm

img

This is a PyTorch-based package to solve kernel SVM with GPU.

Table of contents

Introduction

torchsvm, a PyTorch-based library that trains kernel SVMs and other large-margin classifiers with exact leave-one-out cross-validation (LOOCV) error computation. Conventional SVM solvers often face scalability and efficiency challenges, especially on large datasets or when multiple cross-validation runs are required. torchsvm computes LOOCV at the same cost as training a single SVM while boosting speed and scalability via CUDA-accelerated matrix operations. Benchmark experiments indicate that TorchKSVM outperforms existing kernel SVM solvers in efficiency and speed. This document shows how to use the torchsvm package to fit kernel SVM.

When dealing with low-dimensional problems or more complex scenarios, such as requiring non-linear decision boundaries or higher accuracy, kernel SVMs can be formulated using the kernel method within a reproducing kernel Hilbert space (RKHS). For consistency, we adopt the same notation introduced in the high-dimensional case in Chapter One.

Given a random sample $\{y_i, x_i\}_{i=1}^n$, the kernel SVM can be formulated as a function estimation problem:

kernel SVM formulation

where norm is the RKHS norm that acts as a regularizer, and $\lambda > 0$ is a tuning parameter.

According to the representer theorem for reproducing kernels (Wahba, 1990), the solution to our problem takes the form:

f(x) formula

The coefficients $\alpha^{SVM}$ are obtained by solving the optimization problem:

alpha optimization

where $\mathbf{K}$ is the kernel matrix.

Installation

You can use pip to install this package.

pip install torchsvm

Quick start

Import necessary libraries and functions:

from torchsvm.cvksvm import cvksvm
from torchsvm.functions import *
import torch
import numpy

The usages are similar with scikit-learn:

model = cvksvm(Kmat=Kmat, y=y_train, nlam=nlam, ulam=ulam, nfolds=nfolds, eps=1e-5, maxit=1000, gamma=1e-8, is_exact=0, device='cuda')
model.fit()

Usage

Generate simulation data

functions provides a simulation data generation function, data_gen, to generate data from a mixture of Gaussian models. functions also provides kernel operations like rbf_kernel and kernelMult, as well as data processing functions such as standardize.

# Sample data
nn = 10000 # Number of samples
nm = 5    # Number of clusters per class
pp = 10   # Number of features
p1 = p2 = pp // 2    # Number of positive/negative centers
mu = 2.0  # Mean shift
ro = 3  # Standard deviation for normal distribution
sdn = 42  # Seed for reproducibility

nlam = 50
torch.manual_seed(sdn)
ulam = torch.logspace(3, -3, steps=nlam)

X_train, y_train, means_train = data_gen(nn, nm, pp, p1, p2, mu, ro, sdn)
X_test, y_test, means_test = data_gen(nn // 10, nm, pp, p1, p2, mu, ro, sdn)
X_train = standardize(X_train)
X_test = standardize(X_test)

sig = sigest(X_train)
Kmat = rbf_kernel(X_train, sig)

Basic operation

torchsvm mainly provides cvksvm to tune kernel SVM fast with GPU acceleration and compute exact leave-one-out cross-validation (LOOCV) errors if needed.

model = cvksvm(Kmat=Kmat, y=y_train, nlam=nlam, ulam=ulam, nfolds=5, eps=1e-5, maxit=1000, gamma=1e-8, is_exact=0, device='cuda')
model.fit()
### Tune parameter
cv_mis = model.cv(model.pred, y_train).numpy()
best_ind = numpy.argmin(cv_mis)
best_ind

### Test Error and objective value
Kmat = Kmat.double()
alpmat = model.alpmat.to('cpu')
intcpt = alpmat[0,best_ind]
alp = alpmat[1:,best_ind]
ka = torch.mv(Kmat, alp)
aka = torch.dot(alp, ka)
obj_magic = model.objfun(intcpt, aka, ka, y_train, ulam[best_ind], nn)

Kmat_new = kernelMult(X_test, X_train, sig)
Kmat_new = Kmat_new.double()

result = torch.mv(Kmat_new, alpmat[1:,best_ind]) + alpmat[0, best_ind]

ypred = torch.where(result > 0, torch.tensor(1), torch.tensor(-1))

torch.mean((ypred == y_test).float())

Probability estimation

### Platt scaling
oof_f = torch.where(model.pred > 0, 1, -1).to(device = 'cpu')[:, best_ind]
platt = platt.PlattScalerTorch(dtype=torch.double, device='cuda').fit(oof_f, y_train)

X_test_raw = torch.mv(Kmat_new, alpmat[1:,best_ind]) + alpmat[0, best_ind]

with torch.no_grad():
    p_platt = platt.predict_proba(X_test_raw)[:, 1].cpu().numpy()
    y_test_np = torch.as_tensor(y_test).cpu().numpy()

# Reliability data for Platt
bc, mp, fp, cnt = platt.reliability_curve(y_test_np, p_platt, n_bins=15)
ece_platt  = platt.expected_calibration_error(mp, fp, cnt)
brier_platt = platt.brier_score(y_test_np, p_platt)

platt.plot_calibration(bc, mp, fp, cnt, label=f"Platt (ECE={ece_platt:.3f}, Brier={brier_platt:.3f})")

Real data

torchsvm works well with sklearn datasets. We need to convert these datasets to torch.tensor with $y=1 \text{ or} -1$.

from sklearn.datasets import make_moons

# Generate non-linear dataset
X, y = make_moons(n_samples=300, noise=0.2, random_state=42)

X = torch.from_numpy(X).float()
y = torch.from_numpy(y).float()
y = 2 * y - 1

sig = sigest(X)
Kmat = rbf_kernel(X, sig)

model = cvksvm(Kmat=Kmat, y=y, nlam=nlam, ulam=ulam, nfolds=nfolds, eps=1e-5, maxit=1000, gamma=1e-8, is_exact=0, device='cuda')
model.fit()

Extensions to large-margin classifiers

It also provides applications for other large-margin classifiers:

  1. Kernel logistic regression
     model = cvklogit(Kmat=Kmat, y=y_train, nlam=nlam, ulam=ulam, foldid=foldid, nfolds=nfolds, eps=1e-5, maxit=1000, gamma=1e-8, is_exact=0, device='cuda')
     model.fit()
    
  2. Kernel distance weighted discrimination
    model = cvkdwd(Kmat=Kmat, y=y_train, nlam=nlam, ulam=ulam, foldid=foldid, nfolds=nfolds, eps=1e-5, maxit=1000, gamma=1e-8, is_exact=0, device='cuda')
    model.fit()
    

Getting help

Any questions or suggestions please contact: yikai-zhang@uiowa.edu

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchsvm-3.0.2.tar.gz (35.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchsvm-3.0.2-py3-none-any.whl (46.9 kB view details)

Uploaded Python 3

File details

Details for the file torchsvm-3.0.2.tar.gz.

File metadata

  • Download URL: torchsvm-3.0.2.tar.gz
  • Upload date:
  • Size: 35.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for torchsvm-3.0.2.tar.gz
Algorithm Hash digest
SHA256 ba300681fcd47277d5ab32b730fd279c04d77c837b401bf7bdb3156b64493df0
MD5 e40a9f87d4b4827f38dd8e7931369a76
BLAKE2b-256 ad9d15caf2fbebb669bb7b19b4b300fa68ffb4f103ee70f406f3190add340505

See more details on using hashes here.

File details

Details for the file torchsvm-3.0.2-py3-none-any.whl.

File metadata

  • Download URL: torchsvm-3.0.2-py3-none-any.whl
  • Upload date:
  • Size: 46.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for torchsvm-3.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 bb25d3a2ea3cf15293bdb8e14799c3b20cc14c79c6d4d85b77e6c0dc02aefd20
MD5 91137f5da783615aa09295f4ddd8cda1
BLAKE2b-256 0d72fc316944bc282b8c42fc1ba23656b795b139bc471bb9528d2ebb4642e8c9

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page