An easy-to-use tool for training Pytorch deep learning models
Project description
DeepEpochs
Pytorch深度学习模型训练工具。
使用
常规训练流程
from deepepochs import Trainer, Checker, rename
import torch
from torch import nn
from torch.nn import functional as F
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader, random_split
from torchmetrics import functional as MF
# datasets
data_dir = './dataset'
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
mnist_full = MNIST(data_dir, train=True, transform=transform, download=True)
train_ds, val_ds, _ = random_split(mnist_full, [5000, 5000, 50000])
test_ds = MNIST(data_dir, train=False, transform=transform, download=True)
# dataloaders
train_dl = DataLoader(train_ds, batch_size=32)
val_dl = DataLoader(val_ds, batch_size=32)
test_dl = DataLoader(test_ds, batch_size=32)
# pytorch model
channels, width, height = (1, 28, 28)
model = nn.Sequential(
nn.Flatten(),
nn.Linear(channels * width * height, 64),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(64, 64),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(64, 10)
)
def acc(preds, targets):
return MF.accuracy(preds, targets, task='multiclass', num_classes=10)
@rename('')
def multi_metrics(preds, targets):
return {
'p': MF.precision(preds, targets, task='multiclass', num_classes=10),
'r': MF.recall(preds, targets, task='multiclass', num_classes=10)
}
checker = Checker('loss', mode='min', patience=2)
opt = torch.optim.Adam(model.parameters(), lr=2e-4)
trainer = Trainer(model, F.cross_entropy, opt=opt, epochs=100, checker=checker, metrics=[acc, multi_metrics])
progress = trainer.fit(train_dl, val_dl)
test_rst = trainer.test(test_dl)
非常规训练流程
- 方法1:
- 第1步:继承
deepepochs.TrainerBase
类,定制满足需要的Trainer
,实现train_step
方法和evaluate_step
方法 - 第2步:调用定制
Trainer
训练模型。
- 第1步:继承
- 方法2:
- 第1步:继承
deepepochs.Callback
类,定制满足需要的Callback - 第2步:使用
deepepochs.Learner
训练模型,将定制的Callback作为Learner
的参数 - 提示:
Learner
是具有Callback
功能的Trainer
- 第1步:继承
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
deepepochs-0.2.1.tar.gz
(17.2 kB
view details)
Built Distribution
File details
Details for the file deepepochs-0.2.1.tar.gz
.
File metadata
- Download URL: deepepochs-0.2.1.tar.gz
- Upload date:
- Size: 17.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.6.0 importlib_metadata/4.8.2 pkginfo/1.8.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.8.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | d1e83519fd0dcf3055a5f0122055f1141f4bc5c9af2d25fe9920a8b725373e4f |
|
MD5 | 61fbca3f6787c601bec9cce2244c7ed0 |
|
BLAKE2b-256 | 21e5e2a6837e973693332fd47bdd8af075deceb71400d47cd8b24bd98e6ce819 |
File details
Details for the file deepepochs-0.2.1-py3-none-any.whl
.
File metadata
- Download URL: deepepochs-0.2.1-py3-none-any.whl
- Upload date:
- Size: 21.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.6.0 importlib_metadata/4.8.2 pkginfo/1.8.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.8.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 8633357cc8d87a3589970e330b631c6d4d56f6b95cb53a3abaaf321e74d5176d |
|
MD5 | 40571ea07d3278467d2904611e87b8fa |
|
BLAKE2b-256 | 67e65fd0eb16cf6f550506cb89facc44dfa97f4faf13c601d3f6f2c936b23014 |