Skip to main content

Draw flowcharts and more with matplotlib

Project description

Requirements

  • Python >= 3.6 (There are f strings)
  • matpltolib
  • numpy

Motivation

There are other libs that you can use to draw flowcharts, why writing your own?

  • PyGraph and NetworkX are great libs for graphs operations, but not so great for flowcharts or computation graphs;
  • Graphviz (and its python declinations) would be ok for these tasks but performs poorly when it comes to customization and control;
  • You could use dot2tex but it integrates very poorly with jupyter notebooks and even if you make it work on a local machine it loses compatibility with online services (github markdown rendered for example).
  • Even using the graphviz python library you are not accessing the full versatility of the underlying tool. For example you can't have complex typographic styles in the labels and also adding super-/sub-script poses a problem. With this lib you can use LateX.

Implementation

Matplotlib is an incredibly versatile instrument for graphics in python and I bent it to my will abusing the Annotation tool far beyond its original purpose of annotating plots and repurposing it to drawing tool.

Everything in this library is drawn with the ax.annotate function, especially exploiting the fact that xycoords and textcoords can be both set to bbox instances, which alleviates me from the effort of finding the coordinates from- and to- which the arrows should point.

The rest of the library is just syntactic sugar and quality-of-life functions and parameters.

Usage

The library focuses on usability: I tried to write the code so that you need to write as little code as possible while maintaining total customization.

Creating a Flow

A new plot always starts by instantiating the Flow class:

f = Flow(ax=None)
f.ax.axis('on')
(0.0, 1.0, 0.0, 1.0)

png

Flow() accepts an Axes instance or instantiates its own. By default the Axis instance has hidden splines and ticks, and its ranges [0,1] on both axes.

Creating Nodes

You proceed by adding nodes with the Flow().node() method: the method takes a node_id and a label argument that concur to control the display item inside the node:

  • If label != None, then it is displayed: labels can contain LaTeX (FY Graphviz!)
  • If label == None, then node_id is displayed
  • If also node_id == None, a number is displayed. The number increases each time a node is drawn.
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
ax1, ax2, ax3 = axes
f = Flow(ax=ax1)
f.node(node_id=None, label=None)

f = Flow(ax=ax2)
f.node(node_id='a', label=None)

f = Flow(ax=ax3)
f.node(node_id='a', label='$x^2 + y^2 = 0$');

png

Traveling and connecting nodes

By creating other nodes they are automatically drawn (they travel) to the right of the last node drawn and connected to it by an arrow.

  • The default direction can be set with the argument direction of the Flow instance
  • Direction of single nodes can be specified in the node() function with the travel argument that indicates where the node travels with respect to the startpoint
  • Nodes can travel without being connected by an arrow: this is achieved by setting .node(connect=False)
  • Travel directions are like in a compass: n, s, e, w, ne, se, sw, nw.

The .node() methods returns a Node instance. Each time a Node is drawn, the ylim and xlim of the Axes update to frame and center all the nodes.

fig, axes = plt.subplots(1, 3, figsize=(15, 4))
ax1, ax2, ax3 = axes
f = Flow(ax=ax1)
f.node(label='first node')
f.node(label='second node')
f.node(label='third node')
f.node(connect=False)

f = Flow(ax=ax2, direction='ns')
f.node(label='first node')
f.node(label='second node')
f.node(label='third node')
f.node(connect=False)

f = Flow(ax=ax3)
f.node(label='first node')
f.node(label='second node', travel='sw')
f.node(label='third node', travel='se')
f.node(connect=False, travel='ne');

png

Setting a travel startpoint

Nodes can travel from different startpoints than the last node. This is achieved by specifying the node(startpoint=Node) argument. A startpoint can be:

  • a str: the Node.id of an instantiated Node
  • a Node instance, returned by the Flow.node() method
f = Flow()
a, _ = f.node('a')
f.node('b', travel='ne')
f.node('c', travel='se', startpoint=a)
f.node('d', startpoint='b');

png

Drawing an edge between existing nodes

