No project description provided
Project description
deepepochs
Pytorch模型简易训练工具
使用
常规训练流程
from deepepochs import Trainer, Checker, rename
import torch
from torch import nn
from torch.nn import functional as F
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader, random_split
from torchmetrics import functional as MF
# datasets
data_dir = './dataset'
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
mnist_full = MNIST(data_dir, train=True, transform=transform, download=True)
train_ds, val_ds, _ = random_split(mnist_full, [5000, 5000, 50000])
test_ds = MNIST(data_dir, train=False, transform=transform, download=True)
# dataloaders
train_dl = DataLoader(train_ds, batch_size=32)
val_dl = DataLoader(val_ds, batch_size=32)
test_dl = DataLoader(test_ds, batch_size=32)
# pytorch model
channels, width, height = (1, 28, 28)
model = nn.Sequential(
nn.Flatten(),
nn.Linear(channels * width * height, 64),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(64, 64),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(64, 10)
)
def acc(preds, targets):
return MF.accuracy(preds, targets, task='multiclass', num_classes=10)
@rename('m')
def multi_metrics(preds, targets):
r = MF.recall(preds, targets, task='multiclass', num_classes=10)
f1 = MF.f1_score(preds, targets, task='multiclass', num_classes=10)
return {'@r': r, '@f1': f1}
checker = Checker('loss', mode='max', patience=2)
opt = torch.optim.Adam(model.parameters(), lr=2e-4)
trainer = Trainer(model, F.cross_entropy, opt=opt, epochs=100, checker=checker, metrics=[acc, multi_metrics])
progress = trainer.fit(train_dl, val_dl)
test_rst = trainer.test(test_dl)
非常规训练流程
- 第1步:继承
deepepochs.TrainerBase
类,定制满足需要的Trainer
,实现train_step
方法和evaluate_step
方法 - 第2步:调用定制
Trainer
训练模型。
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
deepepochs-0.1.6.tar.gz
(12.2 kB
view details)
Built Distribution
File details
Details for the file deepepochs-0.1.6.tar.gz
.
File metadata
- Download URL: deepepochs-0.1.6.tar.gz
- Upload date:
- Size: 12.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.6.0 importlib_metadata/4.8.2 pkginfo/1.8.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.8.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | ac60dba5e588e430535b39f0cd3aa0b5c9ab8d4ae82ea94d32622db9afba6f83 |
|
MD5 | 2d7606a8e0c1373e5bdcc0b978422d54 |
|
BLAKE2b-256 | d5cf3d02af07101161c2c93d62e99c3751ca2207ca18e8332a335a03373c81f6 |
File details
Details for the file deepepochs-0.1.6-py3-none-any.whl
.
File metadata
- Download URL: deepepochs-0.1.6-py3-none-any.whl
- Upload date:
- Size: 12.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.6.0 importlib_metadata/4.8.2 pkginfo/1.8.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.8.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 409f71d7140c902f4f00582528b3225e44a1fac053440a0b15859ec41f49f321 |
|
MD5 | 6f6f9fdf05966f595c2870773b86990e |
|
BLAKE2b-256 | 9589295d6326a6da60955e09dde66b100e26aa670f5ee92cb3f4a36f55daf841 |