Group PyTorch Parameters according to Rules
Project description
torch-parameter-groups
Group PyTorch Parameters according to Rules.
Installation
Need Python 3.6+.
pip install torch-parameter-groups
Usage
import torch
import torch.nn as nn
import torch_basic_models
import torch_parameter_groups
model = torch_basic_models.MobileNetV2.factory()
optimizer = torch_parameter_groups.optimizer_factory(
model=model,
config={
'type': 'SGD',
'kwargs': {
'momentum': 0.9,
'nesterov': True,
'weight_decay': 0.0001,
},
'rules': [
{
'param_name_list': ['weight'],
'kwargs': {
'weight_decay': 0
}
},
{
}
]
},
)
criterion = nn.CrossEntropyLoss()
output = model(torch.randn(1, 3, 224, 224))
loss = criterion(output, torch.Tensor([0]).long())
loss.backward()
optimizer.step(closure=None)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Close
Hashes for torch-parameter-groups-0.0.5.post1.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | ea9d3aa3290cd5cde56b5c4df7a24178d62ad0fd196c37cdd6e3ec3f17ae43ae |
|
MD5 | 3c3f28545650d193ede2f3b51ba7aab4 |
|
BLAKE2b-256 | 2a4a026bb9d13dd766b8ac35c6f938d01f2b85720923a7caa6d16aabdc59fbf0 |