Edges can be drawn between nodes that have already been defined, using the Flow.edge() method, which returns an Edge instance

f = Flow()
a, _ = f.node('a')
b, _ = f.node('b', travel='ne')
f.node('c', travel='se', startpoint=a)
f.edge(b, 'c', label='I\'m a label', rotation=90, labelpos=(-1, 0.5))
f.edge('c', 'a', tailport='n', headport='e', arrowprops=dict(connectionstyle='arc3,rad=-0.3'));

png

Use-cases

Computational graph

Let's say you want to draw a fairly simple computational graph. At the moment, the best option to draw such a thing would have been to use the graphviz python package.

Graphviz

With the python package you ca exploit subragph to have multiple inputs for a single box but you have little more control than that. For example I couldn't find a way to align the square boxes in a single line and the best you can do is playing with the layout and hope that one will be good enough for you. Furthermore the superscript are ugly and you have little more mathematical notation that that.

dot = Digraph(node_attr={'fontsize':'9'}, edge_attr={'arrowsize': '0.5', 'fontsize':'9'}, engine='dot')
dot.attr(rankdir='LR', packmode='graph')

with dot.subgraph() as sg:
    sg.attr(rank='same')
    sg.node('x', shape='plaintext', margin='0')
    sg.node('w', label='<W<SUP>[1]</SUP>>', shape='plaintext', margin='0')
    sg.node('b', label='<b<SUP>[1]</SUP>>', shape='plaintext', margin='0')

dot.node('z', shape='rect', label='<Z<SUP>[1]</SUP>  = W<SUP>[1]</SUP> x + b<SUP>[1]</SUP>>', margin='0')
dot.node('h', shape='rect', label='z[2] = W2 a[1] + b[2]', margin='0')
dot.node('y', shape='rect', label='a[2] = g(z[2])', margin='0')
dot.node('l', shape='rect', label='L(a[2], y)', margin='0')

with dot.subgraph() as sg:
    sg.attr(rank='same')
    sg.node('j', label='<W<SUP>[2]</SUP>>', shape='plaintext', margin='0')
    sg.node('k', label='<b<SUP>[2]</SUP>>', shape='plaintext', margin='0')
    sg.node('a', shape='rect', label='a[1] = g(z[1])', margin='0')

dot.edges(['xz', 'wz', 'bz', 'za', 'ah', 'hy', 'yl', 'jh', 'kh'])
dot.edge('l', 'y', headport='s', tailport='s', color='red', label='da[2]', fontcolor='red')
dot.edge('y', 'h', headport='s', tailport='s', color='red', label='dz[2]', fontcolor='red')
dot.edge('h', 'k', color='red', label='db[2]', fontcolor='red')
dot.edge('h', 'j', color='red', label='dW[2]', fontcolor='red')
dot

svg

mpl_flow

With this library you have almost complete control over positioning of boxes and labels and you can access the full power of LateX when it comes to mathematical notation.

fig, ax = plt.subplots(figsize=(10, 2))
f = Flow(ax=ax)
f.node('x', label='$x$')
f.node('W1', label='$W^{[1]}$', travel='s', connect=False)
f.node('b1', label='$b^{[1]}$', travel='s', connect=False)
f.node('Z1', label='$Z^{[1]}=W^{[1]}x+b^{[1]}$', startpoint='x')
f.edge('W1', 'Z1')
f.edge('b1', 'Z1')
f.node('a1', label='$a^{[1]}=g(z^{[1]})$')
f.node('W2', label='$W^{[2]}$', travel='s', connect=False)
f.node('b2', label='$b^{[2]}$', travel='s', connect=False)
f.node('Z2', label='$Z^{[2]}=W^{[2]}x+b^{[2]}$', startpoint='a1')
f.edge('W2', 'Z2')
f.edge('b2', 'Z2')
f.node('a2', label='$a^{[2]}=g(z^{[2]})$', startpoint='Z2')
f.node('L', label='$L(a^{[2]}, y)$')
f.edge('L', 'a2', tailport='s', headport='s', arrowprops=dict(connectionstyle='arc3,rad=0.5', color='r'), c='r', label='$da^{[2]}$')
f.edge('a2', 'Z2', tailport='s', headport='s', arrowprops=dict(connectionstyle='arc3,rad=0.5', color='r'), c='r', label='$dZ^{[2]}$')
f.edge('Z2', 'W2', tailport='s', headport='e', arrowprops=dict(connectionstyle='arc3,rad=0.5', color='r'), c='r', label='$dW^{[2]}$')
f.edge('Z2', 'b2', tailport='s', headport='e', arrowprops=dict(connectionstyle='arc3,rad=0.5', color='r'), c='r', label='$db^{[2]}$');

