Small utilities for PyTorch
Project description
Weaver PyTorch 🧶🧵
from weaver import get_classifier, get_optimizer, get_scheduler, get_transforms
from torchvision.transforms import Compose
model = get_classifier('torchvision', 'resnet50')
optim = get_optimizer(model.parameters(), name='SGD', lr=1e-3)
sched = get_scheduler(optim, name='CosineAnnealingLR', T_max=10)
transform = Compose(get_transforms([
{'name': 'RandAugment', 'num_ops': 2, 'magnitude': 10},
{"name": "ToTensor"},
{"name": "Normalize", "mean": "cifar10", "std": "cifar10"}
]))
Install
pip install weaver-pytorch-rnx0dvmdxk
API
get_classifier(src, name, **kwargs)
- weaver:
wide_resnet{depth}_{width}
,preact_resnet{depth}
- torchvision: https://pytorch.org/vision/stable/models.html
get_optimizer(params, name, **kwargs)
- PyTorch: https://pytorch.org/docs/stable/optim.html#algorithms
- AdaBelief: https://github.com/juntang-zhuang/Adabelief-Optimizer
get_scheduler(optim, name, **kwargs)
- PyTorch: https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
- Custom:
HalfCosineAnnealingLR
get_transform(name, **kwargs)
- PyTorch: https://pytorch.org/vision/stable/transforms.html
- Custom:
AllRandAugment
,Cutout
,Contain
get_transforms(kwargs_list)
- get list of transforms
Others
weaver.optimizers.exclude_wd(module: Module, skip_list=['bias', 'bn'])
weaver.optimizers.EMAModel(model: Module, alpha: float)
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Close
Hashes for weaver-pytorch-rnx0dvmdxk-0.0.2.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | a8d600b406902360df321a86169b5fbc7ef5cc619dd24afa19eac42e5d337ca9 |
|
MD5 | a2434db66af8883286b16e0a93e1c5e2 |
|
BLAKE2b-256 | 4766f5a57b499855203d5c3e0e95fcd4374d187ddd379e5316cda1e32685975f |
Close
Hashes for weaver_pytorch_rnx0dvmdxk-0.0.2-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | cd1ef52052d7da879e45cf68efd3bb24b67c85f8b5b0f6ac563e06588c058e3e |
|
MD5 | 6cc2b2c3efd5aa48d3c4f0add652682c |
|
BLAKE2b-256 | 454dc2038ad38bc095cac0575641708008b6535a8d11851fd802dcbe62f5f09c |