auto-scheduler for pytorch task.
Project description
Fed Flow
Description
auto-scheduler for parallel task.
Install
pip instal fedflow==0.2.0
Usage
import os
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision.datasets import mnist
from torchvision.transforms import transforms
from fedflow import Task, TaskGroup, FedFlow
from fedflow.config import Config
from fedflow.utils.trainer.supervised_trainer import SupervisedTrainer
Config.set_property("debug", True)
Config.set_property("scheduler.interval", 2)
datasets_path = os.path.join(os.path.abspath("."), "datasets")
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.Linear(4 * 4 * 50, 500)
self.fc2 = nn.Linear(500, 10)
self.softmax = nn.Softmax(dim=1)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4 * 4 * 50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
class MnistTask(Task):
def __init__(self, id, datasets_path):
super(MnistTask, self).__init__(task_id=id, estimate_memory="2.5GB", estimate_cuda_memory="1200MB")
self.datasets_path = datasets_path
def load(self):
self.mnist_dataset = mnist.MNIST(root=self.datasets_path,
download=True,
train=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.13066062,), (0.30810776,))
]))
self.test_dataset = mnist.MNIST(root=self.datasets_path,
download=True,
train=False,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.13066062,), (0.30810776,))
]))
self.mnist_model = Net()
self.mnist_optim = optim.SGD(self.mnist_model.parameters(), lr=0.01)
self.criterion = nn.CrossEntropyLoss()
def train(self, device) -> dict:
self.mnist_model = self.mnist_model.to(self.device)
trainer = SupervisedTrainer(self.mnist_model, self.mnist_optim, self.criterion, epoch=50, device=self.device,
console_out="console.out")
trainer.mount_dataset(self.mnist_dataset, self.test_dataset, batch_size=32)
return trainer.train()
def print_result(group: TaskGroup):
print("%2s %9s %9s" % ("ID", "train acc", " val acc "))
for i in range(20):
task = group.get_task(i)
result = task.result
print("%02d %6.2f%% %6.2f%%" % (i, result["train_acc"], result["val_acc"]))
if __name__ == "__main__":
# Download mnist datasets
mnist.MNIST(root=datasets_path, download=True)
group = TaskGroup("mnist")
for i in range(20):
group.add_task(MnistTask(i, datasets_path))
with FedFlow() as flow:
flow.execute(group)
print_result(group)
Features
- add subprocess tracker
- add GPUs load balancing
- add methods to kill specified subprocess/task
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
fedflow-0.2.1.tar.gz
(24.2 kB
view details)
Built Distribution
fedflow-0.2.1-py3-none-any.whl
(30.1 kB
view details)
File details
Details for the file fedflow-0.2.1.tar.gz
.
File metadata
- Download URL: fedflow-0.2.1.tar.gz
- Upload date:
- Size: 24.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.5.0 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.8.10
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | ed46abb997cc203c139f63c73ed5bce882118728ef92e4ce55074045b950cc13 |
|
MD5 | 02e3e4200c231b9abea89110a457cc78 |
|
BLAKE2b-256 | 1ef894a489ee7af18237eb2305e340acf6791a0e312b68d6d90f16fd1b49d6fa |
File details
Details for the file fedflow-0.2.1-py3-none-any.whl
.
File metadata
- Download URL: fedflow-0.2.1-py3-none-any.whl
- Upload date:
- Size: 30.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.5.0 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.8.10
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 7eb72889b0e3fefc4cd670911a1baad0241da453d4e0b70906f8cecb04cc3e96 |
|
MD5 | a7a58f765fb30976d6453135e53f449b |
|
BLAKE2b-256 | 83de5c4d93ab7519b6759ce38260631389f87cb640519f51153df451da78efba |