png

Increasingly complex examples that would be impossible with other libraries and a nightmare with vanilla matplotlib

RNN backpropagation

f = Flow(bbox=dict(boxstyle='square'))

for i in range(6):
    lbl = i if i < 5 else 'T_x'

    if i != 4:
        f.node(f'a{i}', label=f'$a^{{\\langle {lbl} \\rangle}}$', fontsize=13, startpoint=f'a{i-1}')
    else:
        f.node(f'a{i}', label='$\\cdots$', startpoint=f'a{i-1}', fontsize=13, bbox=dict(ec='none'))
    if i != 0:
        f.edge(f'a{i}', f'a{i-1}', arrowprops=dict(connectionstyle='arc3,rad=0.4', ec='r'), headport='se', tailport='sw')

    if i >0 and i != 4:
        f.node(f'x{i}', label=f'$x^{{\\langle {lbl} \\rangle}}$', startpoint=f'a{i}', travel='s', fontsize=13, 
               edge_kwargs=dict(arrowprops=dict(arrowstyle='->')), bbox=dict(ec='none'))
        if i == 5:
            lbl = 'T_y'
        f.node(f'y{i}', label=f'$y^{{\\langle {lbl} \\rangle}}$', startpoint=f'a{i}', travel='n', fontsize=13)
        f.node(f'l{i}', label=f'$\\mathcal{{L}}^{{\\langle {lbl} \\rangle}}$', travel='n', fontsize=13)
        f.edge(f'l{i}', f'y{i}', arrowprops=dict(connectionstyle='arc3,rad=0.4', ec='r', shrinkA=4, shrinkB=6), headport='n', tailport='s')
        f.edge(f'y{i}', f'a{i}', arrowprops=dict(connectionstyle='arc3,rad=0.4', ec='r', shrinkA=4, shrinkB=6), headport='n', tailport='s')

f.node('l', label='$\\mathcal{L}$', startpoint=f'l5', travel='ne', distance=.5, connect=False)

for i in range(1, 6):
    if i != 4:
        f.edge(f'l{i}', 'l', tailport='n', headport='w', arrowprops=dict(connectionstyle='angle,angleA=0,angleB=90,rad=2'))
        f.edge('l', f'l{i}', tailport='sw', headport='ne', arrowprops=dict(connectionstyle='arc3,rad=-0.05', ec='r'))        

png

ResNet architecture

fig, ax = plt.subplots(2, 1, figsize=(14, 6))
ax1, ax2 = ax
dims = [1, 6, 8, 12, 6]
f = [7, 3, 3, 3, 3]
c = [1, 0, 2, 3, 4]
ch = [64, 64, 128, 256, 512]
sc = range(0, 34, 2)

facecolors = sum([[f'C{color}']*dim for color, dim in zip(c, dims)], [])
filters = sum([['${0} \\times {0}$ $\\mathrm{{conv}}$'.format(filt)]*dim for filt, dim in zip(f, dims)], [])
channels = sum([['${}$'.format(c)]*dim for c, dim in zip(ch, dims)], [])
pools = ['$/2$']+['']*6+['$/2$']+['']*7+['$/2$']+['']*11+['$/2$']+['']*5
ax1.set_title('Plain')
ax2.set_title('ResNet')
labels = list(map(lambda l: ', '.join(l), zip(filters, channels, pools)))

