Skip to main content

An easy-to-use tool for training Pytorch deep learning models

Project description

DeepEpochs

Pytorch深度学习模型训练工具。

安装

pip install deepepochs

使用

数据要求

  • 训练集、验证集和测试集是torch.utils.data.Dataloader对象
  • Dataloaer所构造的每个mini-batch数据(collate_fn返回值)是一个tuplelist,其中最后一个是标签
    • 如果训练中不需要标签,则需将最后一项置为None

指标计算

  • 每个指标是一个函数
    • 有两个参数,分别为模型预测和数据标签
    • 返回值为当前mini-batch上计算的指标值或字典
    • 支持基于torchmetrics.functional定义指标

实例

import torch
from torch import nn
from torch.nn import functional as F
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader, random_split
from deepepochs import Trainer

# 1. --- datasets
data_dir = './datasets'
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
mnist_full = MNIST(data_dir, train=True, transform=transform, download=True)
train_ds, val_ds = random_split(mnist_full, [55000, 5000])
test_ds = MNIST(data_dir, train=False, transform=transform, download=True)
train_dl = DataLoader(train_ds, batch_size=32)
val_dl = DataLoader(val_ds, batch_size=32)
test_dl = DataLoader(test_ds, batch_size=32)

# 2. --- model
channels, width, height = (1, 28, 28)
model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(channels * width * height, 64), nn.ReLU(), nn.Dropout(0.1),
    nn.Linear(64, 64), nn.ReLU(), nn.Dropout(0.1),
    nn.Linear(64, 10)
)

# 3. --- optimizer
opt = torch.optim.Adam(model.parameters(), lr=2e-4)

# 4. --- train
trainer = Trainer(model, F.cross_entropy, opt, epochs=2)  # 训练器
trainer.fit(train_dl, val_dl)                             # 训练、验证
trainer.test(test_dl)                                     # 测试

更多实例

序号 功能说明 代码
1 基本使用 examples/1-basic.py
2 Trainer、fit方法、test方法的常用参数 examples/2-basic-params.py
3 模型性能评价指标的使用 examples/3-metrics.py
4 Checkpoint与EarlyStop examples/4-checkpoint-earlystop.py
5 寻找适当的学习率 examples/5-lr-find.py
6 利用Tensorboad记录训练过程 examples/6-logger.py
7 利用Tensorboard记录与可视化超参数 examples/7-log-hyperparameters.py
8 分析、解释或可视化模型的预测效果 examples/8-interprete.py
9 学习率调度 examples/9-lr-schedule.py
10 使用多个优化器 examples/10-multi-optimizers.py
11 在训练、验证、测试中使用多个Dataloader examples/11-multi-dataloaders.py
12 基于图神经网络的节点分类 examples/12-node-classification.py
13 模型前向输出和梯度的可视化 examples/13-weight-grad-visualize.py
14 自定义Callback examples/14-costomize-callback.py
15 通过TrainerBase定制train_stepevaluate_step examples/15-customize-steps-1.py
16 通过EpochTask定制train_stepeval_steptest_step examples/16-customize-steps-2.py
17 通过EpochTask定制step examples/17-costomize-steps-3.py
18 内置Patch的使用1 examples/18-use_patches-1.py
19 内置Patch的使用2 examples/19-use_patches-2.py
20 自定义Patch examples/20-customize-patch.py
21 分布式训练、混合精度训练 examples/21-accelerate.py
22 定制train_step实现累积梯度训练 examples/22-grad_accumulate-1.py
23 定制train_step,利用Accelerate实现累积梯度训练 examples/23-grad_accumulate-2.py

定制

  • 方法1(示例14
    • 第1步:继承deepepochs.Callback类,定制满足需要的Callback
    • 第2步:使用deepepochs.Trainer训练模型,将定制的Callback对象作为Trainercallbacks参数
  • 方法2(示例15
    • 第1步:继承deepepochs.TrainerBase类定制满足需要的Trainer,实现steptrain_stepval_steptest_stepevaluate_step方法,它们的定义方法完全相同
      • 参数
        • batch_x: 一个mini-batch的模型输入数据
        • batch_y: 一个mini-batch的标签
        • **step_args:可变参数字典,即EpochTaskstep_args参数
      • 返回值为None或字典
        • key:指标名称
        • value:deepepochs.PatchBase子类对象,可用的Patch有(示例18
          • ValuePatch: 根据每个mini-batch指标均值(提前计算好)和batch_size,累积计算Epoch指标均值
          • TensorPatch: 保存每个mini-batch的(preds, targets),Epoch指标利用所有mini-batch的(preds, targets)数据重新计算
          • MeanPatch: 保存每个batch指标均值,Epoch指标值利用每个mini-batch的均值计算
            • 一般MeanPatchTensorPatch结果相同,但占用存储空间更小、运算速度更快
            • 不可用于计算'precision', 'recall', 'f1', 'fbeta'等指标
          • ConfusionPatch:用于计算基于混淆矩阵的指标,包括'accuracy', 'precision', 'recall', 'f1', 'fbeta'等
        • 也可以继承PatchBase定义新的Patch,需要实现如下方法 (示例19)
          • PatchBase.add
            • 用于将两个Patch对象相加得到更大的Patch对象
          • PatchBase.forward
            • 用于计算指标,返回指标值或字典
    • 第2步:调用定制Trainer训练模型。
  • 方法(示例16、17
    • 第1步:继承deepepochs.EpochTask类,在其中定义steptrain_stepval_steptest_stepevaluate_step方法
      • 它们的定义方式与Trainer中的*step方法相同
      • step方法优先级最高,即可用于训练也可用于验证和测试(定义了step方法,其他方法就会失效)
      • val_steptest_step优先级高于evaluate_step方法
      • EpochTask中的*step方法优先级高于Trainer中的*step方法
      • EpochTask__ini__方法的**step_args会被注入*step方法的step_args 参数
    • 第2步:使用新的EpochTask任务训练
      • EpochTask对象作为Trainer.fittrain_tasksval_tasks的参数值,或者Trainer.test方法中tasks的参数值

数据流图

https://github.com/hitlic/deepepochs/blob/main/imgs/data_flow.png

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

deepepochs-0.6.5.tar.gz (37.7 kB view details)

Uploaded Source

Built Distribution

deepepochs-0.6.5-py3-none-any.whl (42.7 kB view details)

Uploaded Python 3

File details

Details for the file deepepochs-0.6.5.tar.gz.

File metadata

  • Download URL: deepepochs-0.6.5.tar.gz
  • Upload date:
  • Size: 37.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.13

File hashes

Hashes for deepepochs-0.6.5.tar.gz
Algorithm Hash digest
SHA256 1fa3c591d0ffe8af39d2146f0887f46d19b6453563e9992c9dd0c926e454216a
MD5 b69c15e631d2d27e28e8d05f3207dede
BLAKE2b-256 af11f4eb502879660061bc9b78afa3bcab11fc4e7d5e3e4e8dcbbe87c858f35a

See more details on using hashes here.

File details

Details for the file deepepochs-0.6.5-py3-none-any.whl.

File metadata

  • Download URL: deepepochs-0.6.5-py3-none-any.whl
  • Upload date:
  • Size: 42.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.13

File hashes

Hashes for deepepochs-0.6.5-py3-none-any.whl
Algorithm Hash digest
SHA256 8f78366f71ab7a887676ef4b12f5621e709c7d4bfd93a34b6093b0cce26e6ab3
MD5 a8cd1993d9a3cc405b9b570118992d4c
BLAKE2b-256 1f77402894e17040d2e82a88643c8a4607581303b5394520d28a41cc0ab37b7a

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page