Skip to main content

growingnn is a cutting-edge Python package that introduces a dynamic neural network architecture learning algorithm. This innovative approach allows the neural network to adapt its structure during training, optimizing both weights and architecture. Leveraging a Stochastic Gradient Descent-based optimizer and guided Monte Carlo tree search, the package provides a powerful tool for enhancing model performance. Now with PyTorch support for enhanced performance and GPU acceleration!

Project description

growingnn - Dynamic Neural Network Architecture Learning

The growingnn project introduces an innovative algorithm for data-driven neural network model construction. This algorithm comprises two fundamental components: the first component focuses on weight adjustment, while the second component acts as an orchestrator, launching a guided procedure to dynamically change the network architecture. This architectural modification occurs at regular intervals, specifically every $K$ epochs, and is driven by the outcome of a Monte Carlo tree search. The algorithm's core, outlined in the accompanying research paper, leverages the principles of Stochastic Gradient Descent (SGD) without relying on advanced tools commonly used in neural network training.

Repozytorium GitHub

Click the link above to navigate to the GitHub repository, where you can find the source code, issues, and other project details.

Dokumentacja

This link will take you to the project documentation. Here you'll find instructions, usage information, and other resources helpful for using the project.

Algorithm Overview

Weight Adjustment Component

The first component of the algorithm is dedicated to weight adjustment. It operates within the framework of Stochastic Gradient Descent (SGD), a foundational optimization algorithm for training neural networks. The simplicity of this approach makes it suitable for educational settings, emphasizing fundamental machine learning principles.

Orchestrator and Network Architecture Modification

The second component, the orchestrator, plays a crucial role in initiating a procedure to dynamically change the network architecture. This change occurs systematically at predefined intervals, specifically every $K$ epochs. The decision-making process for architectural changes is facilitated by a guided Monte Carlo tree search. This sophisticated mechanism ensures that architectural modifications are well-informed and contribute to the overall improvement of the neural network model.

Implementation Details

Model Structure

The model is the main structure that stores layers as nodes in a directed graph. It operates based on layer identifiers, treating each layer as an independent structure that contains information about incoming and outgoing connections. The default starting structure is a simple graph with an input and output layer connected by a single connection. In each generation, the algorithm has the flexibility to add new layers or remove existing ones. As the structure grows, each layer gains more incoming and outgoing connections.

Propagation Phase

During the propagation phase, each layer waits until it receives signals from all input layers. Once these signals are received, they are averaged, processed, and propagated through all outgoing connections. This iterative process allows the neural network to dynamically adapt its architecture based on the evolving data and training requirements.

Results and Testing

The proposed algorithm has undergone rigorous testing, particularly in visual pattern classification problems. The results have consistently demonstrated high levels of satisfaction, showcasing the efficacy of the dynamic architecture learning approach in enhancing model performance.

x_train, x_test, y_train, y_test, labels = data_reader.read_mnist_data(mnist_path, 0.9)
gnn.trainer.train(
    x_train = x_train, 
    y_train = y_train, 
    x_test = x_test,
    y_test = y_test,
    labels = labels,
    input_paths = 1,
    path = "./result", 
    model_name = "GNN_model",
    epochs = 10, 
    generations = 10,
    input_size = 28 * 28, 
    hidden_size = 28 * 28, 
    output_size = 10, 
    input_shape = (28, 28, 1), 
    kernel_size = 3, 
    depth = 2
)

This code trains a simple network on the MNIST dataset

Credits

Szymon Swiderski Agnieszka Jastrzebska

Disclosure

This is the first beta version of the growingnn package. We are not liable for the accuracy of the program’s output nor actions performed based upon it.

For more in-depth information on the algorithm, its implementation, and testing results, refer to the accompanying research paper. The provided Python source code is a valuable resource for understanding and implementing the presented method. Feel free to explore, contribute, and adapt it for your specific needs.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

growingnn-0.4.0.tar.gz (86.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

growingnn-0.4.0-py3-none-any.whl (48.6 kB view details)

Uploaded Python 3

File details

Details for the file growingnn-0.4.0.tar.gz.

File metadata

  • Download URL: growingnn-0.4.0.tar.gz
  • Upload date:
  • Size: 86.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.12

File hashes

Hashes for growingnn-0.4.0.tar.gz
Algorithm Hash digest
SHA256 cb6685e631e16540088c332d4ac61a3b1d41d5b73a0b9a925ec4a87d82e10518
MD5 825c8679cbbbc9c5a6ba8c042c5e9b47
BLAKE2b-256 a7332bb61c89a389968740752c65c83cf04d5901cd53057fdcac4d7883df8543

See more details on using hashes here.

File details

Details for the file growingnn-0.4.0-py3-none-any.whl.

File metadata

  • Download URL: growingnn-0.4.0-py3-none-any.whl
  • Upload date:
  • Size: 48.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.12

File hashes

Hashes for growingnn-0.4.0-py3-none-any.whl
Algorithm Hash digest
SHA256 c1c5d7ffa05922ad7df601c3d3e3ef0de38a4de0bad799cd22c5f449843d0ae3
MD5 94251bf2f5abaf2ed75cf5ef63e98f60
BLAKE2b-256 d70edd7e3778d121bf8a4d7997163bde8424d85286934089ddc6c4c6726c032c

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page