Skip to main content

An easy-to-use tool for training Pytorch deep learning models

Project description

DeepEpochs

Pytorch深度学习模型训练工具。

安装

pip install deepepochs

使用

数据要求

  • 训练集、验证集和测试集是torch.utils.data.Dataloader对象
  • Dataloaer所构造的每个mini-batch数据(collate_fn返回值)是一个tuplelist,其中最后一个是标签
    • 如果训练中不需要标签,则需将最后一项置为None

指标计算

  • 每个指标是一个函数
    • 有两个参数,分别为模型预测和数据标签
    • 返回值为当前mini-batch上的指标值

应用

from deepepochs import Trainer, CheckCallback, rename, EpochTask, LogCallback
import torch
from torch import nn
from torch.nn import functional as F
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader, random_split
from torchmetrics import functional as MF

# datasets
data_dir = './datasets'
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
mnist_full = MNIST(data_dir, train=True, transform=transform, download=True)
train_ds, val_ds, _ = random_split(mnist_full, [5000, 5000, 50000])
test_ds = MNIST(data_dir, train=False, transform=transform, download=True)

# dataloaders
train_dl = DataLoader(train_ds, batch_size=32)
val_dl = DataLoader(val_ds, batch_size=32)
test_dl = DataLoader(test_ds, batch_size=32)

# pytorch model
channels, width, height = (1, 28, 28)
model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(channels * width * height, 64),
    nn.ReLU(),
    nn.Dropout(0.1),
    nn.Linear(64, 64),
    nn.ReLU(),
    nn.Dropout(0.1),
    nn.Linear(64, 10)
)

def acc(preds, targets):
    return MF.accuracy(preds, targets, task='multiclass', num_classes=10)

@rename('')
def multi_metrics(preds, targets):
    return {
        'p': MF.precision(preds, targets, task='multiclass', num_classes=10),
        'r': MF.recall(preds, targets, task='multiclass', num_classes=10)
        }
checker = CheckCallback('loss', on_stage='val', mode='min', patience=2)
opt = torch.optim.Adam(model.parameters(), lr=2e-4)

trainer = Trainer(model, F.cross_entropy, opt=opt, epochs=5, callbacks=checker, metrics=[acc])

progress = trainer.fit(train_dl, val_dl, metrics=[multi_metrics])
test_rst = trainer.test(test_dl)

示例

序号 功能说明 代码
1 基本使用 examples/1-basic.py
2 训练器、fit方法、test方法的常用参数 examples/2-basic-params.py
3 模型性能评价指标的使用 examples/3-metrics.py
4 Checkpoint和EarlyStop examples/4-checkpoint-earlystop.py
5 检测适当的学习率 examples/5-lr-find.py
6 利用Tensorboad记录训练过程 examples/6-logger.py
7 利用tensorboard记录与可视化超参数 examples/7-log-hyperparameters.py
8 分析与解释模型的预测效果 examples/8-interprete.py
9 学习率调度 examples/9-lr-schedule.py
10 使用多个优化器 examples/10-multi-optimizers.py
11 在训练、验证、测试中使用多个Dataloader examples/11-multi-dataloaders.py
12 利用图神经网络对节点进行分类 examples/12-node-classification.py
13 模型前向输出和梯度的可视化 examples/13-weight-grad-visualize.py
14 自定义Callback examples/14-costomize-callback.py
15 通过TrainerBase定制train_stepevaluate_step examples/15-customize-steps-1.py
16 通过EpochTask定制train_stepeval_steptest_step examples/16-customize-steps-2.py
17 通过EpochTask定制*step examples/17-costomize-steps-3.py
18 内置Patch的使用 examples/18-use_patches.py
19 自定义Patch examples/19-customize-patch.py

定制训练流程

  • 方法1:
    • 第1步:继承deepepochs.Callback类,定制满足需要的Callback
    • 第2步:使用deepepochs.Trainer训练模型,将定制的Callback对象作为Trainercallbacks参数
  • 方法2:
    • 第1步:继承deepepochs.TrainerBase类,定制满足需要的Trainer,实现steptrain_stepval_steptest_stepevaluate_step方法
      • 这些方法的参数
        • batch_x: 一个mini-batch的模型输入数据
        • batch_y: 一个mini-batch的标签
        • **step_args:可变参数字典,包含do_lossmetrics等参数
      • 这些方法的返回值为元组或字典
        • 元组:(loss, model_out)
          • 损失
          • 模型预测输出
        • 字典:{'loss': loss_value, 'model_out': model_out}
    • 第2步:调用定制Trainer训练模型。
  • 方法3:
    • 第1步:继承deepepochs.EpochTask类,在其中定义steptrain_stepval_steptest_stepevaluate_step方法
      • 它们的定义方式与Trainer中的*step方法相同
      • step方法优先级最高,即可用于训练也可用于验证和测试(定义了step方法,其他方法就会失效)
      • val_steptest_step优先级高于evaluate_step方法
      • EpochTask中的*_step方法优先级高于Trainer中的*_step方法
    • 第2步:使用新的EpochTask任务进行训练
      • EpochTask对象作为Trainer.fittrain_tasksval_tasks的参数值,或者Trainer.test方法中tasks的参数值

数据流图

https://github.com/hitlic/deepepochs/blob/main/imgs/data_flow.png

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

deepepochs-0.4.13.tar.gz (29.7 kB view details)

Uploaded Source

Built Distribution

deepepochs-0.4.13-py3-none-any.whl (32.8 kB view details)

Uploaded Python 3

File details

Details for the file deepepochs-0.4.13.tar.gz.

File metadata

  • Download URL: deepepochs-0.4.13.tar.gz
  • Upload date:
  • Size: 29.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.13

File hashes

Hashes for deepepochs-0.4.13.tar.gz
Algorithm Hash digest
SHA256 1a39631b2e38bbf5024f21fc15267353a0df7d199f2aaa734341bbf0050bc04f
MD5 609e2a4949f1242d7e7683edab304385
BLAKE2b-256 8416db480cf6959cd42db09420c2cca26ff8fb692829d6c10ecfb0aaec1751c8

See more details on using hashes here.

File details

Details for the file deepepochs-0.4.13-py3-none-any.whl.

File metadata

  • Download URL: deepepochs-0.4.13-py3-none-any.whl
  • Upload date:
  • Size: 32.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.13

File hashes

Hashes for deepepochs-0.4.13-py3-none-any.whl
Algorithm Hash digest
SHA256 e539c734cb91ed3a8023a81ef3a91c9425501dcca576612fac04723c71999bf5
MD5 1a7412e7819c403c200566808e6935a5
BLAKE2b-256 cbaf3559f942fddab161bd02186d6ee01c4e3cd9681fcbdb12af48ecceff0939

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page