Simple tool to generate a tikz figure of a pytorch module by executing it.
Project description
Pytorch2Tikz
Generate Tikz figures for neural networks implemented in pytorch. It uses LaTeX snippets from PlotNeuralNet but you can now just run your network to plot everything automatically. For examples see ./examples
.
Example
from pytorch2tikz import Architecture
print('Load model')
model = vgg16(True)
print('Load data')
...
print('Init architecture')
arch = Architecture(model)
print('Run model')
with torch.inference_mode():
for image, _ in data_loader:
image = image.to(device, non_blocking=True)
output = model(image)
print('Write result to out.tex')
arch.save('out.tex')
Getting Started
pip install pytorch2tikz
Interface
Architecture
Architecture(module: nn.Module,
block_offset=8,
height_depth_factor=0.8,
width_factor=0.8,
linear_factor=0.8,
image_path='./input_{i}.png',
ignore_layers=['batchnorm', 'flatten'],
colors=COLOR_VALUES)
Methods
Argument | description |
---|---|
module |
is the Model to plot |
block_offset |
offset to the next block; A block is created when the input dimensions change |
height_depth_factor |
scale the change of the next layer (last 2 dimensions); typically used to make the network a bit more compact |
width_factor |
scale the change of the next layer (first dimension); typically used to make the network a bit more compact |
linear_factor |
used when there is a drastic change in the last dimension (e.g. moving from conv to linear layers) |
image_path |
output path for recognized input images. {i} gets replaced by the current layer index |
ignore_layers |
define layers that should not be plotted. This can be a list of any substring of the type(class) (e.g. torch.nn.modules.batchnorm.BatchNorm) |
colors |
enum of colors. For an example check out ./pytorch2tikz/constants |
Methods
def get_block(self, name: str) -> Block:
...
get a specific block to alter its properties
def get_tex(self) -> str:
...
generate the tex code
def save(self, file_path: str):
...
generate and save the tex code to the given path
Block
Block(name,
fill: COLOR = COLOR.LINEAR,
bandfill: COLOR = None,
pictype = PICTYPE.BOX,
opacity = 0.7,
size = (10,40,40),
default_size = DEFAULT_VALUE,
dim = 3,
scale_factor = np.zeros(3),
offset: Tuple[int] = (0,0,0),
to: Union[Tuple[int], Block] = (0,0,0),
caption = " ",
xlabel = True,
ylabel = False,
zlabel = True)
Arguments
Argument | Description |
---|---|
name |
arbitrary name of the block. Should be unique, and typically the layers id is used |
fill |
filling color as hex string, e.g. #000000 |
bandfill |
filling of subcolor at the right end of a box. pictype should be PICTYPE.RIGHTBANDEDBOX ot be displayed |
pictype |
one of [PICTYPE.BOX , PICTYPE.RIGHTBANDEDBOX ] |
opacity |
opacity of the filling |
size |
size of the box |
default_size |
Size used for dimensions which are "flat": e.g. for 1D inputs the size (default, default, size) is used. |
dim |
dimensionality of the block, e.g. 1 for Linear layers, 3 for conv2d layers (channels x dim1 x dim2) |
scale_factor |
scale factors to alter the size when outputting tex to make the figure more compact |
offset |
offset to the references position/block in to |
to |
position tuple or block used for relative positioning |
capition |
caption of the block. Use an empty string if no caption is wanted |
xlabel |
display label for 1st dimension |
ylabel |
display label for 2nd dimension |
zlabel |
display label for 3nd dimension |
Contributions
Thank you for share your improvements to this package!
Layer support
Please don't hesitate to add blocks for unsupported layers under pytorch2tikz/block/D<x>.py
with x
being the dimensionality of your layer. If your layer exists for multiple dimensions, choose Dn.py
:
- add your block definition under
pytorch2tikz/block/D<x>.py
- add mapping of type string to
pytorch2tikz/mapping.py
- add your color to
pytorch2tikz/constants.py
(seeColors
)
Custom Connection
For custom connections that can be added in postprocessing of an architecture like residual connections, add your desired connection in pytorch2tikz/block/connections.py
. See the examples there as a guidance. For existing connections there are a bunch of defined positions for each block:
Each position can be combined with (padding-)(near|far)(north|south)(east|west)
.
Colors
Colors are defined in pytorch2tikz/constants.py
. For each color there must exist an entry in the enum COLOR
and the defined value in the Dict COLOR_VALUES
. Make sure your color is easily distinguishable from other layers.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file pytorch2tikz-0.0.1.tar.gz
.
File metadata
- Download URL: pytorch2tikz-0.0.1.tar.gz
- Upload date:
- Size: 195.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: python-requests/2.27.1
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | b28d40c702148965365f30638aeb732b5af1157d1fc5b2d5a8ebf3b082bb3292 |
|
MD5 | 8a82e2e1d85c6a276e73f20671ba3f35 |
|
BLAKE2b-256 | bc3ddcc108ece2676b989f7f11c1f03bb67885d710a9aea4c9fc711caeed6a1e |
File details
Details for the file pytorch2tikz-0.0.1-py2.py3-none-any.whl
.
File metadata
- Download URL: pytorch2tikz-0.0.1-py2.py3-none-any.whl
- Upload date:
- Size: 19.3 kB
- Tags: Python 2, Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: python-requests/2.27.1
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 2cc26cfa65c3159388df207f4d21ec0ad5b023282d66d16d88c4e653016ea00f |
|
MD5 | 715201ddd7feef53b94b5a8eb8472de7 |
|
BLAKE2b-256 | 06871b0fa85a56fe57f78e260fd76490a38c3abd9aa3f3756987550e593aedb9 |