Skip to main content

HrvvI's extension to PyTorch

Project description

Overview

pytorch-hrvvi-ext is my extension to PyTorch, which contains many "out of the box" tools to facilitate my everyday study. It is very easy to use them and integrate them to your projects. I will call it hutil below because of import hutil.

Install

pip3 install -U --no-cache-dir --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple pytorch-hrvvi-ext

Hightlights

Trainer

Trainer is written on ignite, providing the following features:

  • Train your network in few lines without writing loops explicitly.
  • Automatic gpu support like Keras
  • Metric for both CV and NLP (Loss, Accuracy, Top-K Accuracy, mAP, BLEU)
  • Checkpoints of the whole trainer by epochs or metrics
  • Send metric history to WeChat

Datasets

hutil contains many datasets wrapped by me providing torchvison.datasets style API. Some of them is much easier to train than VOC or COCO and more suitable for BEGINNERS in object detection. Now it contains the following datasets:

  • CaptchaDetectionOnline: generate captcha image and bounding boxes of chars online
  • SVHNDetection: SVHN dataset for object detection
  • CocoDetection: unreleased dataset of torchvison with hutil's transforms
  • VOCDetection: unreleased dataset of torchvison with hutil's transforms

Transforms

Transoforms in hutil transform inputs and targets of datasets simultaneously, which is more flexible than torchvison.transforms and makes it easier to do data augmentation for object detection with torchvision.transforms style API. The following transoforms is provided now:

  • Resize
  • CenterCrop
  • ToPercentCoords
  • Compose
  • InputTransform
  • TargetTransform

Others

  • train_test_split: Split a dataset to a train set and a test set with different (or same) transforms
  • Fullset: Transform your dataset to hutil' style dataset

Examples

CIFAR10

# Data Preparation

train_transforms = InputTransform(
    Compose([
        RandomCrop(32, padding=4),
        RandomHorizontalFlip(),
        ToTensor(),
        Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
    ])
)

test_transform = InputTransform(
    Compose([
        ToTensor(),
        Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
    ])
)

data_home = gpath("datasets/CIFAR10")
ds = CIFAR10(data_home, train=True, download=True)
ds_train, ds_val = train_test_split(
    ds, test_ratio=0.04,
    transform=train_transforms,
    test_transform=test_transform,
)
ds_test = CIFAR10(data_home, train=False, download=True)


# Define network, loss and optimizer

net = ResNet(WideSEBasicBlock, [4,4,4], k=2)
net.apply(init_weights(nonlinearity='relu'))
criterion = nn.CrossEntropyLoss()
optimizer = SGD(net.parameters(), lr=1e-1, momentum=0.9, dampening=0, weight_decay=5e-4, nesterov=True)
lr_scheduler = MultiStepLR(optimizer, [40, 80, 110], gamma=0.2)


# Define metrics

metrics = {
    'loss': Loss(),
    'acc': Accuracy(),
}

# Put it together with Trainer

trainer = Trainer(net, criterion, optimizer, lr_scheduler, metrics=metrics, save_path=gpath("models"), name="CIFAR10-SE-WRN28-2")

# Show number of parameters

summary(net, (3,32,32))

# Define batch size

train_loader = DataLoader(ds_train, batch_size=32, shuffle=True, num_workers=1, pin_memory=True)
test_loader = DataLoader(ds_test, batch_size=128)
val_loader = DataLoader(ds_val, batch_size=128)

# Train and save good models by val loss (lower is better) after first 40 epochs

trainer.fit(train_loader, 100, val_loader=val_loader, save_by_metric='-val_loss', patience=40)

CaptchaDetectionOnline

letters = "0123456789abcdefghijkmnopqrstuvwxyzABDEFGHJKMNRT"
NUM_CLASSES = len(letters) + 1
WIDTH = 128
HEIGHT = 48
LOCATIONS = [
    (8, 3),
    (4, 2),
]
ASPECT_RATIOS = [
    (1, 2, 1/2),
    (1, 2, 1/2),
]
ASPECT_RATIOS = [torch.tensor(ars) for ars in ASPECT_RATIOS]
NUM_FEATURE_MAPS = len(ASPECT_RATIOS)
SCALES = compute_scales(NUM_FEATURE_MAPS, 0.2, 0.9)
DEFAULT_BOXES = [
    compute_default_boxes(lx, ly, scale, ars)
    for (lx, ly), scale, ars in zip(LOCATIONS, SCALES, ASPECT_RATIOS)
]


# Define captcha dataset

fonts = [
    gpath("fonts/msyh.ttf"),
    gpath("fonts/sfsl0800.pfb.ttf"),
    gpath("fonts/SimHei.ttf"),
    gpath("fonts/Times New Roman.ttf"),
]

