Skip to main content

PyTorch with Cloud GPUs

Project description

SkyTorch

Run PyTorch with remote GPUs.

An example of prompting OSS GPT 120B on a remote NVIDIA A100 80GB GPU using demo/llm_term.py:

SkyTorch LLM chat demo

SkyTorch provides:

  • A sky device backend that virtualizes remote GPUs and transparently streams tensor operations
  • A standalone gRPC server that runs on remote GPU hosts for PyTorch to connect to
  • A Kubernetes operator that provisions the SkyTorch server onto GPU pods / nodes on-demand

Examples

MNIST Training

import asyncio
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from skytorch.client import compute

class MNISTNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.pool = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.pool(nn.functional.relu(self.conv2(nn.functional.relu(self.conv1(x)))))
        x = self.fc2(nn.functional.relu(self.fc1(torch.flatten(x, 1))))
        return x

@compute(
    name="mnist",
    image="ghcr.io/astefanutti/skytorch-server",
    resources={"cpu": "1", "memory": "8Gi", "nvidia.com/gpu": "1"},
)
async def train(node, epochs: int = 10):
    device = node.device("cuda")

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
    ])
    train_loader = DataLoader(
        datasets.MNIST("data", train=True, download=True, transform=transform),
        batch_size=5000, shuffle=True,
    )

    model = MNISTNet().to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.002)

    for epoch in range(epochs):
        model.train()
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            loss = nn.functional.cross_entropy(model(data), target)
            loss.backward()
            optimizer.step()

asyncio.run(train())

See demo/mnist.py for the full example.

GPT OSS 120B Inferencing

import asyncio
import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, TextStreamer
from skytorch.client import compute

@compute(
    name="gpt",
    image="ghcr.io/astefanutti/skytorch-server",
    resources={"cpu": "4", "memory": "32Gi", "nvidia.com/gpu": "1"},
    volumes=[{"name": "cache", "storage": "80Gi", "path": "/cache"}],
    env={"HF_HOME": "/cache", "TRITON_HOME": "/cache"},
)
async def chat(node):
    device = node.device("cuda")
    model_name = "openai/gpt-oss-120b"

    def load_model(model):
        return AutoModelForCausalLM.from_pretrained(model, device_map="cuda")

    # Load the model weights server-side (stays on GPU, only metadata returned)
    # and the tokenizer locally in parallel
    state_dict, tokenizer = await asyncio.gather(
        node.execute(load_model, model_name, retain_model=True),
        asyncio.to_thread(AutoTokenizer.from_pretrained, model_name),
    )

    # Sync model locally (no weights downloaded)
    with torch.device("meta"):
        config = AutoConfig.from_pretrained(model_name)
        model = AutoModelForCausalLM.from_config(config)
    state_dict.load_into(model, triton_modules=["model.layers.*.mlp"])
    model.eval()

    streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    history = [{"role": "system", "content": "You are a helpful assistant."}]

    with torch.no_grad():
        while True:
            user_input = input("You: ")
            if user_input.strip().lower() in ("quit", "exit"):
                break

            history.append({"role": "user", "content": user_input})
            inputs = tokenizer.apply_chat_template(
                history,
                add_generation_prompt=True,
                return_tensors="pt",
                return_dict=True,
            )
            inputs = {k: v.to(device) for k, v in inputs.items()}

            print("Assistant: ", end="", flush=True)
            generated = model.generate(
                **inputs, max_new_tokens=512, do_sample=False, streamer=streamer,
            )

            response = tokenizer.decode(
                generated[0, inputs["input_ids"].shape[1] :],
                skip_special_tokens=True,
            )
            history.append({"role": "assistant", "content": response})

try:
    asyncio.run(chat())
except KeyboardInterrupt:
    pass

See demo/llm_gpt.py for the full example.

Note: GPT OSS 120B requires a GPU with at least 80GB of memory.

GRPO Training

import copy
from transformers import AutoModelForCausalLM
from skytorch.client import Compute, Cluster

async with Cluster(
    Compute(
        name="trainer",
        resources={"nvidia.com/gpu": "1"},
    ),
    Compute(
        name="vllm",
        resources={"nvidia.com/gpu": "1"},
    ),
) as (trainer, vllm):
    trainer_device = trainer.device("cuda")
    vllm_device = vllm.device("cuda")

    # Load the policy model on the trainer and copy it to vLLM
    model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-7B")
    model.to(trainer_device)
    ref_model = copy.deepcopy(model).to(vllm_device)

    for step in range(10):
        # GRPO training step on the trainer device
        # ...

        # Sync weights from trainer to vLLM
        for p, ref_p in zip(model.parameters(), ref_model.parameters()):
            ref_p.data.copy_(p.data)

Note: Cross-compute tensor copy is not supported yet. This example illustrates a future capability.

Getting Started

pip install torch
pip install --no-build-isolation skytorch

