Pytorch wrapper for performing grid-search
Project description
torchgs
Pytorch wrapper for grid search of hyperparameters [https://github.com/danny-1k/torch-gs]
Install
$ pip install torchgs
Example
Finding the best set of hyper-parameters and models for a classification problem
from sklearn.datasets import make_classification
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset
from torchgs import GridSearch
from torchgs.metrics import Loss
X,Y = make_classification(n_samples=200, n_features=20, n_informative=10,n_classes=2,shuffle=True, random_state=42)
X = torch.Tensor(X).float()
Y = torch.Tensor(Y).long()
traindata = TensorDataset(X,Y)
net1 = nn.Sequential(
nn.Linear(20,10),
nn.ReLU(),
nn.Linear(10,2)
)
net2 = nn.Sequential(
nn.Linear(20,10),
nn.Tanh(),
nn.Linear(10,2)
)
net3 = nn.Sequential(
nn.Linear(20,20),
nn.ReLU(),
nn.Linear(20,10),
nn.ReLU(),
nn.Linear(10,2)
)
net4 = nn.Sequential(
nn.Linear(20,20),
nn.Tanh(),
nn.Linear(20,10),
nn.Tanh(),
nn.Linear(10,2)
)
search_space = {
'trainer':
{
'net': [net1,net2,net3,net4],
'optimizer': [torch.optim.Adam],
'lossfn': [torch.nn.CrossEntropyLoss()],
'epochs': list(range(11)),
'metric': [Loss(torch.nn.CrossEntropyLoss())],
},
'train_loader': {
'batch_size': [32,64],
},
'optimizer':
{
'lr': [1e-1,1e-2,1e-3,1e-4],
},
}
searcher = GridSearch(search_space)
results = searcher.fit(traindata)
best = searcher.best(results,using='mean',topk=10,should_print=True)
Output
torchgs
- Trainer
- GridSearch
- metrics
- optimizers
torchgs.metrics
- Metric
- Loss
- Accuracy
- Recall
- Precision
- F1
torchgs.optimizers
- Optimizer
- LRscheduler
Todo
- Parallel Training on multiple GPUS
- Tensorboard Integration
Pull requests are welcome, let's collab 🤲.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
torchgs-0.0.2.tar.gz
(8.2 kB
view details)
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torchgs-0.0.2.tar.gz.
File metadata
- Download URL: torchgs-0.0.2.tar.gz
- Upload date:
- Size: 8.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.6
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
34301572366a53b9cdad5365cafc011ce70ad1546cc7d2edc00a622509e61802
|
|
| MD5 |
24970830d0217a048200d3b542de9437
|
|
| BLAKE2b-256 |
deb9735fb1cf89c7260d4622408bec72d4fa12555aaa4eb6ad4f5f467faae7d0
|
File details
Details for the file torchgs-0.0.2-py3-none-any.whl.
File metadata
- Download URL: torchgs-0.0.2-py3-none-any.whl
- Upload date:
- Size: 8.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.6
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
79554f298aed1c383b85c2f0f69a143eec6e8ed3b2d48de7469266acb6267dad
|
|
| MD5 |
59d6a80c6d59e5bb9642b4d99e551f22
|
|
| BLAKE2b-256 |
e0fbe9020dbfff020c91937adfcb3d6d741566987999ae5b69fd1fc4c1b9d65e
|