torchsvm, a PyTorch-based library that trains kernel SVMs and other large-margin classifiers
Project description
torchsvm
This is a PyTorch-based package to solve kernel SVM with GPU.
Table of contents
Introduction
torchsvm, a PyTorch-based library that trains kernel SVMs and other large-margin classifiers with exact leave-one-out cross-validation (LOOCV) error computation. Conventional SVM solvers often face scalability and efficiency challenges, especially on large datasets or when multiple cross-validation runs are required. torchsvm computes LOOCV at the same cost as training a single SVM while boosting speed and scalability via CUDA-accelerated matrix operations. Benchmark experiments indicate that TorchKSVM outperforms existing kernel SVM solvers in efficiency and speed.
Installation
You can use pip to install this package.
pip install torchsvm
Quick start
The usages are similar with scikit-learn:
model = cvksvm(Kmat=Kmat, y=y_train, nlam=nlam, ulam=ulam, foldid=foldid, nfolds=nfolds, eps=1e-5, maxit=1000, gamma=1e-8, is_exact=0, device='cuda')
model.fit()
Usage
Generate simulation data
torchsvm provides a simulation data generation function to test functions in the library:
# Sample data
nn = 10000 # Number of samples
nm = 5 # Number of clusters per class
pp = 10 # Number of features
p1 = p2 = pp // 2 # Number of positive/negative centers
mu = 2.0 # Mean shift
ro = 3 # Standard deviation for normal distribution
sdn = 42 # Seed for reproducibility
nlam = 50
torch.manual_seed(sdn)
ulam = torch.logspace(3, -3, steps=nlam)
X_train, y_train, means_train = data_gen(nn, nm, pp, p1, p2, mu, ro, sdn)
X_test, y_test, means_test = data_gen(nn // 10, nm, pp, p1, p2, mu, ro, sdn)
X_train = standardize(X_train)
X_test = standardize(X_test)
sig = sigest(X_train)
Kmat = rbf_kernel(X_train, sig)
Basic operation
torchsvm mainly provides cvksvm to tune kernel SVM fast with GPU acceleration and compute exact leave-one-out cross-validation (LOOCV) errors if needed.
model = cvksvm(Kmat=Kmat, y=y_train, nlam=nlam, ulam=ulam, foldid=foldid, nfolds=nfolds, eps=1e-5, maxit=1000, gamma=1e-8, is_exact=0, device='cuda')
model.fit()
It also provides applications for other large-margin classifiers:
- Kernel logistic regression
model = cvklogit(Kmat=Kmat, y=y_train, nlam=nlam, ulam=ulam, foldid=foldid, nfolds=nfolds, eps=1e-5, maxit=1000, gamma=1e-8, is_exact=0, device='cuda') model.fit()
- Kernel SVM with Huber loss
model = cvkhuber(Kmat=Kmat, y=y_train, nlam=nlam, ulam=ulam, foldid=foldid, nfolds=nfolds, eps=1e-5, maxit=1000, gamma=1e-8, is_exact=0, device='cuda') model.fit()
- Kernel squared SVM
model = cvksqsvm(Kmat=Kmat, y=y_train, nlam=nlam, ulam=ulam, foldid=foldid, nfolds=nfolds, eps=1e-5, maxit=1000, gamma=1e-8, is_exact=0, device='cuda') model.fit()
Getting help
Any questions or suggestions please contact: yikai-zhang@uiowa.edu
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torchsvm-0.1.1.tar.gz.
File metadata
- Download URL: torchsvm-0.1.1.tar.gz
- Upload date:
- Size: 18.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.8
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
210be85e29bf60a0c71c8e118707a41c81c117b433979b92173d485cf13cce92
|
|
| MD5 |
851110684011f2ad1e2aa18685fa2d86
|
|
| BLAKE2b-256 |
a25cd2eb0a4cd4cc94457e7fd5da047e7a11ec1dec55f12a29279ff494377136
|
File details
Details for the file torchsvm-0.1.1-py3-none-any.whl.
File metadata
- Download URL: torchsvm-0.1.1-py3-none-any.whl
- Upload date:
- Size: 24.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.8
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
5ea6af6f14a39d43f57c05a8c81967c3de144dad9df0f2e09f046517d023af25
|
|
| MD5 |
265ab642e11054712d423df1ff6c04d4
|
|
| BLAKE2b-256 |
854171c55241273e6d187fdb6ed9be06afc23feb15def467153cae2a7abea2a6
|