A Toolkit for Training, Tracking and Saving PyTorch Models
Project description
# Torch-Scope
[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
[![Documentation Status](https://readthedocs.org/projects/tensorboard-wrapper/badge/?version=latest)](http://tensorboard-wrapper.readthedocs.io/en/latest/?badge=latest)
[![Downloads](https://pepy.tech/badge/torch-scope)](https://pepy.tech/project/torch-scope)
[![PyPI version](https://badge.fury.io/py/torch-scope.svg)](https://badge.fury.io/py/torch-scope)
A Toolkit for training pytorch models, which has three features:
- Tracking environments, dependency, implementations and checkpoints;
- Providing a logger wrapper with two handlers (to ```std``` and ```file```);
- Supporting automatic device selection;
- Providing a tensorboard wrapper;
- Providing a spreadsheet writer to automatically summarizing notes and results;
We are in an early-release beta. Expect some adventures and rough edges.
## Quick Links
- [Installation](#installation)
- [Usage](#usage)
## Installation
To install via pypi:
```
pip install torch-scope
```
To build from source:
```
pip install git+https://github.com/LiyuanLucasLiu/Torch-Scope
```
or
```
git clone https://github.com/LiyuanLucasLiu/Torch-Scope.git
cd Torch-Scope
python setup.py install
```
## Usage
An example is provided as below, please read the doc for a detailed api explaination.
* set up the git in the server & add all source file to the git
* use tensorboard to track the model stats (tensorboard --logdir PATH/log/ --port ####)
```
from torch_scope import wrapper
...
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint_path', type=str, ...)
parser.add_argument('--name', type=str, ...)
parser.add_argument('--gpu', type=str, ...)
...
args = parser.parse_args()
pw = wrapper(os.path.join(args.checkpoint_path, args.name), name = args.log_dir, enable_git_track = False)
# Or if the current folder is binded with git, you can turn on the git tracking as below
# pw = wrapper(os.path.join(args.checkpoint_path, args.name), name = args.log_dir, enable_git_track = True)
# if you properly set the path to credential_path and want to use spreadsheet writer, turn on sheet tracking as below
# pw = wrapper(os.path.join(args.checkpoint_path, args.name), name = args.log_dir, \
# enable_git_track=args.git_tracking, sheet_track_name=args.spreadsheet_name, \
# credential_path="/data/work/jingbo/ll2/Torch-Scope/torch-scope-8acf12bee10f.json")
gpu_index = pw.auto_device() if 'auto' == args.gpu else int(args.gpu)
device = torch.device("cuda:" + str(gpu_index) if gpu_index >= 0 else "cpu")
pw.save_configue(args) # dump the config to config.json
pw.set_level('info') # or 'debug', etc.
# if the spreadsheet writer is enabled, you can add a description about the current model
# pw.add_description(args.description)
pw.info(str(args)) # would be plotted to std & file if level is 'info' or lower
...
batch_index = 0
for index in range(epoch):
...
for instance in ... :
loss = ...
tot_loss += loss.detach()
loss.backward()
if batch_index % ... = 0:
pw.add_loss_vs_batch({'loss': tot_loss / ..., ...}, batch_index, False)
pw.add_model_parameter_stats(model, batch_index, save=True)
optimizer.step()
pw.add_model_update_stats(model, batch_index)
tot_loss = 0
else:
optimizer.step()
batch_index += 1
dev_score = ...
pw.add_loss_vs_batch({'dev_score': dev_score, ...}, index, True)
if dev_score > best_score:
pw.save_checkpoint(model, optimizer, is_best = True)
best_score = dev_score
else:
pw.save_checkpoint(model, optimizer, is_best = False)
```
## Advanced Usage
### Auto Device
### Git Tracking
### Spreadsheet Logging
Share the spreadsheet with the following account ```torch-scope@torch-scope.iam.gserviceaccount.com```. And access the table with its name.
[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
[![Documentation Status](https://readthedocs.org/projects/tensorboard-wrapper/badge/?version=latest)](http://tensorboard-wrapper.readthedocs.io/en/latest/?badge=latest)
[![Downloads](https://pepy.tech/badge/torch-scope)](https://pepy.tech/project/torch-scope)
[![PyPI version](https://badge.fury.io/py/torch-scope.svg)](https://badge.fury.io/py/torch-scope)
A Toolkit for training pytorch models, which has three features:
- Tracking environments, dependency, implementations and checkpoints;
- Providing a logger wrapper with two handlers (to ```std``` and ```file```);
- Supporting automatic device selection;
- Providing a tensorboard wrapper;
- Providing a spreadsheet writer to automatically summarizing notes and results;
We are in an early-release beta. Expect some adventures and rough edges.
## Quick Links
- [Installation](#installation)
- [Usage](#usage)
## Installation
To install via pypi:
```
pip install torch-scope
```
To build from source:
```
pip install git+https://github.com/LiyuanLucasLiu/Torch-Scope
```
or
```
git clone https://github.com/LiyuanLucasLiu/Torch-Scope.git
cd Torch-Scope
python setup.py install
```
## Usage
An example is provided as below, please read the doc for a detailed api explaination.
* set up the git in the server & add all source file to the git
* use tensorboard to track the model stats (tensorboard --logdir PATH/log/ --port ####)
```
from torch_scope import wrapper
...
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint_path', type=str, ...)
parser.add_argument('--name', type=str, ...)
parser.add_argument('--gpu', type=str, ...)
...
args = parser.parse_args()
pw = wrapper(os.path.join(args.checkpoint_path, args.name), name = args.log_dir, enable_git_track = False)
# Or if the current folder is binded with git, you can turn on the git tracking as below
# pw = wrapper(os.path.join(args.checkpoint_path, args.name), name = args.log_dir, enable_git_track = True)
# if you properly set the path to credential_path and want to use spreadsheet writer, turn on sheet tracking as below
# pw = wrapper(os.path.join(args.checkpoint_path, args.name), name = args.log_dir, \
# enable_git_track=args.git_tracking, sheet_track_name=args.spreadsheet_name, \
# credential_path="/data/work/jingbo/ll2/Torch-Scope/torch-scope-8acf12bee10f.json")
gpu_index = pw.auto_device() if 'auto' == args.gpu else int(args.gpu)
device = torch.device("cuda:" + str(gpu_index) if gpu_index >= 0 else "cpu")
pw.save_configue(args) # dump the config to config.json
pw.set_level('info') # or 'debug', etc.
# if the spreadsheet writer is enabled, you can add a description about the current model
# pw.add_description(args.description)
pw.info(str(args)) # would be plotted to std & file if level is 'info' or lower
...
batch_index = 0
for index in range(epoch):
...
for instance in ... :
loss = ...
tot_loss += loss.detach()
loss.backward()
if batch_index % ... = 0:
pw.add_loss_vs_batch({'loss': tot_loss / ..., ...}, batch_index, False)
pw.add_model_parameter_stats(model, batch_index, save=True)
optimizer.step()
pw.add_model_update_stats(model, batch_index)
tot_loss = 0
else:
optimizer.step()
batch_index += 1
dev_score = ...
pw.add_loss_vs_batch({'dev_score': dev_score, ...}, index, True)
if dev_score > best_score:
pw.save_checkpoint(model, optimizer, is_best = True)
best_score = dev_score
else:
pw.save_checkpoint(model, optimizer, is_best = False)
```
## Advanced Usage
### Auto Device
### Git Tracking
### Spreadsheet Logging
Share the spreadsheet with the following account ```torch-scope@torch-scope.iam.gserviceaccount.com```. And access the table with its name.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
torch-scope-0.4.10.tar.gz
(14.0 kB
view details)
File details
Details for the file torch-scope-0.4.10.tar.gz
.
File metadata
- Download URL: torch-scope-0.4.10.tar.gz
- Upload date:
- Size: 14.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/1.12.1 pkginfo/1.4.2 requests/2.19.1 setuptools/40.2.0 requests-toolbelt/0.8.0 tqdm/4.26.0 CPython/3.7.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | f7888900d00b39d94422ad4e2e844e5697d4e88fcebee85eadd62a017f741577 |
|
MD5 | 388c3c48b4395f2144f0b30373263958 |
|
BLAKE2b-256 | 7f76d4d50d3aae32a3b286748b5b54bc435e3b86c9a38886c3e969a2dc2e8988 |