font_sizes = (28, 32, 36, 40, 44, 48)
image = ImageCaptcha(
    WIDTH, HEIGHT, fonts=fonts, font_sizes=font_sizes)

transform = Compose([
    ToPercentCoords(),
    ToTensor(),
    SSDTransform(SCALES, DEFAULT_BOXES, NUM_CLASSES),
])

test_transform = Compose([
    ToTensor(),
])

ds_train = CaptchaDetectionOnline(
    image, size=50000, letters=letters, rotate=20, transform=transform)
ds_val = CaptchaDetectionOnline(
    image, size=1000, letters=letters, rotate=20, transform=test_transform, online=False)


# Define network, loss and optimizer

out_channels = [
    (NUM_CLASSES + 4) * len(ars)
    for ars in ASPECT_RATIOS
]
net = DSOD([3, 4, 4, 4], 36, out_channels=out_channels, reduction=1)
net.apply(init_weights(nonlinearity='relu'))
criterion = SSDLoss(NUM_CLASSES)
optimizer = Adam(net.parameters(), lr=3e-4)
lr_scheduler = MultiStepLR(optimizer, [40, 70, 100], gamma=0.1)


# Define metrics for training and testing

metrics = {
    'loss': TrainLoss(),
}
test_metrics = {
    'mAP': MeanAveragePrecision(
        SSDInference(
            width=WIDTH, height=HEIGHT,
            f_default_boxes=[ cuda(d) for d in DEFAULT_BOXES ],
            num_classes=NUM_CLASSES,
        )
    )
}

# Put it together with Trainer

trainer = Trainer(net, criterion, optimizer, lr_scheduler,
                  metrics=metrics, evaluate_metrics=test_metrics,
                  save_path=gpath("models"), name="DSOD-CAPTCHA-48")

# Show numbers of parameters

summary(net, (3,HEIGHT, WIDTH))


# Define batch size

train_loader = DataLoader(
    ds_train, batch_size=32, shuffle=True, num_workers=1, pin_memory=True)
val_loader = DataLoader(
    ds_val, batch_size=32, collate_fn=box_collate_fn)

# Train and save by val mAP (higher is better) after first 10 epochs

trainer.fit(train_loader, 15, val_loader=val_loader, save_by_metric='val_mAP', patience=10)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch-hrvvi-ext-1.4.14.tar.gz (31.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pytorch_hrvvi_ext-1.4.14-cp37-cp37m-macosx_10_7_x86_64.whl (104.1 kB view details)

Uploaded CPython 3.7mmacOS 10.7+ x86-64

File details

Details for the file pytorch-hrvvi-ext-1.4.14.tar.gz.

File metadata

  • Download URL: pytorch-hrvvi-ext-1.4.14.tar.gz
  • Upload date:
  • Size: 31.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.12.1 pkginfo/1.5.0.1 requests/2.19.1 setuptools/40.6.3 requests-toolbelt/0.8.0 tqdm/4.28.1 CPython/3.7.2

File hashes

Hashes for pytorch-hrvvi-ext-1.4.14.tar.gz
Algorithm Hash digest
SHA256 848c2ff56765cd27fac6fdccdf78278f20562cd34dea66326d610787a7fede5a
MD5 6a0a6d971c7d3e97a17f649fe899d594
BLAKE2b-256 0dfef69b0ab91acb3e88f078152ed11c6e2bcb2352d93d7f265d5d2e4df4f03a

See more details on using hashes here.

File details

Details for the file pytorch_hrvvi_ext-1.4.14-cp37-cp37m-macosx_10_7_x86_64.whl.

File metadata

  • Download URL: pytorch_hrvvi_ext-1.4.14-cp37-cp37m-macosx_10_7_x86_64.whl
  • Upload date:
  • Size: 104.1 kB
  • Tags: CPython 3.7m, macOS 10.7+ x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.12.1 pkginfo/1.5.0.1 requests/2.19.1 setuptools/40.6.3 requests-toolbelt/0.8.0 tqdm/4.28.1 CPython/3.7.2

File hashes

Hashes for pytorch_hrvvi_ext-1.4.14-cp37-cp37m-macosx_10_7_x86_64.whl
Algorithm Hash digest
SHA256 f19cd754e872b9e2629f52745d544bb05e9d3cffe043d13165c43c0055713e1a
MD5 10203c05adbb94a3e19a65c56aab518b
BLAKE2b-256 d77969b6fe09ca3ab7fc0d252214322d8f194630377559d7cd155d68325fb2e7

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page