Skip to main content

Plot an Artificial Neural Network (ANN) model

Project description

Requirements

  • Python >= 3.6 (There are f strings)
  • matpltolib
  • numpy

Motivation

There are other libs that can plot ANNs, why writing your own?

This implementation is completely based on matplotlib, and doesn't require any additional library like Viz or PyGraph.

This also means that is natively compatible with jupyter noteboks

Furthermore you can also create animations that work natively in HTML pages (can't be demonstrated here since github doesn't allow script tags execution)

%%capture
from matplotlib import animation, rc
from IPython.display import HTML
from itertools import chain

fig, ax = plt.subplots()
ax, nodes, edges = ann([3, 5, 2], ax=ax)
ax.set_aspect('equal')
lines = list(chain.from_iterable(chain.from_iterable(edges)))
def animate(i, lw):
    lw = np.roll(lw, i)
    for w, line in zip(lw, lines):
        line.set_linewidth(w)
    return lines

lw = np.random.rand(len(lines))
anim = animation.FuncAnimation(fig, animate, fargs=(lw,),
                               frames=100, interval=60, repeat=True,
                               blit=True)
# HTML(anim.to_jshtml())

Usage

The only import you need is the ann() function from the plot_ann module (which is the only module)

from plot_ann import ann

To draw a network with default settings just pass the dimensionality of layers as a List[int]

ax, *_ = ann([3, 5, 2])
ax.set_aspect('equal')

png

The ann() function returns ax, nodes, edges. Where ax is an instance of matplotlib.axes; nodes is a structured List of matplotlib.patches.Circle and edges is a structured list of matplotlib.lines.Line2D.

ann([1, 1])
(<AxesSubplot:>,
 [[<matplotlib.patches.Circle at 0x7f1f261929a0>],
  [<matplotlib.patches.Circle at 0x7f1f261929d0>]],
 [[[<matplotlib.lines.Line2D at 0x7f1f26147130>]]])

png

Figure customization

Subplotting

If no ax instance is passed, a new figure and axis will be instantiatied (fig, ax = plt.subplots()) and the ax is returned.

Alternatively, one can pass an ax:

import matplotlib.pyplot as plt
fig, [_, ax] = plt.subplots(1, 2, figsize=(8, 4))
ann([3, 5, 2], ax=ax);

png

Plot size

The plot proportions can be changed using the width and height paramters

fig, axes = plt.subplots(1, 2, figsize=(8, 8))
ax1, ax2 = axes
ann([3, 5, 2], ax=ax1, height=2)
ann([3, 5, 2], ax=ax2, width=3)
for ax in axes:
    ax.set_aspect('equal')

png

If you want to preserve the aspect of the axis, remember to either set ax.set_aspect('equal') or manually set the right figsize. For example, default plt paramters will produce a non-square plot and affect the aspect of the network.

ann([3, 5, 2]);

png

Since in plt text is not affected by this distorsion, you can play with this to produce elliptical nodes with enough space to write long text inside of them

fig, ax = plt.subplots(figsize=(12,5))
ax, nodes, _ = ann([3, 5, 2], ax=ax, radius=2)
ax.text(*nodes[0][0].center, 'I\'m long text', zorder=10, va='center', ha='center');

png

Network Customization

The network is fairly customizable

Layers

Set the layer_labels

  • layer_labels=True will set layer labels automatically
  • layer_labels=List will set layer labels from the list from left to right
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
ax1, ax2 = axes
ann([3, 5, 2], ax=ax1, layer_labels=True)
ann([3, 5, 2], ax=ax2, layer_labels=['Input', 'meow']);
for ax in axes:
    ax.set_aspect('equal')

png

Nodes

  1. Change the radius parameter to customize the radius of nodes:
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
ax1, ax2 = axes
ann([3, 5, 2], ax=ax1)
ann([3, 5, 2], ax=ax2, radius=3);
for ax in axes:
    ax.set_aspect('equal')

png

  1. Change the node_lw (linewidth) of nodes:
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
ax1, ax2 = axes
ann([3, 5, 2], ax=ax1)
ann([3, 5, 2], ax=ax2, node_lw=3);
for ax in axes:
    ax.set_aspect('equal')

