A CNN Channel Pruning System
Project description
Torch Model Compression(tomoco)
This is a Deep Learning Pruning Package. This package allows you to prune layers of Convolution Layers based on L1 or L2 Norm. Tomoco Package
Package install:
pip install tomoco
Channel Pruning based on Norm:
from tomoco import pruner
import timm
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
class config:
lr = 0.001
n_classes = 10 # Intended for output classes
epochs = 5 # Set no. of training epochs
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Pick the device availble
batch_size = 64 # Set batch size
optim = 0
training =1 # Set training to 1 if you would like to train post to prune
criterion = nn.CrossEntropyLoss() # Set your criterion here
train_dataset = CIFAR10(root='data/', download=True, transform=transforms.ToTensor())
valid_dataset = CIFAR10(root='data/', download=True,train=False, transform=transforms.ToTensor())
# define the data loaders
train_loader = DataLoader(dataset=train_dataset, batch_size=config.batch_size, shuffle=True)
valid_loader = DataLoader(dataset=valid_dataset, batch_size=config.batch_size, shuffle=False)
#Use a cutom model or pull a model from a repository
res50 = timm.create_model("resnet50", pretrained=True).to(config.device)
config.optim = torch.optim.Adam(res50.parameters(), config.lr=0.001, amsgrad=True)
pruner(res50,"res50", config, (3,64,64), "L1", 0.15, train_loader, valid_loader)
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
tomoco-0.0.11.tar.gz
(5.9 kB
view details)
Built Distribution
File details
Details for the file tomoco-0.0.11.tar.gz
.
File metadata
- Download URL: tomoco-0.0.11.tar.gz
- Upload date:
- Size: 5.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.9
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | d80d4f77b3bcedffbdbc54157d2ea9b56f885395a0ec46a19f4e04eb2e3d747d |
|
MD5 | b97cff9bb6177b04a84cf072e011f2b4 |
|
BLAKE2b-256 | 0052b788961a2751088c6bc14f5748d78b37a704a2291d0b000ceb9cbaba5862 |
File details
Details for the file tomoco-0.0.11-py3-none-any.whl
.
File metadata
- Download URL: tomoco-0.0.11-py3-none-any.whl
- Upload date:
- Size: 5.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.9
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | b4601ca505cd241f7e2b62e7decfb42d95db7bbcf78aef0d80ca457734f84b1f |
|
MD5 | 2e696c4dc8ab53cf3ec641e466a39639 |
|
BLAKE2b-256 | 82b9d863ab05fe60985204307a9654858da983a3e3aa8119e5ae632cef4fd0e3 |