Argus is a lightweight library for training neural networks in PyTorch.
Project description
Argus is a lightweight library for training neural networks in PyTorch.
Documentation
https://pytorch-argus.readthedocs.io
Installation
Requirements:
- torch>=2.0.0
From pip:
pip install pytorch-argus
From source:
pip install -U git+https://github.com/lRomul/argus.git@dev
Example
Simple image classification example with create_model from pytorch-image-models:
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize
import timm
import argus
from argus.callbacks import MonitorCheckpoint, EarlyStopping, ReduceLROnPlateau
def get_data_loaders(batch_size):
data_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])
train_mnist_dataset = MNIST(download=True, root="mnist_data",
transform=data_transform, train=True)
val_mnist_dataset = MNIST(download=False, root="mnist_data",
transform=data_transform, train=False)
train_loader = DataLoader(train_mnist_dataset,
batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_mnist_dataset,
batch_size=batch_size * 2, shuffle=False)
return train_loader, val_loader
class TimmModel(argus.Model):
nn_module = timm.create_model
if __name__ == "__main__":
train_loader, val_loader = get_data_loaders(batch_size=256)
params = {
'nn_module': {
'model_name': 'tf_efficientnet_b0_ns',
'pretrained': False,
'num_classes': 10,
'in_chans': 1,
'drop_rate': 0.2,
'drop_path_rate': 0.2
},
'optimizer': ('Adam', {'lr': 0.01}),
'loss': 'CrossEntropyLoss',
'device': 'cuda'
}
model = TimmModel(params)
callbacks = [
MonitorCheckpoint(dir_path='mnist', monitor='val_accuracy', max_saves=3),
EarlyStopping(monitor='val_accuracy', patience=9),
ReduceLROnPlateau(monitor='val_accuracy', factor=0.5, patience=3)
]
model.fit(train_loader,
val_loader=val_loader,
num_epochs=50,
metrics=['accuracy'],
callbacks=callbacks,
metrics_on_train=True)
More examples you can find here. Additional guides on how to customize and use argus component can be found in Guides section.
Why this name, Argus?
The library name is a reference to a planet from World of Warcraft. Argus is the original homeworld of the eredar (a race of supremely talented magic-wielders), now located within the Twisting Nether. It was once described as a utopian world whose inhabitants were both vastly intelligent and highly gifted in magic. It has since been twisted by demonic, chaotic energies and became the stronghold and homeworld of the Burning Legion.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file pytorch_argus-1.1.0.tar.gz.
File metadata
- Download URL: pytorch_argus-1.1.0.tar.gz
- Upload date:
- Size: 32.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.12.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
09bea32e826f0e922d362be99d924329a054d30d7fb83207bc6e41d37ac71643
|
|
| MD5 |
33683fa9e1b1c49e157cfe49cf701bbd
|
|
| BLAKE2b-256 |
ca24d87967dcd1f8b0b5d23e16de908f308d94ccb82fd79c1fe36762088472f7
|
File details
Details for the file pytorch_argus-1.1.0-py3-none-any.whl.
File metadata
- Download URL: pytorch_argus-1.1.0-py3-none-any.whl
- Upload date:
- Size: 34.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.12.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
175ef698ecff2794746a90917311280d1a7def478ef83049495678e96d656856
|
|
| MD5 |
0f79ca88621ba0b0c2ab5477e5f1819b
|
|
| BLAKE2b-256 |
18b3be52f9a87f76d2e1454cc40a523705b2831fc6610ccc227646910e012745
|