f = Flow(ax=ax1)
for i, (c, l) in enumerate(zip(facecolors, labels)):
    f.node(i, label='{:^50}'.format(l.strip(', ')), rotation=90, bbox=dict(boxstyle='square', pad=0.1, fc=c, alpha=.2))
f.node(sum(dims), label='{:^44}'.format('$\\mathrm{FC}$ $1000$'), rotation=90, bbox=dict(boxstyle='square', pad=0.1, fc=c, alpha=.2));

edges = []
f = Flow(ax=ax2)
for i, (c, l) in enumerate(zip(facecolors, labels)):
    _, e = f.node(i, label='{:^50}'.format(l.strip(', ')), rotation=90, bbox=dict(boxstyle='square', pad=0.1, fc=c, alpha=.2))
    if e is not None:
        edges.append(e)
_, e = f.node(sum(dims), label='{:^44}'.format('$\\mathrm{FC}$ $1000$'), rotation=90, bbox=dict(boxstyle='square', pad=0.1, fc=c, alpha=.2))
edges.append(e)

for ii, (i, j) in enumerate(zip(edges[::2], edges[2::2])):
    ls = '--'if ii in [3, 7, 13] else '-'
    f.edge(i, j, arrowprops=dict(connectionstyle='arc,angleA=90,angleB=90,armA=70,armB=70,rad=25', ls=ls))

png

U-net architecture

f = Flow(figsize=(12, 6))
vpad=4
hpad=0

