A lightweight toolkit for neural architecture search experiments.
Project description
A "gym" style toolkit for building lightweight Neural Architecture Search systems. I know, the name is awful.
Installation
pip install gymnastics
If you want to use NAS-Bench-101, follow the instructions here.
Overview
Over the course of the final year of my PhD I worked a lot on Neural Architecture Search (NAS) and built a bunch of tooling to make my life easier. This is an effort to standardise the various features into a single framework and provide a "gym" style toolkit for comparing various algorithms.
The key use cases for this library are:
- test out new predictors on various NAS benchmarks
- visualise the cells/graphs of your architectures
- add new operations to NAS spaces
- add new backbones to NAS spaces
The framework revolves around three key classes:
Model
Proxy
SearchSpace
Obligatory builder pattern README example
Using gymnastics
we can very easily reconstruct NAS spaces (the goal being that it's easy to define new and exciting ones).
For example, here's how easy it is to redefine the NATS-Bench / NAS-Bench-201 search space:
from gymnastics.searchspace import SearchSpace, CellSpace, Skeleton
from gymnastics.searchspace.ops import Conv3x3, Conv1x1, AvgPool2d, Skip, Zeroize
search_space = SearchSpace(
CellSpace(
ops=[Conv3x3, Conv1x1, AvgPool2d, Skip, Zeroize], num_nodes=4, num_edges=6
),
Skeleton(
style=ResNetCIFAR,
num_blocks=[5, 5, 5],
channels_per_stage=[16, 32, 64],
strides_per_stage=[1, 2, 2],
block_expansion=1
),
)
# create an accuracy predictor
from gymnastics.proxies import NASWOT
from gymnastics.datasets import CIFAR10Loader
proxy = NASWOT()
dataset = CIFAR10Loader(path="~/datasets/cifar10", download=False)
minibatch, _ = dataset.sample_minibatch()
best_score = 0.0
best_model = None
# try out 10 random architectures and save the best one
for i in range(10):
model = search_space.sample_random_architecture()
y = model(minibatch)
score = proxy.score(model, minibatch)
if score > best_score:
best_score = score
best_model = model
best_model.show_picture()
Which prints:
Have a look in examples/
for more examples.
NAS-Benchmarks
If you have designed a new proxy for accuracy and want to test its performance, you can use the benchmarks available in benchmarks/
.
The interface to the benchmarks is exactly the same as the above example for SearchSpace
.
For example, here we score networks from the NDS ResNet space using random input data:
import torch
from gymnastics.benchmarks import NDSSearchSpace
from gymnastics.proxies import Proxy, NASWOT
search_space = NDSSearchSpace(
"~/nds/data/ResNet.json", searchspace="ResNet"
)
proxy: Proxy = NASWOT()
minibatch: torch.Tensor = torch.rand((10, 3, 32, 32))
scores = []
for _ in range(10):
model = search_space.sample_random_architecture()
scores.append(proxy.score(model, minibatch))
Additional supported operations
In addition to the standard NAS operations we include a few more exotic ones, all in various states of completion:
Op | Paper | Notes |
---|---|---|
conv | - | params: kernel size |
gconv | - | + params: group |
depthwise separable | + no extra params needed | |
mixconv | + params: needs a list of kernel_sizes | |
octaveconv | Don't have a sensible way to include this as a single operation yet | |
shift | no params needed | |
ViT | ||
Fused-MBConv | ||
Lambda |
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file gymnastics-0.0.1.tar.gz
.
File metadata
- Download URL: gymnastics-0.0.1.tar.gz
- Upload date:
- Size: 172.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.1 importlib_metadata/4.5.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.1 CPython/3.8.10
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 020fc3c5b25a9af882791c6f77f2079e5fab81d2dc65fe54efa3397d3d63ee81 |
|
MD5 | 5add39e0a5035339b8324f043d91ad09 |
|
BLAKE2b-256 | e5411df53546d0ed8b58acb317a0fb7278c9f18a5520b8f0008f8c523246afff |