Easy high-level library for training neural networks in PyTorch.
Project description
___/\\\\\\\\\_______/\\\\\\\\\_________/\\\\\\\\\\\\__/\\\________/\\\_____/\\\\\\\\\\\___
__/\\\\\\\\\\\\\___/\\\///////\\\_____/\\\//////////__\/\\\_______\/\\\___/\\\/////////\\\_
__/\\\/////////\\\_\/\\\_____\/\\\____/\\\_____________\/\\\_______\/\\\__\//\\\______\///__
_\/\\\_______\/\\\_\/\\\\\\\\\\\/____\/\\\____/\\\\\\\_\/\\\_______\/\\\___\////\\\_________
_\/\\\\\\\\\\\\\\\_\/\\\//////\\\____\/\\\___\/////\\\_\/\\\_______\/\\\______\////\\\______
_\/\\\/////////\\\_\/\\\____\//\\\___\/\\\_______\/\\\_\/\\\_______\/\\\_________\////\\\___
_\/\\\_______\/\\\_\/\\\_____\//\\\__\/\\\_______\/\\\_\//\\\______/\\\___/\\\______\//\\\__
_\/\\\_______\/\\\_\/\\\______\//\\\_\//\\\\\\\\\\\\/___\///\\\\\\\\\/___\///\\\\\\\\\\\/__
_\///________\///__\///________\///___\////////////_______\/////////_______\///////////___
Argus is easy-to-use flexible library for training neural networks in PyTorch.
Documentation
https://pytorch-argus.readthedocs.io
Installation
From pip:
pip install pytorch-argus
From source:
pip install -U git+https://github.com/lRomul/argus.git
Example
Simple image classification example with create_model
from pytorch-image-models:
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, ToTensor, Normalize
import timm
from argus import Model
from argus.callbacks import MonitorCheckpoint, EarlyStopping, ReduceLROnPlateau
def get_data_loaders(batch_size):
data_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])
train_mnist_dataset = MNIST(download=True, root="mnist_data",
transform=data_transform, train=True)
val_mnist_dataset = MNIST(download=False, root="mnist_data",
transform=data_transform, train=False)
train_loader = DataLoader(train_mnist_dataset,
batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_mnist_dataset,
batch_size=batch_size * 2, shuffle=False)
return train_loader, val_loader
class TimmModel(Model):
nn_module = timm.create_model
if __name__ == "__main__":
train_loader, val_loader = get_data_loaders(batch_size=256)
params = {
'nn_module': {
'model_name': 'tf_efficientnet_b0_ns',
'pretrained': False,
'num_classes': 10,
'in_chans': 1,
'drop_rate': 0.2,
'drop_path_rate': 0.2
},
'optimizer': ('Adam', {'lr': 0.01}),
'loss': 'CrossEntropyLoss',
'device': 'cuda'
}
model = TimmModel(params)
callbacks = [
MonitorCheckpoint(dir_path='mnist', monitor='val_accuracy', max_saves=3),
EarlyStopping(monitor='val_accuracy', patience=9),
ReduceLROnPlateau(monitor='val_accuracy', factor=0.5, patience=3)
]
model.fit(train_loader,
val_loader=val_loader,
num_epochs=50,
metrics=['accuracy'],
callbacks=callbacks,
metrics_on_train=True)
More examples you can find here.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
pytorch-argus-0.1.2.tar.gz
(34.1 kB
view hashes)
Built Distribution
Close
Hashes for pytorch_argus-0.1.2-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 7e5e331609f541a6f0431b7a870312d9088e262b853d8a1e28ae3a1e605463d4 |
|
MD5 | c1def1957d16b4031bb39e3a6368851c |
|
BLAKE2b-256 | 3f7c8018cef0efb1a3b1ca5ec8d01d4ebfb46533d8af5c010dc3365b069d3315 |