An easy-to-use tool for training Pytorch deep learning models
Project description
DeepEpochs
Pytorch深度学习模型训练工具。
安装
pip install deepepochs
使用
数据要求
- 训练集、验证集和测试集是
torch.utils.data.Dataloader对象 Dataloaer所构造的每个mini-batch数据(collate_fn返回值)是一个tuple或list,其中最后一个是标签- 如果训练中不需要标签,则需将最后一项置为
None
- 如果训练中不需要标签,则需将最后一项置为
指标计算
- 每个指标是一个函数
- 有两个参数,分别为模型预测和数据标签
- 返回值为当前mini-batch上计算的指标值或字典
- 支持基于
torchmetrics.functional定义指标
实例
import torch
from torch import nn
from torch.nn import functional as F
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader, random_split
from deepepochs import Trainer
# 1. --- datasets
data_dir = './datasets'
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
mnist_full = MNIST(data_dir, train=True, transform=transform, download=True)
train_ds, val_ds = random_split(mnist_full, [55000, 5000])
test_ds = MNIST(data_dir, train=False, transform=transform, download=True)
train_dl = DataLoader(train_ds, batch_size=32)
val_dl = DataLoader(val_ds, batch_size=32)
test_dl = DataLoader(test_ds, batch_size=32)
# 2. --- model
channels, width, height = (1, 28, 28)
model = nn.Sequential(
nn.Flatten(),
nn.Linear(channels * width * height, 64), nn.ReLU(), nn.Dropout(0.1),
nn.Linear(64, 64), nn.ReLU(), nn.Dropout(0.1),
nn.Linear(64, 10)
)
# 3. --- optimizer
opt = torch.optim.Adam(model.parameters(), lr=2e-4)
# 4. --- train
trainer = Trainer(model, F.cross_entropy, opt, epochs=2) # 训练器
trainer.fit(train_dl, val_dl) # 训练、验证
trainer.test(test_dl) # 测试
更多实例
定制
- 方法1(示例14)
- 第1步:继承
deepepochs.Callback类,定制满足需要的Callback - 第2步:使用
deepepochs.Trainer训练模型,将定制的Callback对象作为Trainer的callbacks参数
- 第1步:继承
- 方法2(示例15)
- 第1步:继承
deepepochs.TrainerBase类定制满足需要的Trainer,实现step、train_step、val_step、test_step或evaluate_step方法,它们的定义方法完全相同- 参数
batch_x: 一个mini-batch的模型输入数据batch_y: 一个mini-batch的标签**step_args:可变参数字典,即EpochTask的step_args参数
- 返回值为
None或字典- key:指标名称
- value:
deepepochs.PatchBase子类对象,可用的Patch有(示例18)ValuePatch: 根据每个mini-batch指标均值(提前计算好)和batch_size,累积计算Epoch指标均值TensorPatch: 保存每个mini-batch的(preds, targets),Epoch指标利用所有mini-batch的(preds, targets)数据重新计算MeanPatch: 保存每个batch指标均值,Epoch指标值利用每个mini-batch的均值计算- 一般
MeanPatch与TensorPatch结果相同,但占用存储空间更小、运算速度更快 - 不可用于计算'precision', 'recall', 'f1', 'fbeta'等指标
- 一般
ConfusionPatch:用于计算基于混淆矩阵的指标,包括'accuracy', 'precision', 'recall', 'f1', 'fbeta'等
- 也可以继承
PatchBase定义新的Patch,需要实现如下方法 (示例19)PatchBase.add- 用于将两个Patch对象相加得到更大的Patch对象
PatchBase.forward- 用于计算指标,返回指标值或字典
- 参数
- 第2步:调用定制
Trainer训练模型。
- 第1步:继承
- 方法(示例16、17)
- 第1步:继承
deepepochs.EpochTask类,在其中定义step、train_step、val_step、test_step或evaluate_step方法- 它们的定义方式与
Trainer中的*step方法相同 step方法优先级最高,即可用于训练也可用于验证和测试(定义了step方法,其他方法就会失效)val_step、test_step优先级高于evaluate_step方法EpochTask中的*step方法优先级高于Trainer中的*step方法EpochTask的__ini__方法的**step_args会被注入*step方法的step_args参数
- 它们的定义方式与
- 第2步:使用新的
EpochTask任务训练- 将
EpochTask对象作为Trainer.fit中train_tasks和val_tasks的参数值,或者Trainer.test方法中tasks的参数值
- 将
- 第1步:继承
数据流图
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
deepepochs-0.6.16.tar.gz
(38.9 kB
view details)
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file deepepochs-0.6.16.tar.gz.
File metadata
- Download URL: deepepochs-0.6.16.tar.gz
- Upload date:
- Size: 38.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f187bf3311aeef3759d71a8fffaa8394c18433974cb05bdf66cbd96389d20527
|
|
| MD5 |
05c74ecdac10867623d4ee1e1552c743
|
|
| BLAKE2b-256 |
ece61e740d132f1528d28c7a1fa1e55057c8596cce2556570eeb5354e3383308
|
File details
Details for the file deepepochs-0.6.16-py3-none-any.whl.
File metadata
- Download URL: deepepochs-0.6.16-py3-none-any.whl
- Upload date:
- Size: 43.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
238f770cf7f4381946cf91ffdb42de1388d9af89edcb77f683a3ea208d8772ca
|
|
| MD5 |
94c30bb22852053a0544bc422c3c174e
|
|
| BLAKE2b-256 |
85c213894384d28eca4966c6f7d304af1e9a926da97a016936e56a7c28c6cb4d
|