Skip to main content

A python package for visualization and storage management in a pytorch AI task.

Project description

torchtracer

Build Status

torchtracer is a tool package for visualization and storage management in pytorch AI task.

Getting Started

PyTorch Required

This tool is developed for PyTorch AI task. Thus, PyTorch is needed of course.

Installing

You can use pip to install torchtracer.

pip install torchtracer

How to use?

Import torchtracer

from torchtracer import Tracer

Create an instance of Tracer

Assume that the root is ./checkpoints and current task id is lmmnb.

Avoiding messing working directory, you should make root directory manually.

tracer = Tracer('checkpoints').attach('lmmnb')

This step will create a directory checkpoints inside which is a directory lmmnb for current AI task.

Also, you could call .attach() without task id. Datetime will be used as task id.

tracer = Tracer('checkpoints').attach()

Saving config

Raw config should be a dict like this:

# `net` is a defined nn.Module
args = {'epoch_n': 120,
        'batch_size': 10,
        'criterion': nn.MSELoss(),
        'optimizer': torch.optim.RMSprop(net.parameters(), lr=1e-3)}

The config dict should be wrapped with torchtracer.data.Config

cfg = Config(args)
tracer.store(cfg)

This step will create config.json in ./checkpoints/lmmnb/, which contains JSON information like this:

{
  "epoch_n": 120,
  "batch_size": 10,
  "criterion": "MSELoss",
  "optimizer": {
    "lr": 0.001,
    "momentum": 0,
    "alpha": 0.99,
    "eps": 1e-08,
    "centered": false,
    "weight_decay": 0,
    "name": "RMSprop"
  }
}

Logging

During the training iteration, you could print any information you want by using Tracer.log(msg, file).

If file not specified, it will output msg to ./checkpoints/lmmnb/log. Otherwise, it will be ./checkpoints/lmmnb/something.log.

tracer.log(msg='Epoch #{:03d}\ttrain_loss: {:.4f}\tvalid_loss: {:.4f}'.format(epoch, train_loss, valid_loss),
           file='losses')

This step will create a log file losses.log in ./checkpoints/lmmnb/, which contains logs like:

Epoch #001	train_loss: 18.6356	valid_loss: 21.3882
Epoch #002	train_loss: 19.1731	valid_loss: 17.8482
Epoch #003	train_loss: 19.6756	valid_loss: 19.1418
Epoch #004	train_loss: 20.0638	valid_loss: 18.3875
Epoch #005	train_loss: 18.4679	valid_loss: 19.6304
...

Saving model

The model object should be wrapped with torchtracer.data.Model

If file not specified, it will generates model files model.txt. Otherwise, it will be somename.txt

tracer.store(Model(model), file='somename')

This step will create 2 files:

  • description: somename.txt
Sequential
Sequential(
  (0): Linear(in_features=1, out_features=6, bias=True)
  (1): ReLU()
  (2): Linear(in_features=6, out_features=12, bias=True)
  (3): ReLU()
  (4): Linear(in_features=12, out_features=12, bias=True)
  (5): ReLU()
  (6): Linear(in_features=12, out_features=1, bias=True)
)
  • parameters: somename.pth

Saving matplotlib images

Use tracer.store(figure, file) to save matplotlib figure in images/

# assume that `train_losses` and `valid_losses` are lists of losses. 
# create figure manually.
plt.plot(train_losses, label='train loss', c='b')
plt.plot(valid_losses, label='valid loss', c='r')
plt.title('Demo Learning on SQRT')
plt.legend()
# save figure. remember to call `plt.gcf()`
tracer.store(plt.gcf(), 'losses.png')

This step will save a png file losses.png representing losses curves.

Progress bar for epochs

Use tracer.epoch_bar_init(total) to initialize a progress bar.

tracer.epoch_bar_init(epoch_n)

Use tracer.epoch_bar.update(n=1, **params) to update postfix of the progress bar.

tracer.epoch_bar.update(train_loss=train_loss, valid_loss=train_loss)
(THIS IS A DEMO) 
Tracer start at /home/oidiotlin/projects/torchtracer/checkpoints
Tracer attached with task: rabbit
Epoch: 100%|█████████| 120/120 [00:02<00:00, 41.75it/s, train_loss=0.417, valid_loss=0.417]

DO NOT FORGET TO CALL tracer.epoch_bar.close() to finish the bar.

Contribute

If you like this project, welcome to pull request & create issues.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchtracer-0.2.0.tar.gz (5.6 kB view details)

Uploaded Source

File details

Details for the file torchtracer-0.2.0.tar.gz.

File metadata

  • Download URL: torchtracer-0.2.0.tar.gz
  • Upload date:
  • Size: 5.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.12.1 pkginfo/1.4.2 requests/2.20.1 setuptools/40.6.2 requests-toolbelt/0.8.0 tqdm/4.26.0 CPython/3.5.6

File hashes

Hashes for torchtracer-0.2.0.tar.gz
Algorithm Hash digest
SHA256 8193ef47f4bcfc15cc91cd54dddc079d9b59f819a5831a763844367ebcd46ae3
MD5 8386af1b6f2c8e5389752ceef4ace959
BLAKE2b-256 cc32148381ba901e60822c5624235d5111bcff529a52cc72ab5d9bcd0281fbd1

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page