nodes = []
fc='C0'
for i in range(27):
    c='k'
    lw=1
    dst=1
    drc='e'
    if i != 1 and i % 3 == 0:
        c='r'
        lw=2
        drc = 's'
        level = -1
        if i >= 15:
            fc='cyan'
            c='lime'
            drc = 'n'
            level = 1
        vpad += level
        dst=1/level

    hpad -= level
    nodes.append(f.node(label=' '*hpad+'\n'*3*vpad,  
                        bbox=dict(boxstyle='square', fc=fc, ec='none'),
                        edge_kwargs=dict(arrowprops=dict(ec=c, lw=lw)),
                        travel=drc, distance=0.2, fontsize=5))

    if i >= 15:
        if i % 3 == 0:
            f.edge(nodes[i][0], nodes[((9-(i//3-1))-1)*3-1][0], arrowprops=dict(arrowstyle='->', ec='lightgray', lw=3))
f.node(label='', distance=.12, bbox=dict(boxstyle=None), edge_kwargs=dict(arrowprops=dict(ec='magenta')))
vpad = 0
hpad = 5
for i in range(15, 27, 3):    
    f.node(label=' '*hpad+'\n'*3*vpad, startpoint=i+1, travel='w', distance=0.008*hpad, connect=False,
          bbox=dict(boxstyle='square', ec='none', fc='C0'), fontsize=5)
    hpad -=1 
    vpad +=1


plt.scatter([],[],marker=r'$\rightarrow$', label='Conv. ReLU', c='k', s=100) 
plt.scatter([],[],marker=r'$\rightarrow$', label='Skip Connection', c='lightgray', s=100)
plt.scatter([],[],marker=r'$\rightarrow$', label='Conv (1x1)', c='magenta', s=100)
plt.scatter([],[],marker=r'$\downarrow$', label='Max pool', c='r', s=100) 
plt.scatter([],[],marker=r'$\uparrow$', label='Max pool', c='lime', s=100)
plt.legend(loc='lower right');

png

Transformer architecture

def multihead(f, startpoint, travel, nid, distance=1, edge_kwargs=dict()):
    a =f.node(f'{nid}.1', label='$\\mathrm{Multi-Head}$\n$\\mathrm{Attention}$',
           startpoint=startpoint, travel=travel, distance=distance,
           bbox=dict(ec='C0'), edge_kwargs=edge_kwargs)
    b =f.node(f'{nid}.2', label='$\\mathrm{Multi-Head}$\n$\\mathrm{Attention}$',
           travel='ne', distance=.03, bbox=dict(ec='C3'), connect=False, zorder=-10)
    c =f.node(f'{nid}.3', label='$\\mathrm{Multi-Head}$\n$\\mathrm{Attention}$',
           travel='ne', distance=.03, bbox=dict(ec='C2'), connect=False, zorder=-20)
    return a,b,c

f = Flow(figsize=(6, 6))
f.node('x', label=r'$x^{\langle 1 \rangle}, \dots , x^{\langle T_x \rangle}$', 
       bbox=dict(ec='none'), fontsize=13)
f.node('peE', label='$+$', bbox=dict(boxstyle='circle'), travel='n', distance=.7)
(_, me),_, _ = multihead(f, 'peE', 'n', 'ME', distance=1,
          edge_kwargs=dict(
              label='$Q, K, V$', labelpos=(5, 0.5)))
f.node('an1', label='$\\mathrm{Add \\ & \\ Norm}$', travel='n', 
       bbox=dict(ec='orange'), startpoint='ME.1', distance=.5,
       edge_kwargs=dict(arrowprops=dict(arrowstyle='-')))
_, ffe = f.node('ffnnE', label='$\\mathrm{Feed \\ Forward}$\n$\\mathrm{Neural\\ Network}$',
       travel='n')
f.node('an2', label='$\\mathrm{Add \\ & \\ Norm}$', travel='n', 
       bbox=dict(ec='orange'), distance=.5,
       edge_kwargs=dict(arrowprops=dict(arrowstyle='-')))
f.node('compound_edge1', label='', distance=.5, travel='n',
       edge_kwargs=dict(arrowprops=dict(arrowstyle='-', shrinkA=0), 
                        headport=(0.5, 0.5)))
f.node('compound_edge2', label='', distance=2, travel='e',
   edge_kwargs=dict(arrowprops=dict(arrowstyle='-', shrinkA=0, shrinkB=0), 
                    headport=(0.5, 0.5), tailport=(0.5, 0.5)))

(_, md1), _, _ = multihead(f, 'compound_edge2', 'se', 'MD2', distance=(1, 1.5), 
          edge_kwargs=dict(
              label='$K, V$', labelpos=(.75, 0.05), 
              tailport=(0.5, 0.5), headport=(0.334, 0), 
              arrowprops=dict( 
                  shrinkB=0,
                  connectionstyle='bar,fraction=-0.2,angle=0')))
_, an3 = f.node('an3', label='$\\mathrm{Add \\ & \\ Norm}$', travel='s', 
       bbox=dict(ec='orange'), distance=1, startpoint='MD2.1',
       edge_kwargs=dict(
           arrowprops=dict(arrowstyle='->'), 
           label='$Q$', labelpos=(2, 0.5)))
(_, md1), _, _ = multihead(f, 'an3', 's', 'MD1', distance=.5,
          edge_kwargs=dict(
              arrowprops=dict(arrowstyle='-')))

_, an4 = f.node('an4', label='$\\mathrm{Add \\ & \\ Norm}$', travel='n', 
       bbox=dict(ec='orange'), distance=.5, startpoint='MD2.1',
       edge_kwargs=dict(arrowprops=dict(arrowstyle='-')))

_, ffd= f.node('ffnnD', label='$\\mathrm{Feed \\ Forward}$\n$\\mathrm{Neural\\ Network}$',
   travel='n', startpoint='an4')
f.node('an5', label='$\\mathrm{Add \\ & \\ Norm}$', travel='n', 
       bbox=dict(ec='orange'), distance=.5, startpoint='ffnnD',
       edge_kwargs=dict(arrowprops=dict(arrowstyle='-')))
f.node('lin', label='$\\mathrm{Linear}$'.center(30), travel='n', distance=.8)
f.node('soft', label='$\\mathrm{Softmax}$'.center(28), travel='n', distance=.4, 
       edge_kwargs=dict(arrowprops=dict(arrowstyle='-')))
f.node('y', label='$\\hat{y}_n$', travel='n', distance=.6, bbox=dict(ec='none'))
f.node('compound_edge1', label='', distance=1, travel='e',
   edge_kwargs=dict(arrowprops=dict(arrowstyle='-', shrinkA=0), 
                    headport=(0.5, 0.5)))
_, pd = f.node('peD', label='$+$', bbox=dict(boxstyle='circle'), travel='s', startpoint='MD1.1', 
       distance=1, edge_kwargs=dict(arrowprops=dict(arrowstyle='->')))
f.node('P', label='$P$'.center(10), distance=1.5, startpoint='peE', 
       edge_kwargs=dict(arrowprops=dict(arrowstyle='->')))
f.edge('P', 'peD', headport='w', tailport='e', 
       arrowprops=dict(connectionstyle='bar,fraction=-0.7,angle=90'))
f.edge('compound_edge1', 'peD', 
       arrowprops=dict(
           arrowstyle='<-', connectionstyle='bar,fraction=0.03,angle=0',
           shrinkA=0, shrinkB=0), 
       headport='s', tailport=(0.5, 0.5))
f.node('encoder', label='  \n\n\n\n', startpoint='x', travel='n', distance=2.5,
       bbox=dict(boxstyle='square', ec='none', fc='gainsboro', pad=5), 
       zorder=-30, connect=False, 
       xlabel='$\\mathrm{Encoder}$', xlabel_xy=(0.15, 1.03))
f.node('encoder', label='   \n\n\n\n\n\n\n\n\n\n', startpoint='MD2.1', distance=.05, travel='n',
   bbox=dict(boxstyle='square', ec='none', fc='gainsboro', pad=5), 
   zorder=-30, connect=False, 
   xlabel='$\\mathrm{Decoder}$', xlabel_xy=(0.15, 1.03))
f.edge(me, 'an1', tailport='w', headport='w', arrowprops=dict(connectionstyle='bar,fraction=-0.2,angle=90'))
f.edge(ffe, 'an2', tailport='w', headport='w', arrowprops=dict(connectionstyle='bar,fraction=-0.2,angle=90'))
f.edge(pd, 'an3', tailport='e', headport='e', arrowprops=dict(connectionstyle='bar,fraction=-0.2,angle=90'))
f.edge(an3, 'an4', tailport=(1, 0.2), headport='e',  
       arrowprops=dict(connectionstyle='bar,fraction=-0.2,angle=90'))
f.edge(ffd, 'an5', tailport=(1, 0.5), headport='e',  
       arrowprops=dict(connectionstyle='bar,fraction=-0.2,angle=90'));

png

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mpl_flow-1.0.0.tar.gz (10.4 kB view details)

Uploaded Source

Built Distribution

mpl_flow-1.0.0-py3-none-any.whl (9.2 kB view details)

Uploaded Python 3

File details

Details for the file mpl_flow-1.0.0.tar.gz.

File metadata

  • Download URL: mpl_flow-1.0.0.tar.gz
  • Upload date:
  • Size: 10.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.25.1 setuptools/53.0.0 requests-toolbelt/0.9.1 tqdm/4.56.2 CPython/3.6.9

File hashes

Hashes for mpl_flow-1.0.0.tar.gz
Algorithm Hash digest
SHA256 4d128977363fb8b1bd030a6f9c8a5dc82118a228cc0f08db34ed21bab88c739a
MD5 136df61a081f38c4c43c68630c9cc021
BLAKE2b-256 97a28e2ed595e5d6e521ced9e569c086145e0d4cbeeb43e8ff44e1137ac0512d

See more details on using hashes here.

File details

Details for the file mpl_flow-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: mpl_flow-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 9.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.25.1 setuptools/53.0.0 requests-toolbelt/0.9.1 tqdm/4.56.2 CPython/3.6.9

File hashes

Hashes for mpl_flow-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 cd305915ed24039f7ef71a6266e321d3a21068d09773a8a79d24105ab3748b9e
MD5 0d0839ace7ef6c03868eb6363cd41c29
BLAKE2b-256 f25bbf8c48711629bb8f963f5f30b23dfacad6280b1664ddb71d4bac62750492

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page