Skip to main content

A library for tracing the execution of Pytorch operations and modules

Project description

torchtrail

PyPI version Build Status GitHub license

torchtrail provides an external API to trace pytorch models and extract the graph of torch functions and modules that were executed. The graphs can then be visualized or used for other purposes.

Installation Instructions

On MacOs

brew install graphviz
pip install torchtrail

On Ubuntu

sudo apt-get install graphviz
pip install torchtrail

Examples

Tracing a function

import torch
import torchtrail

with torchtrail.trace():
    input_tensor = torch.rand(1, 64)
    output_tensor = torch.exp(input_tensor)
torchtrail.visualize(output_tensor, file_name="exp.svg")

The graph could be obtained as a networkx.MultiDiGraph using torchtrail.get_graph:

graph: "networkx.MultiDiGraph" = torchtrail.get_graph(output_tensor)

Tracing a module

import torch
import transformers

import torchtrail

model_name = "google/bert_uncased_L-4_H-256_A-4"
config = transformers.BertConfig.from_pretrained(model_name)
config.num_hidden_layers = 1
model = transformers.BertModel.from_pretrained(model_name, config=config).eval()

with torchtrail.trace():
    input_tensor = torch.randint(0, model.config.vocab_size, (1, 64))
    output = model(input_tensor).last_hidden_state

torchtrail.visualize(output, max_depth=1, file_name="bert_max_depth_1.svg")

torchtrail.visualize(output, max_depth=2, file_name="bert_max_depth_2.svg")

The graph of the full module can be visualized by omitting max_depth argument

torchtrail.visualize(output, file_name="bert.svg")

The graph could be obtained as a networkx.MultiDiGraph using torchtrail.get_graph:

graph: "networkx.MultiDiGraph" = torchtrail.get_graph(output_tensor)

Alternatively, visualization of the modules can be turned off completely using show_modules=False

torchtrail.visualize(output, show_modules=False, file_name="bert_show_modules_False.svg")

The flattened graph could be obtained as a networkx.MultiDiGraph using torchtrail.get_graph:

graph: "networkx.MultiDiGraph" = torchtrail.get_graph(output_tensor, flatten=True)

Reference

  • torchtrail was inspired by torchview. mert-kurttutan did an amazing job with displaying torch graphs. However, one of the goals of torchtrail included producing networkx-compatible graph, therefore torchtrail was written.
  • The idea to use persistent MultiDiGraph to trace torch operations was taken from composit

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchtrail-0.0.20.tar.gz (17.1 kB view details)

Uploaded Source

File details

Details for the file torchtrail-0.0.20.tar.gz.

File metadata

  • Download URL: torchtrail-0.0.20.tar.gz
  • Upload date:
  • Size: 17.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.7

File hashes

Hashes for torchtrail-0.0.20.tar.gz
Algorithm Hash digest
SHA256 29ac923cac1bb8eaa750b317e888113a5e1b2953c8963cbde76c9f4f46e4c17c
MD5 6bc3e6372241a0ad96dc5efce1fba1b8
BLAKE2b-256 3cf66f87f3632e165acc75769a81e702adfb4ee5d6ee6b022429db0b5c1bce2c

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page