Skip to main content

A set of interfaces to simplify the usage of PyTorch

Project description

# torchpack

[![PyPI Version](https://img.shields.io/pypi/v/torchpack.svg)](https://pypi.python.org/pypi/torchpack)

Torchpack is a set of interfaces to simplify the usage of PyTorch.

Documentation is ongoing.


## Installation

- Install with pip.
```
pip install torchpack
```
- Install from source.
```
git clone https://github.com/hellock/torchpack.git
cd torchpack
python setup.py install
```

**Note**: If you want to use tensorboard to visualize the training process, you need to
install tensorflow([`installation guide`](https://www.tensorflow.org/install/install_linux)) and tensorboardX(`pip install tensorboardX`).

## What can torchpack do

Torchpack aims to help users to start training with less code, while stays
flexible and configurable. It provides a `Runner` with lots of `Hooks`.

## Example

```python
######################## file1: config.py #######################
work_dir = './demo' # dir to save log file and checkpoints
optimizer = dict(
algorithm='SGD', args=dict(lr=0.001, momentum=0.9, weight_decay=5e-4))
workflow = [('train', 2), ('val', 1)] # train 2 epochs and then validate 1 epochs, iteratively
max_epoch = 16
lr_policy = dict(policy='step', step=12) # decrese learning rate by 10 every 12 epochs
checkpoint_cfg = dict(interval=1) # save checkpoint at every epoch
log_cfg = dict(
# log at every 50 iterations
interval=50,
# two logging hooks, one for printing in terminal and one for tensorboard visualization
hooks=[
('TextLoggerHook', {}),
('TensorboardLoggerHook', dict(log_dir=work_dir + '/log'))
])

######################### file2: main.py ########################
import torch
from torchpack import Config, Runner
from collections import OrderedDict

# define how to process a batch and return a dict
def batch_processor(model, data, train_mode):
img, label = data
label = label.cuda(non_blocking=True)
pred = model(img)
loss = F.cross_entropy(pred, label)
accuracy = get_accuracy(pred, label_var)
log_vars = OrderedDict()
log_vars['loss'] = loss.item()
log_vars['accuracy'] = accuracy.item()
outputs = dict(loss=loss, log_vars=log_vars, num_samples=img.size(0))
return outputs

cfg = Config.from_file('config.py') # or config.yaml/config.json
model = resnet18()
runner = Runner(model, cfg.optimizer, batch_processor, cfg.work_dir)
runner.register_default_hooks(lr_config=cfg.lr_policy,
checkpoint_config=cfg.checkpoint_cfg,
log_config=cfg.log_cfg)

runner.run([train_loader, val_loader], cfg.workflow, cfg.max_epoch)
```

For a full example of training on ImageNet, please see `examples/train_imagenet.py`.

```shell
python examples/train_imagenet.py examples/config.py
```

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchpack-0.2.1.tar.gz (12.9 kB view details)

Uploaded Source

Built Distribution

torchpack-0.2.1-py3-none-any.whl (18.3 kB view details)

Uploaded Python 3

File details

Details for the file torchpack-0.2.1.tar.gz.

File metadata

  • Download URL: torchpack-0.2.1.tar.gz
  • Upload date:
  • Size: 12.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.11.0 pkginfo/1.4.2 requests/2.19.1 setuptools/40.0.0 requests-toolbelt/0.8.0 tqdm/4.24.0 CPython/3.5.5

File hashes

Hashes for torchpack-0.2.1.tar.gz
Algorithm Hash digest
SHA256 143e65966d1e9ab52f83bbc0e5070f2b2cc2f05eb3eef9196a93691770c1188f
MD5 7e5ec3cd797ffe8b10ca98bc95c22592
BLAKE2b-256 16b23a664c4311dfe575bc8e71ff901abd232e51a09e16193923c6059cbdf9fa

See more details on using hashes here.

File details

Details for the file torchpack-0.2.1-py3-none-any.whl.

File metadata

  • Download URL: torchpack-0.2.1-py3-none-any.whl
  • Upload date:
  • Size: 18.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.11.0 pkginfo/1.4.2 requests/2.19.1 setuptools/40.0.0 requests-toolbelt/0.8.0 tqdm/4.24.0 CPython/3.5.5

File hashes

Hashes for torchpack-0.2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 f003cb92fb4b44d44066fc3e46e42dddc97dc1b3741fb88110ea076f0cf674fb
MD5 326a3b7f4edddad247c219fb60063cd7
BLAKE2b-256 2be33673c1ed8ab5d2ea813166c04bb84672fa19d9ffc761f61add0ae3553191

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page