Skip to main content

Easy high-level library for training neural networks in PyTorch.

Project description


Argus is easy-to-use flexible library for training neural networks in PyTorch.


From pip:

pip install pytorch-argus

From source:

git clone
cd argus
python install


Simple image classification example:

import torch
from torch import nn
import torch.nn.functional as F
from mnist_utils import get_data_loaders

from argus import Model, load_model
from argus.callbacks import MonitorCheckpoint, EarlyStopping, ReduceLROnPlateau

class Net(nn.Module):
    def __init__(self, n_classes, p_dropout=0.5):
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d(p=p_dropout)
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, n_classes)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x,
        x = self.fc2(x)
        return x

class MnistModel(Model):
    nn_module = Net
    optimizer = torch.optim.SGD
    loss = torch.nn.CrossEntropyLoss

if __name__ == "__main__":
    train_loader, val_loader = get_data_loaders()

    params = {
        'nn_module': {'n_classes': 10, 'p_dropout': 0.1},
        'optimizer': {'lr': 0.01},
        'device': 'cpu'

    model = MnistModel(params)

    callbacks = [
        MonitorCheckpoint(dir_path='mnist', monitor='val_accuracy', max_saves=3),
        EarlyStopping(monitor='val_accuracy', patience=9),
        ReduceLROnPlateau(monitor='val_accuracy', factor=0.5, patience=3)

    del model
    model = load_model('mnist/model-last.pth')

Use Argus with make_model from pytorch-cnn-finetune.

from cnn_finetune import make_model
from argus import Model

class CnnFinetune(Model):
    nn_module = make_model

params = {
    'nn_module': {
        'model_name': 'resnet18',
        'num_classes': 10,
        'pretrained': False,
        'input_size': (256, 256)
    'optimizer': ('Adam', {'lr': 0.01}),
    'loss': 'CrossEntropyLoss',
    'device': 'cuda'

model = CnnFinetune(params)

You can find other examples here.

Kaggle solutions

  1. 1st place solution for Freesound Audio Tagging 2019 (mel-spectrograms, mixed precision training with Apex)
  2. 14th place solution for TGS Salt Identification Challenge (segmentation, MeanTeacher)
  3. 50th place solution for Quick, Draw! Doodle Recognition Challenge (gradient accumulation, training on 50M images)
  4. 66th place solution for Kaggle Airbus Ship Detection Challenge (instance segmentation)
  5. Solution for Humpback Whale Identification (metric learning: arcface, center loss)
  6. Solution for VSB Power Line Fault Detection (1d conv)

Project details

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Files for pytorch-argus, version 0.0.9
Filename, size File type Python version Upload date Hashes
Filename, size pytorch_argus-0.0.9-py3-none-any.whl (18.6 kB) File type Wheel Python version py3 Upload date Hashes View
Filename, size pytorch-argus-0.0.9.tar.gz (14.3 kB) File type Source Python version None Upload date Hashes View

Supported by

Pingdom Pingdom Monitoring Google Google Object Storage and Download Analytics Sentry Sentry Error logging AWS AWS Cloud computing DataDog DataDog Monitoring Fastly Fastly CDN DigiCert DigiCert EV certificate StatusPage StatusPage Status page