Draw flowcharts and more with matplotlib
Project description
Requirements
- Python >= 3.6 (There are f strings)
- matpltolib
- numpy
Motivation
There are other libs that you can use to draw flowcharts, why writing your own?
- PyGraph and NetworkX are great libs for graphs operations, but not so great for flowcharts or computation graphs;
- Graphviz (and its python declinations) would be ok for these tasks but performs poorly when it comes to customization and control;
- You could use dot2tex but it integrates very poorly with jupyter notebooks and even if you make it work on a local machine it loses compatibility with online services (github markdown rendered for example).
- Even using the graphviz python library you are not accessing the full versatility of the underlying tool. For example you can't have complex typographic styles in the labels and also adding super-/sub-script poses a problem. With this lib you can use LateX.
Implementation
Matplotlib is an incredibly versatile instrument for graphics in python and I bent it to my will abusing the Annotation tool far beyond its original purpose of annotating plots and repurposing it to drawing tool.
Everything in this library is drawn with the ax.annotate function, especially exploiting the fact that xycoords
and textcoords
can be both set to bbox
instances, which alleviates me from the effort of finding the coordinates from- and to- which the arrows should point.
The rest of the library is just syntactic sugar and quality-of-life functions and parameters.
Usage
The library focuses on usability: I tried to write the code so that you need to write as little code as possible while maintaining total customization.
Creating a Flow
A new plot always starts by instantiating the Flow
class:
f = Flow(ax=None)
f.ax.axis('on')
(0.0, 1.0, 0.0, 1.0)
Flow()
accepts an Axes instance or instantiates its own. By default the Axis
instance has hidden splines and ticks, and its ranges [0,1] on both axes.
Creating Nodes
You proceed by adding nodes with the Flow().node()
method: the method takes a node_id
and a label
argument that concur to control the display item inside the node:
- If
label != None
, then it is displayed: labels can contain LaTeX (FY Graphviz!) - If
label == None
, thennode_id
is displayed - If also
node_id == None
, a number is displayed. The number increases each time a node is drawn.
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
ax1, ax2, ax3 = axes
f = Flow(ax=ax1)
f.node(node_id=None, label=None)
f = Flow(ax=ax2)
f.node(node_id='a', label=None)
f = Flow(ax=ax3)
f.node(node_id='a', label='$x^2 + y^2 = 0$');
Traveling and connecting nodes
By creating other nodes they are automatically drawn (they travel) to the right of the last node drawn and connected to it by an arrow.
- The default direction can be set with the argument
direction
of theFlow
instance - Direction of single nodes can be specified in the
node()
function with thetravel
argument that indicates where the nodetravel
s with respect to thestartpoint
- Nodes can travel without being connected by an arrow: this is achieved by setting
.node(connect=False)
- Travel directions are like in a compass:
n
,s
,e
,w
,ne
,se
,sw
,nw
.
The .node()
methods returns a Node
instance. Each time a Node
is drawn, the ylim
and xlim
of the Axes
update to frame and center all the nodes.
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
ax1, ax2, ax3 = axes
f = Flow(ax=ax1)
f.node(label='first node')
f.node(label='second node')
f.node(label='third node')
f.node(connect=False)
f = Flow(ax=ax2, direction='ns')
f.node(label='first node')
f.node(label='second node')
f.node(label='third node')
f.node(connect=False)
f = Flow(ax=ax3)
f.node(label='first node')
f.node(label='second node', travel='sw')
f.node(label='third node', travel='se')
f.node(connect=False, travel='ne');
Setting a travel startpoint
Nodes can travel from different startpoints
than the last node. This is achieved by specifying the node(startpoint=Node)
argument. A startpoint
can be:
- a
str
: theNode.id
of an instantiatedNode
- a
Node
instance, returned by theFlow.node()
method
f = Flow()
a, _ = f.node('a')
f.node('b', travel='ne')
f.node('c', travel='se', startpoint=a)
f.node('d', startpoint='b');
Drawing an edge between existing nodes
Edges can be drawn between nodes that have already been defined, using the Flow.edge()
method, which returns an Edge
instance
f = Flow()
a, _ = f.node('a')
b, _ = f.node('b', travel='ne')
f.node('c', travel='se', startpoint=a)
f.edge(b, 'c', label='I\'m a label', rotation=90, labelpos=(-1, 0.5))
f.edge('c', 'a', tailport='n', headport='e', arrowprops=dict(connectionstyle='arc3,rad=-0.3'));
Use-cases
Computational graph
Let's say you want to draw a fairly simple computational graph. At the moment, the best option to draw such a thing would have been to use the graphviz python package.
Graphviz
With the python package you ca exploit subragph to have multiple inputs for a single box but you have little more control than that. For example I couldn't find a way to align the square boxes in a single line and the best you can do is playing with the layout
and hope that one will be good enough for you. Furthermore the superscript are ugly and you have little more mathematical notation that that.
dot = Digraph(node_attr={'fontsize':'9'}, edge_attr={'arrowsize': '0.5', 'fontsize':'9'}, engine='dot')
dot.attr(rankdir='LR', packmode='graph')
with dot.subgraph() as sg:
sg.attr(rank='same')
sg.node('x', shape='plaintext', margin='0')
sg.node('w', label='<W<SUP>[1]</SUP>>', shape='plaintext', margin='0')
sg.node('b', label='<b<SUP>[1]</SUP>>', shape='plaintext', margin='0')
dot.node('z', shape='rect', label='<Z<SUP>[1]</SUP> = W<SUP>[1]</SUP> x + b<SUP>[1]</SUP>>', margin='0')
dot.node('h', shape='rect', label='z[2] = W2 a[1] + b[2]', margin='0')
dot.node('y', shape='rect', label='a[2] = g(z[2])', margin='0')
dot.node('l', shape='rect', label='L(a[2], y)', margin='0')
with dot.subgraph() as sg:
sg.attr(rank='same')
sg.node('j', label='<W<SUP>[2]</SUP>>', shape='plaintext', margin='0')
sg.node('k', label='<b<SUP>[2]</SUP>>', shape='plaintext', margin='0')
sg.node('a', shape='rect', label='a[1] = g(z[1])', margin='0')
dot.edges(['xz', 'wz', 'bz', 'za', 'ah', 'hy', 'yl', 'jh', 'kh'])
dot.edge('l', 'y', headport='s', tailport='s', color='red', label='da[2]', fontcolor='red')
dot.edge('y', 'h', headport='s', tailport='s', color='red', label='dz[2]', fontcolor='red')
dot.edge('h', 'k', color='red', label='db[2]', fontcolor='red')
dot.edge('h', 'j', color='red', label='dW[2]', fontcolor='red')
dot
mpl_flow
With this library you have almost complete control over positioning of boxes and labels and you can access the full power of LateX when it comes to mathematical notation.
fig, ax = plt.subplots(figsize=(10, 2))
f = Flow(ax=ax)
f.node('x', label='$x$')
f.node('W1', label='$W^{[1]}$', travel='s', connect=False)
f.node('b1', label='$b^{[1]}$', travel='s', connect=False)
f.node('Z1', label='$Z^{[1]}=W^{[1]}x+b^{[1]}$', startpoint='x')
f.edge('W1', 'Z1')
f.edge('b1', 'Z1')
f.node('a1', label='$a^{[1]}=g(z^{[1]})$')
f.node('W2', label='$W^{[2]}$', travel='s', connect=False)
f.node('b2', label='$b^{[2]}$', travel='s', connect=False)
f.node('Z2', label='$Z^{[2]}=W^{[2]}x+b^{[2]}$', startpoint='a1')
f.edge('W2', 'Z2')
f.edge('b2', 'Z2')
f.node('a2', label='$a^{[2]}=g(z^{[2]})$', startpoint='Z2')
f.node('L', label='$L(a^{[2]}, y)$')
f.edge('L', 'a2', tailport='s', headport='s', arrowprops=dict(connectionstyle='arc3,rad=0.5', color='r'), c='r', label='$da^{[2]}$')
f.edge('a2', 'Z2', tailport='s', headport='s', arrowprops=dict(connectionstyle='arc3,rad=0.5', color='r'), c='r', label='$dZ^{[2]}$')
f.edge('Z2', 'W2', tailport='s', headport='e', arrowprops=dict(connectionstyle='arc3,rad=0.5', color='r'), c='r', label='$dW^{[2]}$')
f.edge('Z2', 'b2', tailport='s', headport='e', arrowprops=dict(connectionstyle='arc3,rad=0.5', color='r'), c='r', label='$db^{[2]}$');
Increasingly complex examples that would be impossible with other libraries and a nightmare with vanilla matplotlib
RNN backpropagation
f = Flow(bbox=dict(boxstyle='square'))
for i in range(6):
lbl = i if i < 5 else 'T_x'
if i != 4:
f.node(f'a{i}', label=f'$a^{{\\langle {lbl} \\rangle}}$', fontsize=13, startpoint=f'a{i-1}')
else:
f.node(f'a{i}', label='$\\cdots$', startpoint=f'a{i-1}', fontsize=13, bbox=dict(ec='none'))
if i != 0:
f.edge(f'a{i}', f'a{i-1}', arrowprops=dict(connectionstyle='arc3,rad=0.4', ec='r'), headport='se', tailport='sw')
if i >0 and i != 4:
f.node(f'x{i}', label=f'$x^{{\\langle {lbl} \\rangle}}$', startpoint=f'a{i}', travel='s', fontsize=13,
edge_kwargs=dict(arrowprops=dict(arrowstyle='->')), bbox=dict(ec='none'))
if i == 5:
lbl = 'T_y'
f.node(f'y{i}', label=f'$y^{{\\langle {lbl} \\rangle}}$', startpoint=f'a{i}', travel='n', fontsize=13)
f.node(f'l{i}', label=f'$\\mathcal{{L}}^{{\\langle {lbl} \\rangle}}$', travel='n', fontsize=13)
f.edge(f'l{i}', f'y{i}', arrowprops=dict(connectionstyle='arc3,rad=0.4', ec='r', shrinkA=4, shrinkB=6), headport='n', tailport='s')
f.edge(f'y{i}', f'a{i}', arrowprops=dict(connectionstyle='arc3,rad=0.4', ec='r', shrinkA=4, shrinkB=6), headport='n', tailport='s')
f.node('l', label='$\\mathcal{L}$', startpoint=f'l5', travel='ne', distance=.5, connect=False)
for i in range(1, 6):
if i != 4:
f.edge(f'l{i}', 'l', tailport='n', headport='w', arrowprops=dict(connectionstyle='angle,angleA=0,angleB=90,rad=2'))
f.edge('l', f'l{i}', tailport='sw', headport='ne', arrowprops=dict(connectionstyle='arc3,rad=-0.05', ec='r'))
ResNet architecture
fig, ax = plt.subplots(2, 1, figsize=(14, 6))
ax1, ax2 = ax
dims = [1, 6, 8, 12, 6]
f = [7, 3, 3, 3, 3]
c = [1, 0, 2, 3, 4]
ch = [64, 64, 128, 256, 512]
sc = range(0, 34, 2)
facecolors = sum([[f'C{color}']*dim for color, dim in zip(c, dims)], [])
filters = sum([['${0} \\times {0}$ $\\mathrm{{conv}}$'.format(filt)]*dim for filt, dim in zip(f, dims)], [])
channels = sum([['${}$'.format(c)]*dim for c, dim in zip(ch, dims)], [])
pools = ['$/2$']+['']*6+['$/2$']+['']*7+['$/2$']+['']*11+['$/2$']+['']*5
ax1.set_title('Plain')
ax2.set_title('ResNet')
labels = list(map(lambda l: ', '.join(l), zip(filters, channels, pools)))
f = Flow(ax=ax1)
for i, (c, l) in enumerate(zip(facecolors, labels)):
f.node(i, label='{:^50}'.format(l.strip(', ')), rotation=90, bbox=dict(boxstyle='square', pad=0.1, fc=c, alpha=.2))
f.node(sum(dims), label='{:^44}'.format('$\\mathrm{FC}$ $1000$'), rotation=90, bbox=dict(boxstyle='square', pad=0.1, fc=c, alpha=.2));
edges = []
f = Flow(ax=ax2)
for i, (c, l) in enumerate(zip(facecolors, labels)):
_, e = f.node(i, label='{:^50}'.format(l.strip(', ')), rotation=90, bbox=dict(boxstyle='square', pad=0.1, fc=c, alpha=.2))
if e is not None:
edges.append(e)
_, e = f.node(sum(dims), label='{:^44}'.format('$\\mathrm{FC}$ $1000$'), rotation=90, bbox=dict(boxstyle='square', pad=0.1, fc=c, alpha=.2))
edges.append(e)
for ii, (i, j) in enumerate(zip(edges[::2], edges[2::2])):
ls = '--'if ii in [3, 7, 13] else '-'
f.edge(i, j, arrowprops=dict(connectionstyle='arc,angleA=90,angleB=90,armA=70,armB=70,rad=25', ls=ls))
U-net architecture
f = Flow(figsize=(12, 6))
vpad=4
hpad=0
nodes = []
fc='C0'
for i in range(27):
c='k'
lw=1
dst=1
drc='e'
if i != 1 and i % 3 == 0:
c='r'
lw=2
drc = 's'
level = -1
if i >= 15:
fc='cyan'
c='lime'
drc = 'n'
level = 1
vpad += level
dst=1/level
hpad -= level
nodes.append(f.node(label=' '*hpad+'\n'*3*vpad,
bbox=dict(boxstyle='square', fc=fc, ec='none'),
edge_kwargs=dict(arrowprops=dict(ec=c, lw=lw)),
travel=drc, distance=0.2, fontsize=5))
if i >= 15:
if i % 3 == 0:
f.edge(nodes[i][0], nodes[((9-(i//3-1))-1)*3-1][0], arrowprops=dict(arrowstyle='->', ec='lightgray', lw=3))
f.node(label='', distance=.12, bbox=dict(boxstyle=None), edge_kwargs=dict(arrowprops=dict(ec='magenta')))
vpad = 0
hpad = 5
for i in range(15, 27, 3):
f.node(label=' '*hpad+'\n'*3*vpad, startpoint=i+1, travel='w', distance=0.008*hpad, connect=False,
bbox=dict(boxstyle='square', ec='none', fc='C0'), fontsize=5)
hpad -=1
vpad +=1
plt.scatter([],[],marker=r'$\rightarrow$', label='Conv. ReLU', c='k', s=100)
plt.scatter([],[],marker=r'$\rightarrow$', label='Skip Connection', c='lightgray', s=100)
plt.scatter([],[],marker=r'$\rightarrow$', label='Conv (1x1)', c='magenta', s=100)
plt.scatter([],[],marker=r'$\downarrow$', label='Max pool', c='r', s=100)
plt.scatter([],[],marker=r'$\uparrow$', label='Max pool', c='lime', s=100)
plt.legend(loc='lower right');
Transformer architecture
def multihead(f, startpoint, travel, nid, distance=1, edge_kwargs=dict()):
a =f.node(f'{nid}.1', label='$\\mathrm{Multi-Head}$\n$\\mathrm{Attention}$',
startpoint=startpoint, travel=travel, distance=distance,
bbox=dict(ec='C0'), edge_kwargs=edge_kwargs)
b =f.node(f'{nid}.2', label='$\\mathrm{Multi-Head}$\n$\\mathrm{Attention}$',
travel='ne', distance=.03, bbox=dict(ec='C3'), connect=False, zorder=-10)
c =f.node(f'{nid}.3', label='$\\mathrm{Multi-Head}$\n$\\mathrm{Attention}$',
travel='ne', distance=.03, bbox=dict(ec='C2'), connect=False, zorder=-20)
return a,b,c
f = Flow(figsize=(6, 6))
f.node('x', label=r'$x^{\langle 1 \rangle}, \dots , x^{\langle T_x \rangle}$',
bbox=dict(ec='none'), fontsize=13)
f.node('peE', label='$+$', bbox=dict(boxstyle='circle'), travel='n', distance=.7)
(_, me),_, _ = multihead(f, 'peE', 'n', 'ME', distance=1,
edge_kwargs=dict(
label='$Q, K, V$', labelpos=(5, 0.5)))
f.node('an1', label='$\\mathrm{Add \\ & \\ Norm}$', travel='n',
bbox=dict(ec='orange'), startpoint='ME.1', distance=.5,
edge_kwargs=dict(arrowprops=dict(arrowstyle='-')))
_, ffe = f.node('ffnnE', label='$\\mathrm{Feed \\ Forward}$\n$\\mathrm{Neural\\ Network}$',
travel='n')
f.node('an2', label='$\\mathrm{Add \\ & \\ Norm}$', travel='n',
bbox=dict(ec='orange'), distance=.5,
edge_kwargs=dict(arrowprops=dict(arrowstyle='-')))
f.node('compound_edge1', label='', distance=.5, travel='n',
edge_kwargs=dict(arrowprops=dict(arrowstyle='-', shrinkA=0),
headport=(0.5, 0.5)))
f.node('compound_edge2', label='', distance=2, travel='e',
edge_kwargs=dict(arrowprops=dict(arrowstyle='-', shrinkA=0, shrinkB=0),
headport=(0.5, 0.5), tailport=(0.5, 0.5)))
(_, md1), _, _ = multihead(f, 'compound_edge2', 'se', 'MD2', distance=(1, 1.5),
edge_kwargs=dict(
label='$K, V$', labelpos=(.75, 0.05),
tailport=(0.5, 0.5), headport=(0.334, 0),
arrowprops=dict(
shrinkB=0,
connectionstyle='bar,fraction=-0.2,angle=0')))
_, an3 = f.node('an3', label='$\\mathrm{Add \\ & \\ Norm}$', travel='s',
bbox=dict(ec='orange'), distance=1, startpoint='MD2.1',
edge_kwargs=dict(
arrowprops=dict(arrowstyle='->'),
label='$Q$', labelpos=(2, 0.5)))
(_, md1), _, _ = multihead(f, 'an3', 's', 'MD1', distance=.5,
edge_kwargs=dict(
arrowprops=dict(arrowstyle='-')))
_, an4 = f.node('an4', label='$\\mathrm{Add \\ & \\ Norm}$', travel='n',
bbox=dict(ec='orange'), distance=.5, startpoint='MD2.1',
edge_kwargs=dict(arrowprops=dict(arrowstyle='-')))
_, ffd= f.node('ffnnD', label='$\\mathrm{Feed \\ Forward}$\n$\\mathrm{Neural\\ Network}$',
travel='n', startpoint='an4')
f.node('an5', label='$\\mathrm{Add \\ & \\ Norm}$', travel='n',
bbox=dict(ec='orange'), distance=.5, startpoint='ffnnD',
edge_kwargs=dict(arrowprops=dict(arrowstyle='-')))
f.node('lin', label='$\\mathrm{Linear}$'.center(30), travel='n', distance=.8)
f.node('soft', label='$\\mathrm{Softmax}$'.center(28), travel='n', distance=.4,
edge_kwargs=dict(arrowprops=dict(arrowstyle='-')))
f.node('y', label='$\\hat{y}_n$', travel='n', distance=.6, bbox=dict(ec='none'))
f.node('compound_edge1', label='', distance=1, travel='e',
edge_kwargs=dict(arrowprops=dict(arrowstyle='-', shrinkA=0),
headport=(0.5, 0.5)))
_, pd = f.node('peD', label='$+$', bbox=dict(boxstyle='circle'), travel='s', startpoint='MD1.1',
distance=1, edge_kwargs=dict(arrowprops=dict(arrowstyle='->')))
f.node('P', label='$P$'.center(10), distance=1.5, startpoint='peE',
edge_kwargs=dict(arrowprops=dict(arrowstyle='->')))
f.edge('P', 'peD', headport='w', tailport='e',
arrowprops=dict(connectionstyle='bar,fraction=-0.7,angle=90'))
f.edge('compound_edge1', 'peD',
arrowprops=dict(
arrowstyle='<-', connectionstyle='bar,fraction=0.03,angle=0',
shrinkA=0, shrinkB=0),
headport='s', tailport=(0.5, 0.5))
f.node('encoder', label=' \n\n\n\n', startpoint='x', travel='n', distance=2.5,
bbox=dict(boxstyle='square', ec='none', fc='gainsboro', pad=5),
zorder=-30, connect=False,
xlabel='$\\mathrm{Encoder}$', xlabel_xy=(0.15, 1.03))
f.node('encoder', label=' \n\n\n\n\n\n\n\n\n\n', startpoint='MD2.1', distance=.05, travel='n',
bbox=dict(boxstyle='square', ec='none', fc='gainsboro', pad=5),
zorder=-30, connect=False,
xlabel='$\\mathrm{Decoder}$', xlabel_xy=(0.15, 1.03))
f.edge(me, 'an1', tailport='w', headport='w', arrowprops=dict(connectionstyle='bar,fraction=-0.2,angle=90'))
f.edge(ffe, 'an2', tailport='w', headport='w', arrowprops=dict(connectionstyle='bar,fraction=-0.2,angle=90'))
f.edge(pd, 'an3', tailport='e', headport='e', arrowprops=dict(connectionstyle='bar,fraction=-0.2,angle=90'))
f.edge(an3, 'an4', tailport=(1, 0.2), headport='e',
arrowprops=dict(connectionstyle='bar,fraction=-0.2,angle=90'))
f.edge(ffd, 'an5', tailport=(1, 0.5), headport='e',
arrowprops=dict(connectionstyle='bar,fraction=-0.2,angle=90'));
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.