A Toolkit for Training, Tracking and Saving PyTorch Models
Project description
Torch-Scope
A Toolkit for training pytorch models, which has three features:
- Tracking environments, dependency, implementations and checkpoints;
- Providing a logger wrapper with two handlers (to
std
andfile
); - Supporting automatic device selection;
- Providing a tensorboard wrapper;
- Providing a spreadsheet writer to automatically summarizing notes and results;
We are in an early-release beta. Expect some adventures and rough edges.
Quick Links
Installation
To install via pypi:
pip install torch-scope
To build from source:
pip install git+https://github.com/LiyuanLucasLiu/Torch-Scope
or
git clone https://github.com/LiyuanLucasLiu/Torch-Scope.git
cd Torch-Scope
python setup.py install
Usage
An example is provided as below, please read the doc for a detailed api explaination.
- set up the git in the server & add all source file to the git
- use tensorboard to track the model stats (tensorboard --logdir PATH/log/ --port ####)
from torch_scope import wrapper
...
logger = logging.getLogger(__name__)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint_path', type=str, ...)
parser.add_argument('--name', type=str, ...)
parser.add_argument('--gpu', type=str, ...)
...
args = parser.parse_args()
pw = wrapper(os.path.join(args.checkpoint_path, args.name), name = args.log_dir, enable_git_track = False)
# Or if the current folder is binded with git, you can turn on the git tracking as below
# pw = wrapper(os.path.join(args.checkpoint_path, args.name), name = args.log_dir, enable_git_track = True)
# if you properly set the path to credential_path and want to use spreadsheet writer, turn on sheet tracking as below
# pw = wrapper(os.path.join(args.checkpoint_path, args.name), name = args.log_dir, \
# enable_git_track=args.git_tracking, sheet_track_name=args.spreadsheet_name, \
# credential_path="/data/work/jingbo/ll2/Torch-Scope/torch-scope-8acf12bee10f.json")
gpu_index = pw.auto_device() if 'auto' == args.gpu else int(args.gpu)
device = torch.device("cuda:" + str(gpu_index) if gpu_index >= 0 else "cpu")
pw.save_configue(args) # dump the config to config.json
# if the spreadsheet writer is enabled, you can add a description about the current model
# pw.add_description(args.description)
logger.info(str(args)) # would be plotted to std & file if level is 'info' or lower
...
batch_index = 0
for index in range(epoch):
...
for instance in ... :
loss = ...
tot_loss += loss.detach()
loss.backward()
if batch_index % ... = 0:
pw.add_loss_vs_batch({'loss': tot_loss / ..., ...}, batch_index, False)
pw.add_model_parameter_stats(model, batch_index, save=True)
optimizer.step()
pw.add_model_update_stats(model, batch_index)
tot_loss = 0
else:
optimizer.step()
batch_index += 1
dev_score = ...
pw.add_loss_vs_batch({'dev_score': dev_score, ...}, index, True)
if dev_score > best_score:
pw.save_checkpoint(model, optimizer, is_best = True)
best_score = dev_score
else:
pw.save_checkpoint(model, optimizer, is_best = False)
Advanced Usage
Auto Device
Git Tracking
Spreadsheet Logging
Share the spreadsheet with the following account torch-scope@torch-scope.iam.gserviceaccount.com
. And access the table with its name.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
torch-scope-0.5.2.tar.gz
(14.0 kB
view details)
File details
Details for the file torch-scope-0.5.2.tar.gz
.
File metadata
- Download URL: torch-scope-0.5.2.tar.gz
- Upload date:
- Size: 14.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/40.8.0 requests-toolbelt/0.9.1 tqdm/4.31.1 CPython/3.7.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 7dede788c30254d7ac208c078216f487160877cd98d716dac015eb18a9afb7be |
|
MD5 | adbfbbc437499b2414360d8f0a5f2264 |
|
BLAKE2b-256 | 5af76142b17188367652f7fdb7e0ed96b9468cda0a28a031e78cea102d0d6ddd |