Skip to main content

Easy Neural Network Experiments with pytorch

Project description

A very lightweight framework on top of PyTorch with full functionality.

Logo

PyPi version YourActionName Actions Status Python versions


  • Introduces two extra multi-processing handles for blazing fast training by extending the easytorch.ETDataset class:

    • Multi-threaded data pre-loading.
    • Disk caching for faster access.
from easytorch import ETDataset

class MyDataset(ETDataset):
    def load_index(self, file):
        """(Optional) Load/Process something and add to diskcache as:
                self.diskcahe.add(file, value)"""
        """This method runs in multiple processes by default"""
    
        self.indices.append([file, 'something_extra'])

    def __getitem__(self, index):
        file = self.indices[index]
        """(Optional) Retrieve from diskcache as self.diskcache.get(file)"""

        image =  # Todo # Load file/Image. 
        label =  # Todo # Load corresponding label.
        
        # Extra preprocessing, if needed.
        # Apply transforms, if needed.

        return image, label

Installation

  1. pip install --upgrade pip
  2. Install latest pytorch and torchvision from Pytorch
  3. pip install easytorch

Lets start something simple like MNIST digit classification:

from easytorch import EasyTorch, ETTrainer, ConfusionMatrix, ETMeter
from torchvision import datasets, transforms
import torch.nn.functional as F
import torch
from examples.models import MNISTNet

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])


class MNISTTrainer(ETTrainer):
    def _init_nn_model(self):
        self.nn['model'] = MNISTNet()

    def iteration(self, batch):
        inputs, labels = batch[0].to(self.device['gpu']).float(), batch[1].to(self.device['gpu']).long()

        out = self.nn['model'](inputs)
        loss = F.nll_loss(out, labels)
        _, pred = torch.max(out, 1)
        
        meter = self.new_meter()
        meter.averages.add(loss.item(), len(inputs))
        meter.metrics['cfm'].add(pred, labels.float())

        return {'loss': loss, 'meter': meter, 'predictions': pred}

    def init_experiment_cache(self):
        self.cache['log_header'] = 'Loss|Accuracy,F1,Precision,Recall'
        self.cache.update(monitor_metric='f1', metric_direction='maximize')

    def new_meter(self):
        return ETMeter(
            cfm=ConfusionMatrix(num_classes=10)
        )


if __name__ == "__main__":
    train_dataset = datasets.MNIST('../data', train=True, download=True, transform=transform)
    val_dataset = datasets.MNIST('../data', train=False, transform=transform)

    dataloader_args = {'train': {'dataset': train_dataset}, 'validation': {'dataset': val_dataset}}
    runner = EasyTorch(phase='train', batch_size=512,
                       epochs=10, gpus=[0], dataloader_args=dataloader_args)
    runner.run(MNISTTrainer)

General use case:

1. Define your trainer

from easytorch import ETTrainer, Prf1a, ETMeter, AUCROCMetrics


class MyTrainer(ETTrainer):

    def _init_nn_model(self):
        self.nn['model'] = NeuralNetModel(out_size=self.args['num_class'])

    def iteration(self, batch):
        """Handle a single batch"""
        """Must have loss and meter"""
        meter = self.new_meter()
        ...
        return {'loss': ..., 'meter': ..., 'predictions': ...}

    def new_meter(self):
       return ETMeter(
            num_averages=1,
            prf1a=Prf1a(),
            auc=AUCROCMetrics()
        )

    def init_cache(self):
        """Will plot Loss in one plot, and Accuracy,F1_score in another."""
        self.cache['log_header'] = 'Loss|Accuracy,F1_score'
        
        """Model selection using validation set if present"""
        self.cache.update(monitor_metric='f1', metric_direction='maximize')
  • Method new_meter() returns ETMeter that takes any implementation of easytorch.meter.ETMetrics. Provided ones:
    • easytorch.metrics.Prf1a() for binary classification that computes accuracy,f1,precision,recall, overlap/IOU.
    • easytorch.metrics.ConfusionMatrix(num_classes=...) for multiclass classification that also computes global accuracy,f1,precision,recall.
    • easytorch.metrics.AUCROCMetrics for binary ROC-AUC score.