png

  1. Change the node_color:
    • node_colors=True will set colors automatically (C0 for input layer and C1 for everything else)
    • node_colors=List will set colors specified in the list starting from top-most input the first layer and proceeding towards the bottom and then towards the right (if the list is shorter than the number of nodes default color is applied)
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
ax1, ax2 = axes
ann([3, 5, 2], ax=ax1, node_colors=True, node_lw=2)
ann([3, 5, 2], ax=ax2, node_colors=['C0', 'C1', 'C2', 'C3', 'C4'], node_lw=2);
for ax in axes:
    ax.set_aspect('equal')

png

  1. Set node_labels:
    • node_labels=True will set labels automatically
    • node_labels=List will set labels specified in the list starting from top-most input the first layer and prociding towards the bottom and then towards the right (if the list is shorter than the number of nodes no label is applied)
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
ax1, ax2 = axes
architecture = [3, 5, 2]
ann(architecture, ax=ax1, radius=2, node_labels=True)
ann(architecture, ax=ax2, radius=2, node_labels=range(sum(architecture)-3));
for ax in axes:
    ax.set_aspect('equal')

png

Edges

  1. Draw edges from the center of the node with edge_from_center=True or from the edge of the node with edge_from_center=False
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
ax1, ax2 = axes
ann([3, 5, 2], ax=ax1, radius=3)
ann([3, 5, 2], ax=ax2, edge_from_center=False, radius=3);
for ax in axes:
    ax.set_aspect('equal')

png

  1. Change the edge_lw (linewidth)
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
ax1, ax2 = axes
ann([3, 5, 2], ax=ax1)
ann([3, 5, 2], ax=ax2, edge_lw=2);
for ax in axes:
    ax.set_aspect('equal')

png

  1. Change the edge_colors:
    • edge_colors=True will set colors automatically (C0 for input layer and C1 for everything else)
    • edge_colors=List will set colors specified in the list starting from top-most input the first layer and proceeding towards the bottom and then towards the right (if the list is shorter than the number of edges default color is applied)
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
ax1, ax2 = axes
ann([3, 5, 2], ax=ax1, edge_colors=True, edge_lw=1)
ann([3, 5, 2], ax=ax2, edge_colors=['C0', 'C1', 'C2', 'C3', 'C4'], edge_lw=1);
for ax in axes:
    ax.set_aspect('equal')

png

  1. Set edge_labels:
    • edge_labels=True will set labels automatically
    • edge_labels=List will set labels specified in the list starting from top-most input the first layer and proceeding towards the bottom and then towards the right (if the list is shorter than the number of edges no label is applied)
import numpy as np
fig, axes = plt.subplots(1, 2, figsize=(12.5, 5))
ax1, ax2 = axes
architecture = [3, 5, 2]
ann(architecture, ax=ax1, edge_labels=True)
ann(architecture, ax=ax2, edge_labels=range(np.prod(architecture)-10));
for ax in axes:
    ax.set_aspect('equal')

png

  1. Change edge_label_spacing:
    • edge_label_spacing=0 will have all edge labels of a layer in the center of the edge (if all labels are drawn some of them will overlap)
    • edge_label_spacing>0 will displace a label left if its rotation angle is > 0 else right
import numpy as np
fig, axes = plt.subplots(1, 2, figsize=(12.5, 5))
ax1, ax2 = axes
architecture = [3, 5, 2]
ann(architecture, ax=ax1, edge_labels=True, edge_label_spacing=0)
ann(architecture, ax=ax2, edge_labels=True, edge_label_spacing=1);
for ax in axes:
    ax.set_aspect('equal')

png

  1. bias=True will assume that the topmost input in architecture is the bias for each layer (no edges reaching them from the previous layer) and will change the indexing for node_labels=True and edge_labels=True
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
ax1, ax2 = axes
ann([3, 5, 2], ax=ax1, radius=2, bias=False, node_labels=True)
ann([3, 5, 2], ax=ax2, radius=2, bias=True, node_labels=True)
for ax in axes:
    ax.set_aspect('equal')

png

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

plot_ann-1.1.3.tar.gz (5.5 kB view hashes)

Uploaded Source

Built Distribution

plot_ann-1.1.3-py3-none-any.whl (5.1 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page