No project description provided
Project description
StitchNet: Composing Neural Networks from Pre-Trained Fragments
Installation
pip install stitchnet
Usage
import stitchnet
# prepare stitching data D
from stitchnet import load_hf_dataset
# load the beans dataset from huggingface
dataset_train, dataset_val = load_hf_dataset('beans', train_split='validation', val_split='test', label_column='labels', seed=47)
# generate stitchnets
import numpy as np
from tqdm import tqdm
stitching_dataset = np.vstack([x['pixel_values'] for x in tqdm(dataset_train.select(range(32)))])
score,net = generate(stitching_dataset, threshold=0.8, totalThreshold=0, maxDepth=10, K=2, sample=True)
# print macs and params
net.get_macs_params() # {'macs': 4488343528.0, 'params': 25653096}
# save onnx
net.save_onnx('./_data/net') # saving to ./_results/net.onnx
# draw the stitchnet
net.draw_svg('./_data/net') # saving to ./_results/net.svg
# train a classifier
net.fit(dataset_train, label_column="labels")
# use it for prediction
net.predict_files(['./_results/test.jpg']) # [{'score': [0.8, 0.2, 0.0], 'label': 0}]
# evaluate with validation dataset
net.evaluate_dataset(dataset_val, label_column='labels') # {'accuracy': 0.7421875}
CUDA
See https://pytorch.org/get-started/previous-versions/ to install appropriate version. For example
# CUDA 11.6
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu116
Experiment Notebooks
- Download dogs and cats dataset from https://www.kaggle.com/c/dogs-vs-cats/data and put train data in _data/dogs_cats/raw/train folder
- See 00_prepare_data.ipynb to split the images into cats and dogs folder
- See 01_download_networks.ipynb to download the pretrained networks from Torchvision
- See 02_generate_fragments.ipynb to generate fragments from the pretrained networks
- See 03_stitchnet.ipynb to generate stitchnets
- See 04_render_graph.ipynb to create svg images of the network graphs using netron
- See 05_eval_original_networks.ipynb for evaluating the original pretrained networks
- See 06_finetuning.ipynb to generate the finetuning result
- See 07_ensemble.ipynb to generate the ensemble result
- See 08_number_of_samples_for_stitching.ipynb for experimenting with varying number of samples to use when stitching
- See 09_plot_results.ipynb plot figures of the results for the paper
Installation using conda
Create a new conda env
conda create -n stitchnet python=3.10
Activate stitchnet conda env
conda activate stitchnet
For conda and NVIDIA gpu, please also install for CUDA runtime on onnx
conda install -c conda-forge cudnn
Install poetry
curl -sSL https://install.python-poetry.org | python3 -
Install dependencies using poetry
poetry install
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
stitchnet-0.2.0.tar.gz
(25.6 kB
view hashes)
Built Distribution
stitchnet-0.2.0-py3-none-any.whl
(29.6 kB
view hashes)
Close
Hashes for stitchnet-0.2.0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 70868b362a4b72613116b4de8e58396ba0c1c4c0197f8d9a2006a7d6a1b4a541 |
|
MD5 | 84df58b5008ecea169487c8e347fb7ac |
|
BLAKE2b-256 | 4ca3c50ff162d94869eb9786846b662d466d6adbcd4c3167d79d1e332c64fd21 |