2. Define specification for your datasets:

  • EasyTorch automatically splits the training data in data_source as specified by split_ratio
    • train, validation, and test as data_source=[0.7, 0.15, 0.15], or just train, validation [0.8, 0.2] OR
    • Custom splits in txt files:
      • data_source = "/some/path/*.txt", where it tries to load 'train.txt', 'validation.txt', and 'test.txt' if phase is train only 'test.txt' if phase is test
    • data_source = "some/path/split.json", where each split key has list of files as:
      • {'train': [], 'validation' :[], 'test':[]}
    • just glob as data_source = "some/path/**/*.txt", must also provide split_ratio if phase = train

3. Entry point (say main.py)

from easytorch import EasyTorch

runner = EasyTorch(phase="train", batch_size=4, epochs=21,
                   num_channel=1, num_class=2,
                   split_ratio=[0.6, 0.2, 0.2])

**OR

python main.py -ph train -b 4 -e 50 -nc 3 -spl 0.8 0.1 0.1

Note: See easytorch.config.__ init __.py for full list of args

OR

runner = EasyTorch(yaml_config="path/toyaml/file/with/args/as/in/easytorch.confi/default_confi.yaml")

if __name__ == "__main__":
    runner.run(MyTrainer, MyDataset) # To train an individual models for each datasets. 

Run from the command line:

python main.py -ph train -b 4 -e 21 -spl 0.6 0.2 0.2

Note: directly given(EasyTorch constructor) args precedes command line arguments. See below for a list of default arguments.


All the best! Cheers! 🎉

Cite the following papers if you use this library:

@article{deepdyn_10.3389/fcomp.2020.00035,
	title        = {Dynamic Deep Networks for Retinal Vessel Segmentation},
	author       = {Khanal, Aashis and Estrada, Rolando},
	year         = 2020,
	journal      = {Frontiers in Computer Science},
	volume       = 2,
	pages        = 35,
	doi          = {10.3389/fcomp.2020.00035},
	issn         = {2624-9898}
}

@misc{2202.02382,
        Author       = {Aashis Khanal and Saeid Motevali and Rolando Estrada},
        Title        = {Fully Automated Tree Topology Estimation and Artery-Vein Classification},
        Year         = {2022},
        Eprint       = {arXiv:2202.02382},
}

Feature Higlights:

  • DataHandle that is always available, and decoupled from other modules enabling easy customization (ETDataHandle).
    • Use custom & complex data handling mechanism.
  • Simple lightweight logger/plotter.
    • Plot: set log_header = 'Loss,F1,Accuracy' to plot in same plot or set log_header = 'Loss|F1,Accuracy' to plot Loss in one plot, and F1,Accuracy in another plot.
    • Logs: all arguments/generated data will be saved in logs.json file after the experiment finishes.
  • Gradient accumulation, automatic logging/plotting, model checkpointing
  • Multiple metrics implementation at easytorch.metrics: Precision, Recall, Accuracy, Overlap, F1, ROC-AUC, Confusion matrix ..more features
  • For advanced training with multiple networks, and complex training steps, click here:
  • Implement custom metrics as here.

Project details


Release history Release notifications | RSS feed

This version

3.5.3

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

easytorch-3.5.3.tar.gz (34.6 kB view details)

Uploaded Source

File details

Details for the file easytorch-3.5.3.tar.gz.

File metadata

  • Download URL: easytorch-3.5.3.tar.gz
  • Upload date:
  • Size: 34.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.1

File hashes

Hashes for easytorch-3.5.3.tar.gz
Algorithm Hash digest
SHA256 422f01d9370dc13dfdfa1487b4a0fb65b6688e84433c29cf68a44422d1369ee8
MD5 d54eb1175f7f30bcd8d58546654b3d85
BLAKE2b-256 928695a840dd77c733e18f48b9416dac482160c3b77bd7b6ebdf5ea7ed472a84

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page