--no-build-isolation is required because the C++ extension needs PyTorch headers at build time, so PyTorch must be installed first.

Remote Host

Start the SkyTorch server on a machine with a GPU, then connect from your local machine.

On the GPU machine:

python -m skytorch.torch.server --port 50051

On your local machine:

import asyncio
import torch
from skytorch.torch.server import compute

@compute("gpu-host:50051")
async def main(node):
    device = node.device("cuda")

    x = torch.randn(4, 4, device=device)
    y = x @ x.T
    print(y.cpu())

asyncio.run(main())

Kubernetes

SkyTorch can be deployed as a Kubernetes operator. This requires a cluster with Gateway API support.

Install the operator using Kustomize, choosing the overlay that matches your cluster:

# Vanilla Kubernetes / KinD (includes Contour as the Gateway API controller)
kubectl apply --server-side -k config/e2e

# OpenShift (uses the built-in gateway controller)
kubectl apply --server-side -k config/openshift

Then on your local machine, you can run:

import asyncio
import torch
from skytorch.client import compute

@compute(
    name="demo",
    image="ghcr.io/astefanutti/skytorch-server",
    resources={"cpu": "1", "memory": "8Gi", "nvidia.com/gpu": "1"},
)
async def main(node):
    device = node.device("cuda")

    x = torch.randn(4, 4, device=device)
    y = x @ x.T
    print(y.cpu())

asyncio.run(main())

Configuration

SkyTorch can be configured via environment variables.

Client

Variable Default Description
SKYTORCH_GRPC_COMPRESSION gzip gRPC compression (none, deflate, gzip)
SKYTORCH_BATCH_COALESCE_MS 2 Delay (ms) to coalesce partial batches
SKYTORCH_BATCH_THRESHOLD 64 Ops buffered before forced flush
SKYTORCH_DELETE_COALESCE_MS 10 Delay (ms) to coalesce deferred tensor deletes
SKYTORCH_DELETE_THRESHOLD 256 Deferred delete IDs buffered before forced flush
SKYTORCH_STREAMING 1 Enable bidirectional gRPC streaming, must be set to 0 for IDE debugging
SKYTORCH_CPP_REQUEST_BUILDER 1 Use C++ fast path for request serialization
SKYTORCH_SPECULATIVE_SCALAR 1 Predict .item() results to avoid sync
SKYTORCH_ASYNC_COPY 0 Non-blocking copy to CPU for single-element tensors (e.g. for LLM inferencing)
SKYTORCH_PROFILE 0 Enable lightweight profiling (~200ns/op overhead)

Server

Variable Default Description
SKYTORCH_PORT 50051 gRPC server port
SKYTORCH_HOST [::] Server bind address
SKYTORCH_LOG_LEVEL INFO Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
SKYTORCH_GRPC_COMPRESSION none gRPC compression (none, deflate, gzip)
SKYTORCH_CHUNK_SIZE 1048576 Chunk size in bytes for streaming tensors
SKYTORCH_METRICS_SOURCES (empty) Comma-separated metrics sources (e.g., nvidia-gpu)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

skytorch-0.2.0.tar.gz (399.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

skytorch-0.2.0-cp312-cp312-macosx_10_13_universal2.whl (1.6 MB view details)

Uploaded CPython 3.12macOS 10.13+ universal2 (ARM64, x86-64)

File details

Details for the file skytorch-0.2.0.tar.gz.

File metadata

  • Download URL: skytorch-0.2.0.tar.gz
  • Upload date:
  • Size: 399.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.10.9 {"installer":{"name":"uv","version":"0.10.9","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for skytorch-0.2.0.tar.gz
Algorithm Hash digest
SHA256 bc7542a6812e85bf37b825a4973e398b969c0ccd221d99649fe399ceb1e206fe
MD5 dcf860c4b05f2e0ee9e89af1c705f9c7
BLAKE2b-256 7550f54a605890c214b236e48b5b26cbeda253b7be944e57ac842400a04a1e16

See more details on using hashes here.

File details

Details for the file skytorch-0.2.0-cp312-cp312-macosx_10_13_universal2.whl.

File metadata

  • Download URL: skytorch-0.2.0-cp312-cp312-macosx_10_13_universal2.whl
  • Upload date:
  • Size: 1.6 MB
  • Tags: CPython 3.12, macOS 10.13+ universal2 (ARM64, x86-64)
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.10.9 {"installer":{"name":"uv","version":"0.10.9","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for skytorch-0.2.0-cp312-cp312-macosx_10_13_universal2.whl
Algorithm Hash digest
SHA256 e4041a3217c0f2547d2490cb855979b28ebade061e8ce75e6e564d345f917fc1
MD5 36298e8b0fe53d68244fd2c2e2a05f18
BLAKE2b-256 26aa13e6b0671462666fb21598e87cbfe989646c905597e09b72c1673467d242

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page