Skip to main content

An easy-to-use tool for training Pytorch deep learning models

Project description

DeepEpochs

Pytorch深度学习模型训练工具。

安装

pip install deepepochs

使用

数据要求

  • 训练集、验证集和测试集是torch.utils.data.Dataloader对象
  • Dataloaer中每个mini-batch数据是一个tuplelist,其中最后一个是标签
    • 如果数据不包含标签,则请将最后一项置为None

指标计算

  • 每个指标是一个函数
    • 它有两个参数,分别为模型的预测结果和标签
    • 返回值为当前mini-batch上的指标值

常规训练流程应用示例

from deepepochs import Trainer, CheckCallback, EpochTask, rename
import torch
from torch import nn
from torch.nn import functional as F
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader, random_split
from torchmetrics import functional as MF

# datasets
data_dir = './dataset'
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
mnist_full = MNIST(data_dir, train=True, transform=transform, download=True)
train_ds, val_ds, _ = random_split(mnist_full, [5000, 5000, 50000])
test_ds = MNIST(data_dir, train=False, transform=transform, download=True)

# dataloaders
train_dl = DataLoader(train_ds, batch_size=32)
val_dl = DataLoader(val_ds, batch_size=32)
test_dl = DataLoader(test_ds, batch_size=32)

# pytorch model
channels, width, height = (1, 28, 28)
model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(channels * width * height, 64),
    nn.ReLU(),
    nn.Dropout(0.1),
    nn.Linear(64, 64),
    nn.ReLU(),
    nn.Dropout(0.1),
    nn.Linear(64, 10)
)

def acc(preds, targets):
    return MF.accuracy(preds, targets, task='multiclass', num_classes=10)

@rename('')
def multi_metrics(preds, targets):
    return {
        'p': MF.precision(preds, targets, task='multiclass', num_classes=10),
        'r': MF.recall(preds, targets, task='multiclass', num_classes=10)
        }


checker = CheckCallback('loss', on_stage='train', mode='min', patience=2)
opt = torch.optim.Adam(model.parameters(), lr=2e-4)

trainer = Trainer(model, F.cross_entropy, opt=opt, epochs=100, callbacks=checker, metrics=[acc])

# 应用示例1:
progress = trainer.fit(train_dl, val_dl, metrics=[multi_metrics])
test_rst = trainer.test(test_dl)

# 应用示例2:
# t1 = EpochTask(train_dl, metrics=[acc])
# t2 = EpochTask(val_dl, metrics=[multi_metrics], do_loss=True)
# progress = trainer.fit(train_tasks=t1, val_tasks=t2)
# test_rst = trainer.test(tasks=t2)

# 应用示例3:
# t1 = EpochTask(train_dl, metrics=[acc])
# t2 = EpochTask(val_dl, metrics=[acc, multi_metrics], do_loss=True)
# progress = trainer.fit(train_dl, val_tasks=[t1, t2])
# test_rst = trainer.test(tasks=[t1, t2])

非常规训练流程

  • 方法1:
    • 第1步:继承deepepochs.Callback类,定制满足需要的Callback
    • 第2步:使用deepepochs.Trainer训练模型,将定制的Callback对象作为Trainercallbacks参数
  • 方法2:
    • 第1步:继承deepepochs.TrainerBase类,定制满足需要的Trainer,实现train_step方法和evaluate_step方法
      • 返回值为字典:key为指标名称,value为DeepEpochs.PatchBase子类对象,可用的Patch有
        • ValuePatch: 根据每个batch指标均值(提前计算好)和batch_size,累积计算Epoch指标均值
        • TensorPatch: 保存每个batch模型预测输出及标签,根据指定指标函数累积计算Epoch指标均值
        • MeanPatch: 保存每个batch指标均值,根据指定指标函数累积计算Epoch指标均值
        • ConfusionPatch:累积计算基于混淆矩阵的指标
    • 第2步:调用定制Trainer训练模型。
  • 方法3:
    • 第1步:继承deepepochs.EpochTask类,在其中定义steptrain_stepval_steptest_step或者evaluate_step方法
      • 参数分别为:batch_x, batch_y, metrics, **kwargs
      • 返回值为字典:key为指标名称,value为DeepEpochs.PatchBase子类对象
    • 第2步:使用将新的EpochTask任务进行训练。
      • EpochTask对象作为Trainer.fittrain_tasksval_tasks的参数值,或者Trainer.test方法中tasks的参数值

数据流程图

https://github.com/hitlic/deepepochs/blob/main/imgs/data_flow.png

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

deepepochs-0.3.1.tar.gz (112.1 kB view details)

Uploaded Source

Built Distribution

deepepochs-0.3.1-py3-none-any.whl (19.9 kB view details)

Uploaded Python 3

File details

Details for the file deepepochs-0.3.1.tar.gz.

File metadata

  • Download URL: deepepochs-0.3.1.tar.gz
  • Upload date:
  • Size: 112.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.6.0 importlib_metadata/4.8.2 pkginfo/1.8.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.8.12

File hashes

Hashes for deepepochs-0.3.1.tar.gz
Algorithm Hash digest
SHA256 7600bc96e1176c46d96212f0ba6aad894a9cce6da3b5b1640e2dcfc3771ccde9
MD5 da8d7564811b65c841633264f10db09c
BLAKE2b-256 da1edefd9255d8054cc867f4af0571b77434af9343eaf19f1539a5a377a36dba

See more details on using hashes here.

File details

Details for the file deepepochs-0.3.1-py3-none-any.whl.

File metadata

  • Download URL: deepepochs-0.3.1-py3-none-any.whl
  • Upload date:
  • Size: 19.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.6.0 importlib_metadata/4.8.2 pkginfo/1.8.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.8.12

File hashes

Hashes for deepepochs-0.3.1-py3-none-any.whl
Algorithm Hash digest
SHA256 26f5a7c8b67b208ad27ed09c0526bde7c7365e3a3768a6868a8c469081b0ea35
MD5 d805409781b1869f0a0b7fa6af2ad5e3
BLAKE2b-256 9171e6e67393fae2c153ad2af73c8ad7b63362be87f6f06bb467a5c595cb5037

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page