Plugin for visual display of a neural network.
Project description
# Install the neural-net-drawer package via pip
pip install neural-net-drawer
# Import necessary libraries
import matplotlib.pyplot as plt
from neural_net_drawer.drawer import NeuralNetworkDiagram
# Example usage
fig = plt.figure(figsize=(9, 9))
ax = fig.gca()
ax.axis('off')
# Create an instance of the NeuralNetworkDiagram class
nn_diagram = NeuralNetworkDiagram()
# Call the method to draw the neural network diagram
nn_diagram.draw_neural_net(ax, [7, 5, 4, 3, 4, 2, 1])
plt.show()
# Example usage with customizations
fig = plt.figure(figsize=(9, 9))
ax = fig.gca()
ax.axis('off')
# Create an instance of the NeuralNetworkDiagram class
nn_diagram = NeuralNetworkDiagram()
# Set maximum number of layers to display
nn_diagram.max_n_layers_size = 3
# Set maximum number of neurons per layer to display
nn_diagram.max_layer_size = 10
# Call the method to draw the neural network diagram
nn_diagram.draw_neural_net(ax, [7, 5, 4, 3, 4, 2, 1])
plt.show()
# Example usage with further customizations
fig = plt.figure(figsize=(9, 9))
ax = fig.gca()
ax.axis('off')
# Create an instance of the NeuralNetworkDiagram class
nn_diagram = NeuralNetworkDiagram()
# Set maximum number of layers to display
nn_diagram.max_n_layers_size = 3
# Hide neuron numbers
nn_diagram.show_neuron_numbers = False
# Call the method to draw the neural network diagram
nn_diagram.draw_neural_net(ax, [7, 5, 4, 3, 4, 2, 1])
plt.show()
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.