Auto-detects the best PyTorch compute device for AMD GPUs, with gfx1010 ROCm override support (RX 5700 XT, RX 5600 XT, Navi 10)
Project description
torch-amd-setup
Auto-detects the best PyTorch compute device for AMD GPUs — with first-class support for cards that are not in ROCm's default allow-list (RX 5700 XT, RX 5600 XT, RX 5500 XT / gfx1010–gfx1012).
One import. No manual env var hunting. Works on Windows, Linux, WSL2, and macOS.
from torch_amd_setup import get_best_device, get_torch_device, get_dtype
device_type = get_best_device() # "rocm" | "dml" | "cuda" | "mps" | "cpu"
device = get_torch_device() # torch.device ready for model.to()
dtype = get_dtype() # torch.float16 or torch.float32
The problem this solves
AMD GPUs that use the gfx1010 architecture (Navi 10 — RX 5700 XT, RX 5700, RX 5600 XT) are not in ROCm's default supported GPU list. PyTorch on ROCm will silently fall back to CPU unless you set:
export HSA_OVERRIDE_GFX_VERSION=10.3.0
...but it has to be set before Python imports torch, which means you either:
- Remember to set it in every shell session, or
- Bake it into a shell script wrapper, or
- Set it in your Python script before the first
import torch
torch-amd-setup handles all of that automatically. It also detects DirectML on Windows (no ROCm required), Apple MPS on macOS, NVIDIA CUDA, and falls back to CPU — so you can ship one codebase that works everywhere.
Detection priority
| Priority | Backend | Platform | Requirement |
|---|---|---|---|
| 1 | NVIDIA CUDA | Any | Standard pip install torch |
| 2 | AMD ROCm | Linux / WSL2 | ROCm PyTorch + AMD driver ≥22.20 |
| 3 | AMD DirectML | Windows | pip install torch-directml, Py≤3.11 |
| 4 | Apple MPS | macOS Apple Silicon | Standard pip install torch |
| 5 | CPU | Any | Always available, always slow |
Install
pip install torch-amd-setup
torchis not a hard dependency — install the appropriate torch variant for your hardware first (see Tutorial).
Quick start
from torch_amd_setup import get_best_device, get_torch_device, get_dtype
import torch
device_type = get_best_device()
device = get_torch_device(device_type)
dtype = get_dtype(device_type)
print(f"Using: {device_type} → {device} @ {dtype}")
# Load your model
model = MyModel().to(device).to(dtype)
Diagnostics CLI
python -m torch_amd_setup
Output:
── torch-amd-setup diagnostics ──────────────────────────────
python_version 3.10.12
platform Linux-6.6.x-WSL2-x86_64
best_device rocm
cuda_available True
cuda_device_name AMD Radeon RX 5700 XT
cuda_vram_mb 8176
rocm_available True
torch_version 2.6.0+rocm6.1
...
API Reference
get_best_device() → str
Returns the best available device type as a string: "cuda", "rocm", "dml", "mps", or "cpu".
get_torch_device(device_type=None) → torch.device
Returns a torch.device object (or a DirectML device object for "dml") ready for model.to(). If device_type is None, calls get_best_device() automatically.
get_dtype(device_type=None) → torch.dtype
Returns torch.float16 for CUDA/ROCm/MPS, and torch.float32 for DirectML/CPU. DirectML float16 support is unreliable; this keeps you safe.
device_info() → dict
Returns a diagnostic dictionary with all detected hardware info. Useful for logging and bug reports.
get_install_guide() → str
Returns platform-appropriate install instructions as a formatted string.
get_wsl2_install_guide() → str
Returns the full WSL2 + ROCm setup walkthrough for AMD GPUs on Windows.
AMD_ROCM_ENV: dict
The environment variable overrides applied for gfx1010 support. You can inspect or override these before calling get_best_device().
AMD GPU compatibility
| GPU | Architecture | HSA Override | Tested |
|---|---|---|---|
| RX 5700 XT | gfx1010 | 10.3.0 |
✅ |
| RX 5700 | gfx1010 | 10.3.0 |
✅ |
| RX 5600 XT | gfx1010 | 10.3.0 |
✅ |
| RX 5500 XT | gfx1011 | 10.3.0 |
⚠️ reported |
| RX 6000 series (gfx1030+) | RDNA2 | Not needed | ✅ native ROCm |
| RX 7000 series (gfx1100+) | RDNA3 | Not needed | ✅ native ROCm |
If your card isn't listed, check GFX_OVERRIDE_MAP in detect.py and open a PR.
Windows users: DirectML vs WSL2
| Feature | DirectML | WSL2 + ROCm |
|---|---|---|
| Setup difficulty | Easy | Medium |
| float16 support | ❌ (float32 only) | ✅ |
| Python version limit | 3.11 max | Any |
| GPU memory usage | ~1.5× higher | Native |
| Best for | Quick experiments | Production workloads |
Contributing
PRs welcome. Especially interested in:
- Verified gfx override values for additional GPU models
- ROCm 6.2+ compatibility reports
- Windows DirectML on NVIDIA/Intel test results
Please open an issue before large PRs.
License
MIT — see LICENSE.
Background
This package was extracted from a private AI music pipeline project. The gfx1010 ROCm workaround was discovered the hard way — through several hours of cascading PyTorch installs, ROCm SDK conflicts, and dependency hell. The goal is that nobody else has to spend that time.
See docs/lessons-learned.md for the full story.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torch_amd_setup-0.1.0.tar.gz.
File metadata
- Download URL: torch_amd_setup-0.1.0.tar.gz
- Upload date:
- Size: 18.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
0607229fcc24780f770a3b4775508f99ae60c394bbbe4a2afbe235739f8c33c5
|
|
| MD5 |
c4cd1751951a3c71a9903726ccfce2d6
|
|
| BLAKE2b-256 |
bfc3b88ad1f152ae26760ede4e38a617d5c98c62d1b055fe7f5c4cac3e18648a
|
File details
Details for the file torch_amd_setup-0.1.0-py3-none-any.whl.
File metadata
- Download URL: torch_amd_setup-0.1.0-py3-none-any.whl
- Upload date:
- Size: 10.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
0d6dc4a7869b45d3b0df8a6a817c6f9f5e199e2b09753987c53cc6bc2fcb36d3
|
|
| MD5 |
9a098f6be2d88b143a7632a9fe3b5047
|
|
| BLAKE2b-256 |
57da67cd11692f690e3aefd7ba0cd1d9ae8e9b6dc9d81b93df47dd48242a87d2
|