Skip to main content

A set of interfaces to simplify the usage of PyTorch

Project description

# torchpack

[![PyPI Version](https://img.shields.io/pypi/v/torchpack.svg)](https://pypi.python.org/pypi/torchpack)

Torchpack is a set of interfaces to simplify the usage of PyTorch.

Documentation is ongoing.


## Installation

- Install with pip.
```
pip install torchpack
```
- Install from source.
```
git clone https://github.com/hellock/torchpack.git
cd torchpack
python setup.py install
```

**Note**: If you want to use tensorboard to visualize the training process, you need to
install tensorflow([`installation guide`](https://www.tensorflow.org/install/install_linux)) and tensorboardX(`pip install tensorboardX`).

## What can torchpack do

Torchpack aims to help users to start training with less code, while stays
flexible and configurable. It provides a `Runner` with lots of `Hooks`.

## Example

```python
######################## file1: config.py #######################
work_dir = './demo' # dir to save log file and checkpoints
optimizer = dict(
algorithm='SGD', args=dict(lr=0.001, momentum=0.9, weight_decay=5e-4))
workflow = [('train', 2), ('val', 1)] # train 2 epochs and then validate 1 epochs, iteratively
max_epoch = 16
lr_policy = dict(policy='step', step=12) # decrese learning rate by 10 every 12 epochs
checkpoint_cfg = dict(interval=1) # save checkpoint at every epoch
log_cfg = dict(
# log at every 50 iterations
interval=50,
# two logging hooks, one for printing in terminal and one for tensorboard visualization
hooks=[
('TextLoggerHook', {}),
('TensorboardLoggerHook', dict(log_dir=work_dir + '/log'))
])

######################### file2: main.py ########################
import torch
from torchpack import Config, Runner
from collections import OrderedDict

# define how to process a batch and return a dict
def batch_processor(model, data, train_mode):
img, label = data
volatile = False if train_mode else True
img_var = torch.autograd.Variable(img, volatile=volatile)
label_var = torch.autograd.Variable(label, requires_grad=False)
pred = model(img)
loss = F.cross_entropy(pred, label_var)
accuracy = get_accuracy(pred, label_var)
log_vars = OrderedDict()
log_vars['loss'] = loss.data[0]
log_vars['accuracy'] = accuracy.data[0]
outputs = dict(
loss=loss, log_vars=log_vars, num_samples=img.size(0))
return outputs

cfg = Config.from_file('config.py') # or config.yaml/config.json
model = resnet18()
runner = Runner(model, cfg.optimizer, batch_processor, cfg.work_dir)
runner.register_default_hooks(lr_config=cfg.lr_policy,
checkpoint_config=cfg.checkpoint_cfg,
log_config=cfg.log_cfg)

runner.run([train_loader, val_loader], cfg.workflow, cfg.max_epoch)
```

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchpack-0.0.13.tar.gz (11.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchpack-0.0.13-py3-none-any.whl (17.8 kB view details)

Uploaded Python 3

File details

Details for the file torchpack-0.0.13.tar.gz.

File metadata

  • Download URL: torchpack-0.0.13.tar.gz
  • Upload date:
  • Size: 11.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No

File hashes

Hashes for torchpack-0.0.13.tar.gz
Algorithm Hash digest
SHA256 655a5dd0d08ba75aaf6b6d4d36e3d445bcf7f7d9107875849642487a8ebf8443
MD5 86aefdf2af0c9d3367352b0a79543312
BLAKE2b-256 1910e7c829071c606912efccae8041fc5a1c45000e2baae1394ef89867093d69

See more details on using hashes here.

File details

Details for the file torchpack-0.0.13-py3-none-any.whl.

File metadata

File hashes

Hashes for torchpack-0.0.13-py3-none-any.whl
Algorithm Hash digest
SHA256 e2b64bf13d7c881d16c5b29b963f65f77d196380efb830a0fd1a703d8d20b486
MD5 192994ae1d3b5742df01d278496a3dfc
BLAKE2b-256 a4f25af391ba0be1430b518637b49fd25f52d9341f4825b98c10e41d2fe9ba0c

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page