PyTorch with Cloud GPUs
Project description
SkyTorch
Run PyTorch locally with the Power of Cloud GPUs.
SkyTorch registers a "sky" device backend in PyTorch that transparently streams tensor operations to cloud GPUs managed by Kubernetes.
MNIST Training
import asyncio
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from skytorch.client import Compute, compute
class MNISTNet(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.pool = nn.MaxPool2d(2)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.pool(nn.functional.relu(self.conv2(nn.functional.relu(self.conv1(x)))))
x = self.fc2(nn.functional.relu(self.fc1(torch.flatten(x, 1))))
return x
@compute(
name="mnist",
image="ghcr.io/astefanutti/skytorch-server:latest",
resources={"nvidia.com/gpu": "1"},
)
async def train(node: Compute, epochs: int = 10):
device = node.device("cuda")
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
])
train_loader = DataLoader(
datasets.MNIST("data", train=True, download=True, transform=transform),
batch_size=5000, shuffle=True,
)
model = MNISTNet().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.002)
for epoch in range(epochs):
model.train()
for data, target in train_loader:
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
loss = nn.functional.cross_entropy(model(data), target)
loss.backward()
optimizer.step()
asyncio.run(train())
Multi-GPU GRPO Training
import asyncio
import copy
from transformers import AutoModelForCausalLM
from skytorch.client import Compute, Cluster
async def main():
async with Cluster(
Compute(name="trainer"),
Compute(name="vllm"),
) as (trainer, vllm):
trainer_device = trainer.device("cuda")
vllm_device = vllm.device("cuda")
# Load the policy model on the trainer and copy it to vLLM
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-7B")
model.to(trainer_device)
ref_model = copy.deepcopy(model).to(vllm_device)
for step in range(10):
# GRPO training step on the trainer device
# ...
# Sync weights from trainer to vLLM
for p, ref_p in zip(model.parameters(), ref_model.parameters()):
ref_p.data.copy_(p.data)
asyncio.run(main())
Note: Cross-compute tensor copy is not supported yet. This example illustrates a future capability.
Getting Started
pip install torch
pip install --no-build-isolation skytorch
--no-build-isolation is required because the C++ extension needs PyTorch headers at build time, so PyTorch must be installed first.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file skytorch-0.1.0.tar.gz.
File metadata
- Download URL: skytorch-0.1.0.tar.gz
- Upload date:
- Size: 292.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.6.10
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
888145797c1fc9049c6953c6db15796852e4d53508e17b8f6f920de1f580f890
|
|
| MD5 |
16bd51ad3b175f584adfc3008a5b7e8f
|
|
| BLAKE2b-256 |
811903b14a6a72fc117c830487f0d6d680b8e1b22fc69f7c872f3226992d0580
|
File details
Details for the file skytorch-0.1.0-cp312-cp312-macosx_10_13_universal2.whl.
File metadata
- Download URL: skytorch-0.1.0-cp312-cp312-macosx_10_13_universal2.whl
- Upload date:
- Size: 996.5 kB
- Tags: CPython 3.12, macOS 10.13+ universal2 (ARM64, x86-64)
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.6.10
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
c39672939b6fa85d76cb31694f1646668e633081a7f28b110e648aef4e4008bc
|
|
| MD5 |
3bb89625c1dd9908fb751c9d15a8fe44
|
|
| BLAKE2b-256 |
39d45afe500e426c0926746400ee52d77f349409a34012895502d82b0f76346f
|