A python package for visualization and storage management in a pytorch AI task.
Project description
torchtracer
torchtracer
is a tool package for visualization and storage management in pytorch AI task.
Getting Started
PyTorch Required
This tool is developed for PyTorch AI task. Thus, PyTorch is needed of course.
Installing
You can use pip
to install torchtracer
.
pip install torchtracer
How to use?
Import torchtracer
from torchtracer import Tracer
Create an instance of Tracer
Assume that the root is ./checkpoints
and current task id is lmmnb
.
Avoiding messing working directory, you should make root directory manually.
tracer = Tracer('checkpoints').attach('lmmnb')
This step will create a directory checkpoints
inside which is a directory lmmnb
for current AI task.
Also, you could call .attach()
without task id. Datetime will be used as task id.
tracer = Tracer('checkpoints').attach()
Saving config
Raw config should be a dict
like this:
# `net` is a defined nn.Module
args = {'epoch_n': 120,
'batch_size': 10,
'criterion': nn.MSELoss(),
'optimizer': torch.optim.RMSprop(net.parameters(), lr=1e-3)}
The config dict should be wrapped with torchtracer.data.Config
cfg = Config(args)
tracer.store(cfg)
This step will create config.json
in ./checkpoints/lmmnb/
, which contains JSON information like this:
{
"epoch_n": 120,
"batch_size": 10,
"criterion": "MSELoss",
"optimizer": {
"lr": 0.001,
"momentum": 0,
"alpha": 0.99,
"eps": 1e-08,
"centered": false,
"weight_decay": 0,
"name": "RMSprop"
}
}
Logging
During the training iteration, you could print any information you want by using Tracer.log(msg, file)
.
If file
not specified, it will output msg
to ./checkpoints/lmmnb/log
. Otherwise, it will be ./checkpoints/lmmnb/something.log
.
tracer.log(msg='Epoch #{:03d}\ttrain_loss: {:.4f}\tvalid_loss: {:.4f}'.format(epoch, train_loss, valid_loss),
file='losses')
This step will create a log file losses.log
in ./checkpoints/lmmnb/
, which contains logs like:
Epoch #001 train_loss: 18.6356 valid_loss: 21.3882
Epoch #002 train_loss: 19.1731 valid_loss: 17.8482
Epoch #003 train_loss: 19.6756 valid_loss: 19.1418
Epoch #004 train_loss: 20.0638 valid_loss: 18.3875
Epoch #005 train_loss: 18.4679 valid_loss: 19.6304
...
Saving model
The model object should be wrapped with torchtracer.data.Model
If file
not specified, it will generates model files model.txt
. Otherwise, it will be somename.txt
tracer.store(Model(model), file='somename')
This step will create 2 files:
- description:
somename.txt
Sequential
Sequential(
(0): Linear(in_features=1, out_features=6, bias=True)
(1): ReLU()
(2): Linear(in_features=6, out_features=12, bias=True)
(3): ReLU()
(4): Linear(in_features=12, out_features=12, bias=True)
(5): ReLU()
(6): Linear(in_features=12, out_features=1, bias=True)
)
- parameters:
somename.pth
Saving matplotlib images
Use tracer.store(figure, file)
to save matplotlib figure in images/
# assume that `train_losses` and `valid_losses` are lists of losses.
# create figure manually.
plt.plot(train_losses, label='train loss', c='b')
plt.plot(valid_losses, label='valid loss', c='r')
plt.title('Demo Learning on SQRT')
plt.legend()
# save figure. remember to call `plt.gcf()`
tracer.store(plt.gcf(), 'losses.png')
This step will save a png file losses.png
representing losses curves.
Progress bar for epochs
Use tracer.epoch_bar_init(total)
to initialize a progress bar.
tracer.epoch_bar_init(epoch_n)
Use tracer.epoch_bar.update(n=1, **params)
to update postfix of the progress bar.
tracer.epoch_bar.update(train_loss=train_loss, valid_loss=train_loss)
(THIS IS A DEMO)
Tracer start at /home/oidiotlin/projects/torchtracer/checkpoints
Tracer attached with task: rabbit
Epoch: 100%|█████████| 120/120 [00:02<00:00, 41.75it/s, train_loss=0.417, valid_loss=0.417]
DO NOT FORGET TO CALL tracer.epoch_bar.close()
to finish the bar.
Contribute
If you like this project, welcome to pull request & create issues.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file torchtracer-0.2.0.tar.gz
.
File metadata
- Download URL: torchtracer-0.2.0.tar.gz
- Upload date:
- Size: 5.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/1.12.1 pkginfo/1.4.2 requests/2.20.1 setuptools/40.6.2 requests-toolbelt/0.8.0 tqdm/4.26.0 CPython/3.5.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 8193ef47f4bcfc15cc91cd54dddc079d9b59f819a5831a763844367ebcd46ae3 |
|
MD5 | 8386af1b6f2c8e5389752ceef4ace959 |
|
BLAKE2b-256 | cc32148381ba901e60822c5624235d5111bcff529a52cc72ab5d9bcd0281fbd1 |