Skip to main content

auto-scheduler for pytorch task.

Project description

Fed Flow

Description

auto-scheduler for parallel task.

Install

pip instal fedflow==0.2.0

Usage

import os

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision.datasets import mnist
from torchvision.transforms import transforms

from fedflow import Task, TaskGroup, FedFlow
from fedflow.config import Config
from fedflow.utils.trainer.supervised_trainer import SupervisedTrainer


Config.set_property("debug", True)
Config.set_property("scheduler.interval", 2)


datasets_path = os.path.join(os.path.abspath("."), "datasets")


class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4 * 4 * 50, 500)
        self.fc2 = nn.Linear(500, 10)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4 * 4 * 50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


class MnistTask(Task):

    def __init__(self, id, datasets_path):
        super(MnistTask, self).__init__(task_id=id, estimate_memory="2.5GB", estimate_cuda_memory="1200MB")
        self.datasets_path = datasets_path

    def load(self):
        self.mnist_dataset = mnist.MNIST(root=self.datasets_path,
                                         download=True,
                                         train=True,
                                         transform=transforms.Compose([
                                             transforms.ToTensor(),
                                             transforms.Normalize((0.13066062,), (0.30810776,))
                                         ]))
        self.test_dataset = mnist.MNIST(root=self.datasets_path,
                                        download=True,
                                        train=False,
                                        transform=transforms.Compose([
                                             transforms.ToTensor(),
                                             transforms.Normalize((0.13066062,), (0.30810776,))
                                         ]))
        self.mnist_model = Net()
        self.mnist_optim = optim.SGD(self.mnist_model.parameters(), lr=0.01)
        self.criterion = nn.CrossEntropyLoss()

    def train(self, device) -> dict:
        self.mnist_model = self.mnist_model.to(self.device)
        trainer = SupervisedTrainer(self.mnist_model, self.mnist_optim, self.criterion, epoch=50, device=self.device,
                                    console_out="console.out")
        trainer.mount_dataset(self.mnist_dataset, self.test_dataset, batch_size=32)
        return trainer.train()


def print_result(group: TaskGroup):
    print("%2s %9s %9s" % ("ID", "train acc", " val acc "))
    for i in range(20):
        task = group.get_task(i)
        result = task.result
        print("%02d  %6.2f%%   %6.2f%%" % (i, result["train_acc"], result["val_acc"]))


if __name__ == "__main__":
    # Download mnist datasets
    mnist.MNIST(root=datasets_path, download=True)
    group = TaskGroup("mnist")
    for i in range(20):
        group.add_task(MnistTask(i, datasets_path))
    with FedFlow() as flow:
        flow.execute(group)

    print_result(group)

Features

  • add subprocess tracker
  • add GPUs load balancing
  • add methods to kill specified subprocess/task

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

fedflow-0.2.1.tar.gz (24.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

fedflow-0.2.1-py3-none-any.whl (30.1 kB view details)

Uploaded Python 3

File details

Details for the file fedflow-0.2.1.tar.gz.

File metadata

  • Download URL: fedflow-0.2.1.tar.gz
  • Upload date:
  • Size: 24.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.5.0 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.8.10

File hashes

Hashes for fedflow-0.2.1.tar.gz
Algorithm Hash digest
SHA256 ed46abb997cc203c139f63c73ed5bce882118728ef92e4ce55074045b950cc13
MD5 02e3e4200c231b9abea89110a457cc78
BLAKE2b-256 1ef894a489ee7af18237eb2305e340acf6791a0e312b68d6d90f16fd1b49d6fa

See more details on using hashes here.

File details

Details for the file fedflow-0.2.1-py3-none-any.whl.

File metadata

  • Download URL: fedflow-0.2.1-py3-none-any.whl
  • Upload date:
  • Size: 30.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.5.0 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.8.10

File hashes

Hashes for fedflow-0.2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 7eb72889b0e3fefc4cd670911a1baad0241da453d4e0b70906f8cecb04cc3e96
MD5 a7a58f765fb30976d6453135e53f449b
BLAKE2b-256 83de5c4d93ab7519b6759ce38260631389f87cb640519f51153df451da78efba

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page