Skip to main content

Easy Neural Network Experiments with pytorch

Project description

A very lightweight framework on top of PyTorch with full functionality.

Logo

PyPi version YourActionName Actions Status Python versions


  • Introduces two extra multi-processing handles for blazing fast training by extending the easytorch.ETDataset class:

    • Multi-threaded data pre-loading.
    • Disk caching for faster access.
from easytorch import ETDataset

class MyDataset(ETDataset):
    def load_index(self, file):
        """(Optional) Load/Process something and add to diskcache as:
                self.diskcahe.add(file, value)"""
        """This method runs in multiple processes by default"""
    
        self.indices.append([file, 'something_extra'])

    def __getitem__(self, index):
        file = self.indices[index]
        """(Optional) Retrieve from diskcache as self.diskcache.get(file)"""

        image =  # Todo # Load file/Image. 
        label =  # Todo # Load corresponding label.
        
        # Extra preprocessing, if needed.
        # Apply transforms, if needed.

        return image, label

Installation

  1. pip install --upgrade pip
  2. Install latest pytorch and torchvision from Pytorch
  3. pip install easytorch

Lets start something simple like MNIST digit classification:

from easytorch import EasyTorch, ETTrainer, ConfusionMatrix, ETMeter
from torchvision import datasets, transforms
import torch.nn.functional as F
import torch
from examples.models import MNISTNet

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])


class MNISTTrainer(ETTrainer):
    def _init_nn_model(self):
        self.nn['model'] = MNISTNet()

    def iteration(self, batch):
        inputs, labels = batch[0].to(self.device['gpu']).float(), batch[1].to(self.device['gpu']).long()

        out = self.nn['model'](inputs)
        loss = F.nll_loss(out, labels)
        _, pred = torch.max(out, 1)
        
        meter = self.new_meter()
        meter.averages.add(loss.item(), len(inputs))
        meter.metrics['cfm'].add(pred, labels.float())

        return {'loss': loss, 'meter': meter, 'predictions': pred}

    def init_experiment_cache(self):
        self.cache['log_header'] = 'Loss|Accuracy,F1,Precision,Recall'
        self.cache.update(monitor_metric='f1', metric_direction='maximize')

    def new_meter(self):
        return ETMeter(
            cfm=ConfusionMatrix(num_classes=10)
        )


if __name__ == "__main__":
    train_dataset = datasets.MNIST('../data', train=True, download=True, transform=transform)
    val_dataset = datasets.MNIST('../data', train=False, transform=transform)

    dataloader_args = {'train': {'dataset': train_dataset}, 'validation': {'dataset': val_dataset}}
    runner = EasyTorch(phase='train', batch_size=512,
                       epochs=10, gpus=[0], dataloader_args=dataloader_args)
    runner.run(MNISTTrainer)

General use case:

1. Define your trainer

from easytorch import ETTrainer, Prf1a, ETMeter, AUCROCMetrics


class MyTrainer(ETTrainer):

    def _init_nn_model(self):
        self.nn['model'] = NeuralNetModel(out_size=self.args['num_class'])

    def iteration(self, batch):
        """Handle a single batch"""
        """Must have loss and meter"""
        meter = self.new_meter()
        ...
        return {'loss': ..., 'meter': ..., 'predictions': ...}

    def new_meter(self):
       return ETMeter(
            num_averages=1,
            prf1a=Prf1a(),
            auc=AUCROCMetrics()
        )

    def init_cache(self):
        """Will plot Loss in one plot, and Accuracy,F1_score in another."""
        self.cache['log_header'] = 'Loss|Accuracy,F1_score'
        
        """Model selection using validation set if present"""
        self.cache.update(monitor_metric='f1', metric_direction='maximize')
  • Method new_meter() returns ETMeter that takes any implementation of easytorch.meter.ETMetrics. Provided ones:
    • easytorch.metrics.Prf1a() for binary classification that computes accuracy,f1,precision,recall, overlap/IOU.
    • easytorch.metrics.ConfusionMatrix(num_classes=...) for multiclass classification that also computes global accuracy,f1,precision,recall.
    • easytorch.metrics.AUCROCMetrics for binary ROC-AUC score.

2. Define specification for your datasets:

  • EasyTorch automatically splits the training data in data_source as specified by split_ratio
    • train, validation, and test as data_source=[0.7, 0.15, 0.15], or just train, validation [0.8, 0.2] OR
    • Custom splits in txt files:
      • data_source = "/some/path/*.txt", where it tries to load 'train.txt', 'validation.txt', and 'test.txt' if phase is train only 'test.txt' if phase is test
    • data_source = "some/path/split.json", where each split key has list of files as:
      • {'train': [], 'validation' :[], 'test':[]}
    • just glob as data_source = "some/path/**/*.txt", must also provide split_ratio if phase = train

3. Entry point (say main.py)

from easytorch import EasyTorch

runner = EasyTorch(phase="train", batch_size=4, epochs=21,
                   num_channel=1, num_class=2,
                   split_ratio=[0.6, 0.2, 0.2])

#OR

python main.py -ph train -b 4 -e 50 -nc 3 -spl 0.8 0.1 0.1

Note: See easytorch.config.init.py for full list of args

OR

runner = EasyTorch(yaml_config="path/toyaml/file/with/args/as/in/easytorch.confi/default_confi.yaml")

if __name__ == "__main__":
    runner.run(MyTrainer, MyDataset) # To train an individual models for each datasets. 

Run from the command line:

python main.py -ph train -b 4 -e 21 -spl 0.6 0.2 0.2

Note: directly given(EasyTorch constructor) args precedes command line arguments. See below for a list of default arguments.


All the best! Cheers! 🎉

Cite the following papers if you use this library:

@article{deepdyn_10.3389/fcomp.2020.00035,
	title        = {Dynamic Deep Networks for Retinal Vessel Segmentation},
	author       = {Khanal, Aashis and Estrada, Rolando},
	year         = 2020,
	journal      = {Frontiers in Computer Science},
	volume       = 2,
	pages        = 35,
	doi          = {10.3389/fcomp.2020.00035},
	issn         = {2624-9898}
}

@misc{2202.02382,
        Author       = {Aashis Khanal and Saeid Motevali and Rolando Estrada},
        Title        = {Fully Automated Tree Topology Estimation and Artery-Vein Classification},
        Year         = {2022},
        Eprint       = {arXiv:2202.02382},
}

Feature Higlights:

  • DataHandle that is always available, and decoupled from other modules enabling easy customization (ETDataHandle).
    • Use custom & complex data handling mechanism.
  • Simple lightweight logger/plotter.
    • Plot: set log_header = 'Loss,F1,Accuracy' to plot in same plot or set log_header = 'Loss|F1,Accuracy' to plot Loss in one plot, and F1,Accuracy in another plot.
    • Logs: all arguments/generated data will be saved in logs.json file after the experiment finishes.
  • Gradient accumulation, automatic logging/plotting, model checkpointing
  • Multiple metrics implementation at easytorch.metrics: Precision, Recall, Accuracy, Overlap, F1, ROC-AUC, Confusion matrix ..more features
  • For advanced training with multiple networks, and complex training steps, click here:
  • Implement custom metrics as here.

Project details


Release history Release notifications | RSS feed

This version

3.5.2

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

easytorch-3.5.2.tar.gz (34.6 kB view hashes)

Uploaded Source

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page