Skip to main content

No project description provided

Project description

StitchNet: Composing Neural Networks from Pre-Trained Fragments

We propose StitchNet, a novel neural network creation paradigm that stitches together fragments (one or more consecutive network layers) from multiple pre-trained neural networks. StitchNet allows the creation of high-performing neural networks without the large compute and data requirements needed under traditional model creation processes via backpropagation training. We leverage Centered Kernel Alignment (CKA) as a compatibility measure to efficiently guide the selection of these fragments in composing a network for a given task tailored to specific accuracy needs and computing resource constraints. We then show that these fragments can be stitched together to create neural networks with comparable accuracy to traditionally trained networks at a fraction of computing resource and data requirements. Finally, we explore a novel on-the-fly personalized model creation and inference application enabled by this new paradigm.

Installation

pip install stitchnet

Usage

import stitchnet

# load the beans dataset from huggingface
from stitchnet import load_hf_dataset
dataset_train, dataset_val = load_hf_dataset('beans', train_split='validation', val_split='test', label_column='labels', seed=47)

# prepare stitching dataset
import numpy as np
from tqdm import tqdm
stitching_dataset = np.vstack([x['pixel_values'] for x in tqdm(dataset_train.select(range(32)))])

# generate stitchnets 1 sample
score,net = generate(stitching_dataset, threshold=0, totalThreshold=0, maxDepth=10, K=2, sample=True)

# generate multiple stitchnets
generator = generate(stitching_dataset, threshold=0.8, totalThreshold=0.8, maxDepth=10, K=2, sample=False)
for score,net in generator:
    print(score,net)

# print macs and params
net.get_macs_params() # {'macs': 4488343528.0, 'params': 25653096}

# save onnx
net.save_onnx('./_data/net') # saving to ./_results/net.onnx
    
# draw the stitchnet
net.draw_svg('./_data/net') # saving to ./_results/net.svg

# train a classifier
net.fit(dataset_train, label_column="labels")

# use it for prediction
net.predict_files(['./_results/test.jpg']) # [{'score': [0.8, 0.2, 0.0], 'label': 0}]

# evaluate with validation dataset
net.evaluate_dataset(dataset_val, label_column='labels') # {'accuracy': 0.7421875}

CUDA

See https://pytorch.org/get-started/previous-versions/ to install appropriate version. For example

# CUDA 11.6
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu116

Experiment Notebooks

  1. Download dogs and cats dataset from https://www.kaggle.com/c/dogs-vs-cats/data and put train data in _data/dogs_cats/raw/train folder
  2. See 00_prepare_data.ipynb to split the images into cats and dogs folder
  3. See 01_download_networks.ipynb to download the pretrained networks from Torchvision
  4. See 02_generate_fragments.ipynb to generate fragments from the pretrained networks
  5. See 03_stitchnet.ipynb to generate stitchnets
  6. See 04_render_graph.ipynb to create svg images of the network graphs using netron
  7. See 05_eval_original_networks.ipynb for evaluating the original pretrained networks
  8. See 06_finetuning.ipynb to generate the finetuning result
  9. See 07_ensemble.ipynb to generate the ensemble result
  10. See 08_number_of_samples_for_stitching.ipynb for experimenting with varying number of samples to use when stitching
  11. See 09_plot_results.ipynb plot figures of the results for the paper

Installation using conda

Create a new conda env

conda create -n stitchnet python=3.10

Activate stitchnet conda env

conda activate stitchnet

For conda and NVIDIA gpu, please also install for CUDA runtime on onnx

conda install -c conda-forge cudnn

Install poetry

curl -sSL https://install.python-poetry.org | python3 -

Install dependencies using poetry

poetry install

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

stitchnet-0.2.2.tar.gz (26.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

stitchnet-0.2.2-py3-none-any.whl (30.1 kB view details)

Uploaded Python 3

File details

Details for the file stitchnet-0.2.2.tar.gz.

File metadata

  • Download URL: stitchnet-0.2.2.tar.gz
  • Upload date:
  • Size: 26.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.1.14 CPython/3.9.7 Linux/5.13.0-28-generic

File hashes

Hashes for stitchnet-0.2.2.tar.gz
Algorithm Hash digest
SHA256 538e8a1896aa8fca8a3e787cef0c09d041753cb229348a598bd47cfbbcec273b
MD5 a8996fe071f6c0b7729cc77dc71cf5df
BLAKE2b-256 59818588c0ddab755e8358d892c442e3c569be8ee73122f94484c21cb9f62a0e

See more details on using hashes here.

File details

Details for the file stitchnet-0.2.2-py3-none-any.whl.

File metadata

  • Download URL: stitchnet-0.2.2-py3-none-any.whl
  • Upload date:
  • Size: 30.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.1.14 CPython/3.9.7 Linux/5.13.0-28-generic

File hashes

Hashes for stitchnet-0.2.2-py3-none-any.whl
Algorithm Hash digest
SHA256 26f5a9b86b3708d59037bd15e8b2507d28a9d4b3507e8fbac19e9e55a0c6fcfc
MD5 69a0c00564d01a29784186af09c51120
BLAKE2b-256 646f293e88f2808d570f6f9134d4cb4ec2a24ef7eb2301f0ced342adfcb764f6

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page