An easy-to-use tool for training Pytorch deep learning models
Project description
DeepEpochs
Pytorch深度学习模型训练工具。
安装
pip install deepepochs
使用
数据要求
- 训练集、验证集和测试集是
torch.utils.data.Dataloader
对象 Dataloaer
中每个mini-batch数据是一个tuple
或list
,其中最后一个是标签- 如果数据不包含标签,则请将最后一项置为
None
- 如果数据不包含标签,则请将最后一项置为
指标计算
- 每个指标是一个函数
- 它有两个参数,分别为模型的预测结果和标签
- 返回值为当前mini-batch上的指标值
常规训练流程应用示例
from deepepochs import Trainer, CheckCallback, EpochTask, rename
import torch
from torch import nn
from torch.nn import functional as F
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader, random_split
from torchmetrics import functional as MF
# datasets
data_dir = './dataset'
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
mnist_full = MNIST(data_dir, train=True, transform=transform, download=True)
train_ds, val_ds, _ = random_split(mnist_full, [5000, 5000, 50000])
test_ds = MNIST(data_dir, train=False, transform=transform, download=True)
# dataloaders
train_dl = DataLoader(train_ds, batch_size=32)
val_dl = DataLoader(val_ds, batch_size=32)
test_dl = DataLoader(test_ds, batch_size=32)
# pytorch model
channels, width, height = (1, 28, 28)
model = nn.Sequential(
nn.Flatten(),
nn.Linear(channels * width * height, 64),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(64, 64),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(64, 10)
)
def acc(preds, targets):
return MF.accuracy(preds, targets, task='multiclass', num_classes=10)
@rename('')
def multi_metrics(preds, targets):
return {
'p': MF.precision(preds, targets, task='multiclass', num_classes=10),
'r': MF.recall(preds, targets, task='multiclass', num_classes=10)
}
checker = CheckCallback('loss', on_stage='train', mode='min', patience=2)
opt = torch.optim.Adam(model.parameters(), lr=2e-4)
trainer = Trainer(model, F.cross_entropy, opt=opt, epochs=100, callbacks=checker, metrics=[acc])
# 应用示例1:
progress = trainer.fit(train_dl, val_dl, metrics=[multi_metrics])
test_rst = trainer.test(test_dl)
# 应用示例2:
# t1 = EpochTask(train_dl, metrics=[acc])
# t2 = EpochTask(val_dl, metrics=[multi_metrics], do_loss=True)
# progress = trainer.fit(train_tasks=t1, val_tasks=t2)
# test_rst = trainer.test(tasks=t2)
# 应用示例3:
# t1 = EpochTask(train_dl, metrics=[acc])
# t2 = EpochTask(val_dl, metrics=[acc, multi_metrics], do_loss=True)
# progress = trainer.fit(train_dl, val_tasks=[t1, t2])
# test_rst = trainer.test(tasks=[t1, t2])
非常规训练流程
- 方法1:
- 第1步:继承
deepepochs.Callback
类,定制满足需要的Callback
- 第2步:使用
deepepochs.Trainer
训练模型,将定制的Callback
对象作为Trainer
的callbacks
参数
- 第1步:继承
- 方法2:
- 第1步:继承
deepepochs.TrainerBase
类,定制满足需要的Trainer
,实现train_step
方法和evaluate_step
方法- 返回值为字典:key为指标名称,value为
DeepEpochs.PatchBase
子类对象,可用的Patch有ValuePatch
: 根据每个batch指标均值(提前计算好)和batch_size,累积计算Epoch指标均值TensorPatch
: 保存每个batch模型预测输出及标签,根据指定指标函数累积计算Epoch指标均值MeanPatch
: 保存每个batch指标均值,根据指定指标函数累积计算Epoch指标均值ConfusionPatch
:累积计算基于混淆矩阵的指标
- 返回值为字典:key为指标名称,value为
- 第2步:调用定制
Trainer
训练模型。
- 第1步:继承
- 方法3:
- 第1步:继承
deepepochs.EpochTask
类,在其中定义step
、train_step
、val_step
、test_step
或者evaluate_step
方法- 参数分别为:
batch_x
,batch_y
,metrics
,**kwargs
- 返回值为字典:key为指标名称,value为
DeepEpochs.PatchBase
子类对象
- 参数分别为:
- 第2步:使用将新的
EpochTask
任务进行训练。- 将
EpochTask
对象作为Trainer.fit
中train_tasks
和val_tasks
的参数值,或者Trainer.test
方法中tasks
的参数值
- 将
- 第1步:继承
数据流程图
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
deepepochs-0.3.1.tar.gz
(112.1 kB
view details)
Built Distribution
File details
Details for the file deepepochs-0.3.1.tar.gz
.
File metadata
- Download URL: deepepochs-0.3.1.tar.gz
- Upload date:
- Size: 112.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.6.0 importlib_metadata/4.8.2 pkginfo/1.8.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.8.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 7600bc96e1176c46d96212f0ba6aad894a9cce6da3b5b1640e2dcfc3771ccde9 |
|
MD5 | da8d7564811b65c841633264f10db09c |
|
BLAKE2b-256 | da1edefd9255d8054cc867f4af0571b77434af9343eaf19f1539a5a377a36dba |
File details
Details for the file deepepochs-0.3.1-py3-none-any.whl
.
File metadata
- Download URL: deepepochs-0.3.1-py3-none-any.whl
- Upload date:
- Size: 19.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.6.0 importlib_metadata/4.8.2 pkginfo/1.8.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.8.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 26f5a7c8b67b208ad27ed09c0526bde7c7365e3a3768a6868a8c469081b0ea35 |
|
MD5 | d805409781b1869f0a0b7fa6af2ad5e3 |
|
BLAKE2b-256 | 9171e6e67393fae2c153ad2af73c8ad7b63362be87f6f06bb467a5c595cb5037 |