Skip to main content

No project description provided

Project description

deepepochs

Pytorch模型简易训练工具

使用

常规训练流程

from deepepochs import Trainer, Checker, rename
import torch
from torch import nn
from torch.nn import functional as F
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader, random_split
from torchmetrics import functional as MF


# datasets
data_dir = './dataset'
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
mnist_full = MNIST(data_dir, train=True, transform=transform, download=True)
train_ds, val_ds, _ = random_split(mnist_full, [5000, 5000, 50000])
test_ds = MNIST(data_dir, train=False, transform=transform, download=True)

# dataloaders
train_dl = DataLoader(train_ds, batch_size=32)
val_dl = DataLoader(val_ds, batch_size=32)
test_dl = DataLoader(test_ds, batch_size=32)

# pytorch model
channels, width, height = (1, 28, 28)
model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(channels * width * height, 64),
    nn.ReLU(),
    nn.Dropout(0.1),
    nn.Linear(64, 64),
    nn.ReLU(),
    nn.Dropout(0.1),
    nn.Linear(64, 10)
)

def acc(preds, targets):
    return MF.accuracy(preds, targets, task='multiclass', num_classes=10)

@rename('m')
def multi_metrics(preds, targets):
    r =  MF.recall(preds, targets, task='multiclass', num_classes=10)
    f1 = MF.f1_score(preds, targets, task='multiclass', num_classes=10)
    return {'@r': r, '@f1': f1}


checker = Checker('loss', mode='max', patience=2)
opt = torch.optim.Adam(model.parameters(), lr=2e-4)
trainer = Trainer(model, F.cross_entropy, opt=opt, epochs=100, checker=checker, metrics=[acc, multi_metrics])

progress = trainer.fit(train_dl, val_dl)
test_rst = trainer.test(test_dl)

非常规训练流程

  • 第1步:继承deepepochs.TrainerBase类,定制满足需要的Trainer,实现train_step方法和evaluate_step方法
  • 第2步:调用定制Trainer训练模型。

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

deepepochs-0.1.6.tar.gz (12.2 kB view details)

Uploaded Source

Built Distribution

deepepochs-0.1.6-py3-none-any.whl (12.0 kB view details)

Uploaded Python 3

File details

Details for the file deepepochs-0.1.6.tar.gz.

File metadata

  • Download URL: deepepochs-0.1.6.tar.gz
  • Upload date:
  • Size: 12.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.6.0 importlib_metadata/4.8.2 pkginfo/1.8.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.8.12

File hashes

Hashes for deepepochs-0.1.6.tar.gz
Algorithm Hash digest
SHA256 ac60dba5e588e430535b39f0cd3aa0b5c9ab8d4ae82ea94d32622db9afba6f83
MD5 2d7606a8e0c1373e5bdcc0b978422d54
BLAKE2b-256 d5cf3d02af07101161c2c93d62e99c3751ca2207ca18e8332a335a03373c81f6

See more details on using hashes here.

File details

Details for the file deepepochs-0.1.6-py3-none-any.whl.

File metadata

  • Download URL: deepepochs-0.1.6-py3-none-any.whl
  • Upload date:
  • Size: 12.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.6.0 importlib_metadata/4.8.2 pkginfo/1.8.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.8.12

File hashes

Hashes for deepepochs-0.1.6-py3-none-any.whl
Algorithm Hash digest
SHA256 409f71d7140c902f4f00582528b3225e44a1fac053440a0b15859ec41f49f321
MD5 6f6f9fdf05966f595c2870773b86990e
BLAKE2b-256 9589295d6326a6da60955e09dde66b100e26aa670f5ee92cb3f4a36f55daf841

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page