Skip to main content

An easy-to-use tool for training Pytorch deep learning models

Project description

DeepEpochs

Pytorch深度学习模型训练工具。

安装

pip install deepepochs

使用

数据要求

  • 训练集、验证集和测试集是torch.utils.data.Dataloader对象
  • Dataloaer中每个mini-batch数据是一个tuplelist,其中最后一个是标签
    • 如果数据不包含标签,则请将最后一项置为None

指标计算

  • 每个指标是一个函数
    • 它有两个参数,分别为模型的预测结果和标签
    • 返回值为当前mini-batch上的指标值

常规训练流程

from deepepochs import Trainer, Checker, rename
import torch
from torch import nn
from torch.nn import functional as F
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader, random_split
from torchmetrics import functional as MF

# datasets
data_dir = './dataset'
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
mnist_full = MNIST(data_dir, train=True, transform=transform, download=True)
train_ds, val_ds, _ = random_split(mnist_full, [5000, 5000, 50000])
test_ds = MNIST(data_dir, train=False, transform=transform, download=True)

# dataloaders
train_dl = DataLoader(train_ds, batch_size=32)
val_dl = DataLoader(val_ds, batch_size=32)
test_dl = DataLoader(test_ds, batch_size=32)

# pytorch model
channels, width, height = (1, 28, 28)
model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(channels * width * height, 64),
    nn.ReLU(),
    nn.Dropout(0.1),
    nn.Linear(64, 64),
    nn.ReLU(),
    nn.Dropout(0.1),
    nn.Linear(64, 10)
)

def acc(preds, targets):
    return MF.accuracy(preds, targets, task='multiclass', num_classes=10)

@rename('')
def multi_metrics(preds, targets):
    return {
        'p': MF.precision(preds, targets, task='multiclass', num_classes=10),
        'r': MF.recall(preds, targets, task='multiclass', num_classes=10)
        }


checker = Checker('loss', mode='min', patience=2)
opt = torch.optim.Adam(model.parameters(), lr=2e-4)
trainer = Trainer(model, F.cross_entropy, opt=opt, epochs=100, checker=checker, metrics=[acc, multi_metrics])

progress = trainer.fit(train_dl, val_dl)
test_rst = trainer.test(test_dl)

非常规训练流程

  • 方法1:
    • 第1步:继承deepepochs.TrainerBase类,定制满足需要的Trainer,实现train_step方法和evaluate_step方法
    • 第2步:调用定制Trainer训练模型。
  • 方法2:
    • 第1步:继承deepepochs.Callback类,定制满足需要的Callback
    • 第2步:使用deepepochs.Learner训练模型,将定制的Callback作为Learner的参数
    • 提示Learner是具有Callback功能的Trainer

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

deepepochs-0.2.6.tar.gz (18.2 kB view details)

Uploaded Source

Built Distribution

deepepochs-0.2.6-py3-none-any.whl (21.3 kB view details)

Uploaded Python 3

File details

Details for the file deepepochs-0.2.6.tar.gz.

File metadata

  • Download URL: deepepochs-0.2.6.tar.gz
  • Upload date:
  • Size: 18.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.6.0 importlib_metadata/4.8.2 pkginfo/1.8.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.8.12

File hashes

Hashes for deepepochs-0.2.6.tar.gz
Algorithm Hash digest
SHA256 a3ea4bf36309edec710a92910b8bbc589cc5854e781ff63b78eb06cd238f209a
MD5 0c19294542d0443f383eddfad8d0ce2a
BLAKE2b-256 3a9b5c0b9622d27482a783bde4f1c7c63b8fead58d7d19e67d900f95a7ef6e53

See more details on using hashes here.

File details

Details for the file deepepochs-0.2.6-py3-none-any.whl.

File metadata

  • Download URL: deepepochs-0.2.6-py3-none-any.whl
  • Upload date:
  • Size: 21.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.6.0 importlib_metadata/4.8.2 pkginfo/1.8.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.8.12

File hashes

Hashes for deepepochs-0.2.6-py3-none-any.whl
Algorithm Hash digest
SHA256 c1f2bc0c13d2765014cf2ddcadc4cfa6d3d92d375e82295cb7bb9569c3f3d306
MD5 3d69d4ec0f46b5626a412e8872797731
BLAKE2b-256 a62cae205efaf4cd5644b917737d80da1a7205e1eafcd427911e3d3f5e613e69

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page