Skip to main content

PyTorch with Cloud GPUs

Project description

SkyTorch

Run PyTorch locally with the Power of Cloud GPUs.

SkyTorch registers a "sky" device backend in PyTorch that transparently streams tensor operations to cloud GPUs managed by Kubernetes.

MNIST Training

import asyncio
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from skytorch.client import Compute, compute

class MNISTNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.pool = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.pool(nn.functional.relu(self.conv2(nn.functional.relu(self.conv1(x)))))
        x = self.fc2(nn.functional.relu(self.fc1(torch.flatten(x, 1))))
        return x

@compute(
    name="mnist",
    image="ghcr.io/astefanutti/skytorch-server:latest",
    resources={"nvidia.com/gpu": "1"},
)
async def train(node: Compute, epochs: int = 10):
    device = node.device("cuda")

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
    ])
    train_loader = DataLoader(
        datasets.MNIST("data", train=True, download=True, transform=transform),
        batch_size=5000, shuffle=True,
    )

    model = MNISTNet().to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.002)

    for epoch in range(epochs):
        model.train()
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            loss = nn.functional.cross_entropy(model(data), target)
            loss.backward()
            optimizer.step()

asyncio.run(train())

Multi-GPU GRPO Training

import asyncio
import copy
from transformers import AutoModelForCausalLM
from skytorch.client import Compute, Cluster

async def main():
    async with Cluster(
        Compute(name="trainer"),
        Compute(name="vllm"),
    ) as (trainer, vllm):
        trainer_device = trainer.device("cuda")
        vllm_device = vllm.device("cuda")

        # Load the policy model on the trainer and copy it to vLLM
        model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-7B")
        model.to(trainer_device)
        ref_model = copy.deepcopy(model).to(vllm_device)

        for step in range(10):
            # GRPO training step on the trainer device
            # ...

            # Sync weights from trainer to vLLM
            for p, ref_p in zip(model.parameters(), ref_model.parameters()):
                ref_p.data.copy_(p.data)

asyncio.run(main())

Note: Cross-compute tensor copy is not supported yet. This example illustrates a future capability.

Getting Started

pip install torch
pip install --no-build-isolation skytorch

--no-build-isolation is required because the C++ extension needs PyTorch headers at build time, so PyTorch must be installed first.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

skytorch-0.1.0.tar.gz (292.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

skytorch-0.1.0-cp312-cp312-macosx_10_13_universal2.whl (996.5 kB view details)

Uploaded CPython 3.12macOS 10.13+ universal2 (ARM64, x86-64)

File details

Details for the file skytorch-0.1.0.tar.gz.

File metadata

  • Download URL: skytorch-0.1.0.tar.gz
  • Upload date:
  • Size: 292.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.6.10

File hashes

Hashes for skytorch-0.1.0.tar.gz
Algorithm Hash digest
SHA256 888145797c1fc9049c6953c6db15796852e4d53508e17b8f6f920de1f580f890
MD5 16bd51ad3b175f584adfc3008a5b7e8f
BLAKE2b-256 811903b14a6a72fc117c830487f0d6d680b8e1b22fc69f7c872f3226992d0580

See more details on using hashes here.

File details

Details for the file skytorch-0.1.0-cp312-cp312-macosx_10_13_universal2.whl.

File metadata

File hashes

Hashes for skytorch-0.1.0-cp312-cp312-macosx_10_13_universal2.whl
Algorithm Hash digest
SHA256 c39672939b6fa85d76cb31694f1646668e633081a7f28b110e648aef4e4008bc
MD5 3bb89625c1dd9908fb751c9d15a8fe44
BLAKE2b-256 39d45afe500e426c0926746400ee52d77f349409a34012895502d82b0f76346f

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page