Skip to main content

A frugal and memetic Neural Architecture Search (NAS) framework.

Project description

Nas-Torch: Frugal & Memetic Neural Architecture Search

Python 3.8+ PyTorch License: MIT

nas-torch is a frugal, modular, and "white-box" Neural Architecture Search (NAS) framework.

It was designed to solve two major problems in Deep Learning: the empirical design ("gut feeling") of hyperparameters and the tendency to generate bloated networks. Unlike traditional NAS approaches that require thousands of GPU days, nas-torch is optimized to discover high-performing topologies in just a few hours on a standard consumer GPU (e.g., RTX 3060).

Main Features

  • Frugality & Accessibility: Find the optimal architecture directly on your laptop.
  • Hybrid Memetic Approach: Combines an autoregressive controller (Transformer) for a topological warm-start, followed by a swarm metaheuristic (Artificial Bee Colony - ABC) for micro-exploitation.
  • DynamicNet Engine: A smart parser that converts a list of layer configurations into a valid PyTorch model, automatically managing the computation of spatial and linear dimensions (via a dummy tensor).
  • Domain Agnostic: Works equally well on computer vision (CIFAR-10) and highly imbalanced tabular data (native optimization of the F1-Score for fraud detection).
  • Fighting Bloat: Integrates a multi-objective reward function that dynamically penalizes unnecessary network depth.

Installation

git clone [https://github.com/Romain-Amigon/8INF976.git](https://github.com/Romain-Amigon/8INF976.git)
cd 8INF976
pip install -r requirements.txt

Quickstart

from nas_torch import TransformerOptimizer, DynamicNet
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

opt = TransformerOptimizer(
    dataset=train_loader,
    max_layers=20,
    d_model=64,
    nhead=4
)

best_arch_config, stats = opt.run(iterations=20)

model = DynamicNet(best_arch_config, input_shape=(3, 32, 32))

print(model)
import torch.nn as nn
from nas_torch import ABCOptimizer, Conv2dCfg, LinearCfg, DropoutCfg, DynamicNet

initial_layers = [
    Conv2dCfg(in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1, activation=nn.ReLU),
    Conv2dCfg(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1, activation=nn.ReLU),
    DropoutCfg(p=0.25),
    LinearCfg(in_features=None, out_features=128, activation=nn.ReLU),
    LinearCfg(in_features=128, out_features=10, activation=nn.LogSoftmax)
]

opt = ABCOptimizer(
    layers=initial_layers,
    dataset=train_loader,
    pop_size=10,
    limit=5,
    patience=3
)

best_arch_config, stats = opt.run(iterations=20)

optimized_model = DynamicNet(best_arch_config, input_shape=(3, 32, 32))

Benchmarks & Performances

Tests were conducted with strict Train/Test splits and a fast evaluation proxy. Hardware: NVIDIA GeForce RTX 3060 Laptop GPU (6 GB VRAM).

Task Algorithm Final Score Search Time Note
Credit Card Fraud Transf. + ABC 0.77 (F1-Score) ~ 2 h Autonomous optimization of F1-Score on imbalanced data.
Breast Cancer ABC Only 99.56% (Acc) < 1 min Dataset absolute limit reached.
California Housing ABC Only -0.32 (MSE) < 10 min Competitive with manual ensemble methods.
CIFAR-10 Transf. + ABC 85.39% (Acc) ~ 4.8 h Very short final training (100 epochs).

Ablation Study (CIFAR-10)

Our memetic approach demonstrates the necessity of the Transformer to avoid the "Cold Start" of metaheuristics:

  • Simulated Annealing (100 iterations): 73.90% ± 3.46%
  • ABC Only (30 iterations): 79.76% ± 2.37%
  • Transformer + ABC: 83.48% ± 1.98%

Framework Architecture

  1. layer_classes.py: Definition of topological building blocks (Conv2dCfg, LinearCfg, DropoutCfg).
  2. model.py: DynamicNet engine and _reconnect_layers algorithm for the mathematical consistency of graphs.
  3. optimizer.py: Abstract optimization classes and evaluation Proxy integrating dynamic Early Stopping. Implementation of Simulated Annealing, GA, ABC, LSTM, and Transformer.

Roadmap / Future Works

  • Implement search space
  • Addition of modern macro-cells (Inverted Residuals, Dense Blocks) to the search space.
  • Integration of Zero-Cost Proxies metrics (e.g., SynFlow) to accelerate initial filtering.
  • Support for strict non-dominated sorting (Pareto Front) for Hardware-Aware NAS (Latency vs. Accuracy).

License

This project is licensed under the MIT License.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

nas_torch-0.1.0.tar.gz (14.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

nas_torch-0.1.0-py3-none-any.whl (11.1 kB view details)

Uploaded Python 3

File details

Details for the file nas_torch-0.1.0.tar.gz.

File metadata

  • Download URL: nas_torch-0.1.0.tar.gz
  • Upload date:
  • Size: 14.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.9

File hashes

Hashes for nas_torch-0.1.0.tar.gz
Algorithm Hash digest
SHA256 fd4033ccc9d937844df7cbe63ebb6cd33d7685566c646b8f617300c0a620dd23
MD5 468704cf9c3b26f088a09d2133ced76f
BLAKE2b-256 111b57f17ace5151a430d1cb3b5ad657ca9263e73fb4b979330641cb1e3e1d7b

See more details on using hashes here.

File details

Details for the file nas_torch-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: nas_torch-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 11.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.9

File hashes

Hashes for nas_torch-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 4db376c5ba88ab1cd58381d0121f6c83bef1822f9537ff57730e355c9d775ee3
MD5 254c0bb17f9718442e57cf6f8e42c79d
BLAKE2b-256 37fca3592a107bf14315b159375e6a3aa7cf27839a78c5dad3c46ad40512b241

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page