Skip to main content

Easy Neural Network Experiments with pytorch

Project description

A very lightweight framework on top of PyTorch with full functionality.

Just one way of doing things means no learning curve.

Logo

PyPi version YourActionName Actions Status Python versions


Installation

  1. pip install --upgrade pip
  2. Install latest pytorch and torchvision from Pytorch
  3. pip install easytorch

Let's start with something simple like MNIST digit classification:

from easytorch import EasyTorch, ETRunner, ConfusionMatrix, ETMeter
from torchvision import datasets, transforms
import torch.nn.functional as F
import torch
from examples.models import MNISTNet

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])


class MNISTTrainer(ETRunner):
    def _init_nn_model(self):
        self.nn['model'] = MNISTNet()

    def iteration(self, batch):
        inputs, labels = batch[0].to(self.device['gpu']).float(), batch[1].to(self.device['gpu']).long()

        out = self.nn['model'](inputs)
        loss = F.nll_loss(out, labels)
        _, pred = torch.max(out, 1)

        meter = self.new_meter()
        meter.averages.add(loss.item(), len(inputs))
        meter.metrics['cfm'].add(pred, labels.float())

        return {'loss': loss, 'meter': meter, 'predictions': pred}

    def init_experiment_cache(self):
        self.cache['log_header'] = 'Loss|Accuracy,F1,Precision,Recall'
        self.cache.update(monitor_metric='f1', metric_direction='maximize')

    def new_meter(self):
        return ETMeter(
            cfm=ConfusionMatrix(num_classes=10)
        )


if __name__ == "__main__":
    train_dataset = datasets.MNIST('../data', train=True, download=True, transform=transform)
    val_dataset = datasets.MNIST('../data', train=False, transform=transform)

    dataloader_args = {'train': {'dataset': train_dataset}, 'validation': {'dataset': val_dataset}}
    runner = EasyTorch(phase='train', batch_size=512,
                       epochs=10, gpus=[0], dataloader_args=dataloader_args)
    runner.run(MNISTTrainer)

Run as:

python script.py -ph train -b 512 -e 10 -gpus 0

... with 20+ useful options. Check here for full list.


General use case:

1. Define your trainer

from easytorch import ETRunner, Prf1a, ETMeter, AUCROCMetrics


class MyTrainer(ETRunner):

    def _init_nn_model(self):
        self.nn['model'] = NeuralNetModel(out_size=self.conf['num_class'])

    def iteration(self, batch):
        """Handle a single batch"""
        """Must have loss and meter"""
        meter = self.new_meter()
        ...
        return {'loss': ..., 'meter': ..., 'predictions': ...}

    def new_meter(self):
        return ETMeter(
            num_averages=1,
            prf1a=Prf1a(),
            auc=AUCROCMetrics()
        )

    def init_cache(self):
        """Will plot Loss in one plot, and Accuracy,F1_score in another."""
        self.cache['log_header'] = 'Loss|Accuracy,F1_score'

        """Model selection using validation set if present"""
        self.cache.update(monitor_metric='f1', metric_direction='maximize')
  • Method new_meter() returns ETMeter that takes any implementation of easytorch.meter.ETMetrics. Provided ones:
    • easytorch.metrics.Prf1a() for binary classification that computes accuracy,f1,precision,recall, overlap/IOU.
    • easytorch.metrics.ConfusionMatrix(num_classes=...) for multiclass classification that also computes global accuracy,f1,precision,recall.
    • easytorch.metrics.AUCROCMetrics for binary ROC-AUC score.

2. Define specification for your datasets:

  • EasyTorch automatically splits the training data in data_source as specified by split_ratio(-spl or --split-ratio 0.7, 0.15, 0.15, for train validation and test portion) OR Custom splits in
    1. Text files:
      • data_source = "/some/path/*.txt", where it looks for 'train.txt', 'validation.txt', and 'test.txt' if phase is train, and only 'test.txt' if phase is test
    2. Json files:
      • data_source = "some/path/split.json", where each split key has list of files as {'train': [], ' validation' :[], 'test':[]}
    3. Just glob as data_source = "some/path/**/*.txt", must also provide split_ratio if phase = train
from easytorch import ETDataset


class MyDataset(ETDataset):
    def load_index(self, file):
        """(Optional) Load/Process something and add to diskcache as:
                self.diskcahe.add(file, value)"""
        """This method runs in multiple processes by default"""

        self.indices.append([file, 'something_extra'])

    def __getitem__(self, index):
        file = self.indices[index]
        """(Optional) Retrieve from diskcache as self.diskcache.get(file)"""

        image =  # Todo # Load file/Image. 
        label =  # Todo # Load corresponding label.

        # Extra preprocessing, if needed.
        # Apply transforms, if needed.

        return image, label

3. Entry point (say main.py)

from easytorch import EasyTorch

runner = EasyTorch(phase="train", batch_size=4, epochs=21,
                   num_channel=1, num_class=2,
                   split_ratio=[0.6, 0.2, 0.2])

All the best! Cheers! 🎉

Cite the following papers if you use this library:

@article{deepdyn_10.3389/fcomp.2020.00035,
	title        = {Dynamic Deep Networks for Retinal Vessel Segmentation},
	author       = {Khanal, Aashis and Estrada, Rolando},
	year         = 2020,
	journal      = {Frontiers in Computer Science},
	volume       = 2,
	pages        = 35,
	doi          = {10.3389/fcomp.2020.00035},
	issn         = {2624-9898}
}

@misc{2202.02382,
        Author       = {Aashis Khanal and Saeid Motevali and Rolando Estrada},
        Title        = {Fully Automated Tree Topology Estimation and Artery-Vein Classification},
        Year         = {2022},
        Eprint       = {arXiv:2202.02382},
}

Feature Higlights:

  • DataHandle that is always available, and decoupled from other modules enabling easy customization (ETDataHandle).
    • Use custom & complex data handling mechanism.
  • Simple lightweight logger/plotter.
    • Plot: set log_header = 'Loss,F1,Accuracy' to plot in same plot or set log_header = 'Loss|F1,Accuracy' to plot Loss in one plot, and F1,Accuracy in another plot.
    • Logs: all arguments/generated data will be saved in logs.json file after the experiment finishes.
  • Gradient accumulation, automatic logging/plotting, model checkpointing, save everything.
  • Multiple metrics implementation at easytorch.metrics: Precision, Recall, Accuracy, Overlap, F1, ROC-AUC, Confusion matrix
  • For advanced training with multiple networks, and complex training steps, click here:
  • Implement custom metrics as here.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

easytorch-3.5.8.tar.gz (33.6 kB view details)

Uploaded Source

File details

Details for the file easytorch-3.5.8.tar.gz.

File metadata

  • Download URL: easytorch-3.5.8.tar.gz
  • Upload date:
  • Size: 33.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.2

File hashes

Hashes for easytorch-3.5.8.tar.gz
Algorithm Hash digest
SHA256 b522d30082d7ff088ca2aa5ff933e507627f7ea2d541c38ee3f8709c58f17788
MD5 257ffcc272c560b15d2cf70b1594badf
BLAKE2b-256 cbf1e08124cb2367d9fe5f849583f4dce88d4e086649d80efc182b6a01f642ae

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page