Skip to main content

Auto-detects the best PyTorch compute device for AMD GPUs, with gfx1010 ROCm override support (RX 5700 XT, RX 5600 XT, Navi 10)

Project description

torch-amd-setup

Auto-detects the best PyTorch compute device for AMD GPUs — with first-class support for cards that are not in ROCm's default allow-list (RX 5700 XT, RX 5600 XT, RX 5500 XT / gfx1010–gfx1012).

One import. No manual env var hunting. Works on Windows, Linux, WSL2, and macOS.

from torch_amd_setup import get_best_device, get_torch_device, get_dtype

device_type = get_best_device()   # "rocm" | "dml" | "cuda" | "mps" | "cpu"
device      = get_torch_device()  # torch.device ready for model.to()
dtype       = get_dtype()         # torch.float16 or torch.float32

The problem this solves

AMD GPUs that use the gfx1010 architecture (Navi 10 — RX 5700 XT, RX 5700, RX 5600 XT) are not in ROCm's default supported GPU list. PyTorch on ROCm will silently fall back to CPU unless you set:

export HSA_OVERRIDE_GFX_VERSION=10.3.0

...but it has to be set before Python imports torch, which means you either:

  • Remember to set it in every shell session, or
  • Bake it into a shell script wrapper, or
  • Set it in your Python script before the first import torch

torch-amd-setup handles all of that automatically. It also detects DirectML on Windows (no ROCm required), Apple MPS on macOS, NVIDIA CUDA, and falls back to CPU — so you can ship one codebase that works everywhere.


Detection priority

Priority Backend Platform Requirement
1 NVIDIA CUDA Any Standard pip install torch
2 AMD ROCm Linux / WSL2 ROCm PyTorch + AMD driver ≥22.20
3 AMD DirectML Windows pip install torch-directml, Py≤3.11
4 Apple MPS macOS Apple Silicon Standard pip install torch
5 CPU Any Always available, always slow

Install

pip install torch-amd-setup

torch is not a hard dependency — install the appropriate torch variant for your hardware first (see Tutorial).


Quick start

from torch_amd_setup import get_best_device, get_torch_device, get_dtype
import torch

device_type = get_best_device()
device      = get_torch_device(device_type)
dtype       = get_dtype(device_type)

print(f"Using: {device_type}{device} @ {dtype}")

# Load your model
model = MyModel().to(device).to(dtype)

Diagnostics CLI

python -m torch_amd_setup

Output:

── torch-amd-setup diagnostics ──────────────────────────────
  python_version            3.10.12
  platform                  Linux-6.6.x-WSL2-x86_64
  best_device               rocm
  cuda_available            True
  cuda_device_name          AMD Radeon RX 5700 XT
  cuda_vram_mb              8176
  rocm_available            True
  torch_version             2.6.0+rocm6.1
  ...

API Reference

get_best_device() → str

Returns the best available device type as a string: "cuda", "rocm", "dml", "mps", or "cpu".

get_torch_device(device_type=None) → torch.device

Returns a torch.device object (or a DirectML device object for "dml") ready for model.to(). If device_type is None, calls get_best_device() automatically.

get_dtype(device_type=None) → torch.dtype

Returns torch.float16 for CUDA/ROCm/MPS, and torch.float32 for DirectML/CPU. DirectML float16 support is unreliable; this keeps you safe.

device_info() → dict

Returns a diagnostic dictionary with all detected hardware info. Useful for logging and bug reports.

get_install_guide() → str

Returns platform-appropriate install instructions as a formatted string.

get_wsl2_install_guide() → str

Returns the full WSL2 + ROCm setup walkthrough for AMD GPUs on Windows.

AMD_ROCM_ENV: dict

The environment variable overrides applied for gfx1010 support. You can inspect or override these before calling get_best_device().


AMD GPU compatibility

GPU Architecture HSA Override Tested
RX 5700 XT gfx1010 10.3.0
RX 5700 gfx1010 10.3.0
RX 5600 XT gfx1010 10.3.0
RX 5500 XT gfx1011 10.3.0 ⚠️ reported
RX 6000 series (gfx1030+) RDNA2 Not needed ✅ native ROCm
RX 7000 series (gfx1100+) RDNA3 Not needed ✅ native ROCm

If your card isn't listed, check GFX_OVERRIDE_MAP in detect.py and open a PR.


Windows users: DirectML vs WSL2

Feature DirectML WSL2 + ROCm
Setup difficulty Easy Medium
float16 support ❌ (float32 only)
Python version limit 3.11 max Any
GPU memory usage ~1.5× higher Native
Best for Quick experiments Production workloads

Contributing

PRs welcome. Especially interested in:

  • Verified gfx override values for additional GPU models
  • ROCm 6.2+ compatibility reports
  • Windows DirectML on NVIDIA/Intel test results

Please open an issue before large PRs.


License

MIT — see LICENSE.


Background

This package was extracted from a private AI music pipeline project. The gfx1010 ROCm workaround was discovered the hard way — through several hours of cascading PyTorch installs, ROCm SDK conflicts, and dependency hell. The goal is that nobody else has to spend that time.

See docs/lessons-learned.md for the full story.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_amd_setup-0.1.0.tar.gz (18.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torch_amd_setup-0.1.0-py3-none-any.whl (10.6 kB view details)

Uploaded Python 3

File details

Details for the file torch_amd_setup-0.1.0.tar.gz.

File metadata

  • Download URL: torch_amd_setup-0.1.0.tar.gz
  • Upload date:
  • Size: 18.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.3

File hashes

Hashes for torch_amd_setup-0.1.0.tar.gz
Algorithm Hash digest
SHA256 0607229fcc24780f770a3b4775508f99ae60c394bbbe4a2afbe235739f8c33c5
MD5 c4cd1751951a3c71a9903726ccfce2d6
BLAKE2b-256 bfc3b88ad1f152ae26760ede4e38a617d5c98c62d1b055fe7f5c4cac3e18648a

See more details on using hashes here.

File details

Details for the file torch_amd_setup-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for torch_amd_setup-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 0d6dc4a7869b45d3b0df8a6a817c6f9f5e199e2b09753987c53cc6bc2fcb36d3
MD5 9a098f6be2d88b143a7632a9fe3b5047
BLAKE2b-256 57da67cd11692f690e3aefd7ba0cd1d9ae8e9b6dc9d81b93df47dd48242a87d2

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page