Skip to main content

torchsvm, a PyTorch-based library that trains kernel SVMs and other large-margin classifiers

Project description

torchsvm

img

This is a PyTorch-based package to solve kernel SVM with GPU.

Table of contents

Introduction

torchsvm, a PyTorch-based library that trains kernel SVMs and other large-margin classifiers with exact leave-one-out cross-validation (LOOCV) error computation. Conventional SVM solvers often face scalability and efficiency challenges, especially on large datasets or when multiple cross-validation runs are required. torchsvm computes LOOCV at the same cost as training a single SVM while boosting speed and scalability via CUDA-accelerated matrix operations. Benchmark experiments indicate that TorchKSVM outperforms existing kernel SVM solvers in efficiency and speed.

Installation

You can use pip to install this package.

pip install torchsvm

Quick start

The usages are similar with scikit-learn:

model = cvksvm(Kmat=Kmat, y=y_train, nlam=nlam, ulam=ulam, foldid=foldid, nfolds=nfolds, eps=1e-5, maxit=1000, gamma=1e-8, is_exact=0, device='cuda')
model.fit()

Usage

Generate simulation data

torchsvm provides a simulation data generation function to test functions in the library:

# Sample data
nn = 10000 # Number of samples
nm = 5    # Number of clusters per class
pp = 10   # Number of features
p1 = p2 = pp // 2    # Number of positive/negative centers
mu = 2.0  # Mean shift
ro = 3  # Standard deviation for normal distribution
sdn = 42  # Seed for reproducibility

nlam = 50
torch.manual_seed(sdn)
ulam = torch.logspace(3, -3, steps=nlam)

X_train, y_train, means_train = data_gen(nn, nm, pp, p1, p2, mu, ro, sdn)
X_test, y_test, means_test = data_gen(nn // 10, nm, pp, p1, p2, mu, ro, sdn)
X_train = standardize(X_train)
X_test = standardize(X_test)

sig = sigest(X_train)
Kmat = rbf_kernel(X_train, sig)

Basic operation

torchsvm mainly provides cvksvm to tune kernel SVM fast with GPU acceleration and compute exact leave-one-out cross-validation (LOOCV) errors if needed.

model = cvksvm(Kmat=Kmat, y=y_train, nlam=nlam, ulam=ulam, foldid=foldid, nfolds=nfolds, eps=1e-5, maxit=1000, gamma=1e-8, is_exact=0, device='cuda')
model.fit()

It also provides applications for other large-margin classifiers:

  1. Kernel logistic regression
     model = cvklogit(Kmat=Kmat, y=y_train, nlam=nlam, ulam=ulam, foldid=foldid, nfolds=nfolds, eps=1e-5, maxit=1000, gamma=1e-8, is_exact=0, device='cuda')
     model.fit()
    
  2. Kernel SVM with Huber loss
    model = cvkhuber(Kmat=Kmat, y=y_train, nlam=nlam, ulam=ulam, foldid=foldid, nfolds=nfolds, eps=1e-5, maxit=1000, gamma=1e-8, is_exact=0, device='cuda')
    model.fit()
    
  3. Kernel squared SVM
     model = cvksqsvm(Kmat=Kmat, y=y_train, nlam=nlam, ulam=ulam, foldid=foldid, nfolds=nfolds, eps=1e-5, maxit=1000, gamma=1e-8, is_exact=0, device='cuda')
     model.fit()
    

Getting help

Any questions or suggestions please contact: yikai-zhang@uiowa.edu

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchsvm-2.0.0.tar.gz (26.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchsvm-2.0.0-py3-none-any.whl (40.2 kB view details)

Uploaded Python 3

File details

Details for the file torchsvm-2.0.0.tar.gz.

File metadata

  • Download URL: torchsvm-2.0.0.tar.gz
  • Upload date:
  • Size: 26.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for torchsvm-2.0.0.tar.gz
Algorithm Hash digest
SHA256 8804014d43a2fc73b82aae749ec81f5fc678204e7a66f5e0594029e76787f4fe
MD5 f48e9da8e6e6587772bd9197ad4c66a1
BLAKE2b-256 3a41900f7e8de3c4438e47c1b778c81a9e78df7edf20ecf49d14815cb1fa7e68

See more details on using hashes here.

File details

Details for the file torchsvm-2.0.0-py3-none-any.whl.

File metadata

  • Download URL: torchsvm-2.0.0-py3-none-any.whl
  • Upload date:
  • Size: 40.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for torchsvm-2.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 5cd4c315c004ea12b25dca51b6a1660452275d95b37d115509fcdd82d54c6619
MD5 98821dc0fbdb89abc0d888d39750675c
BLAKE2b-256 e44f8a9d89ec98f3924d5b97b9659de7474a7f3d0d1f127c58aadd639e4bbf2b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page