Skip to main content

Universal GPU setup and diagnostics for PyTorch and JAX — DirectML, ROCm, CUDA, MPS, CPU. Auto-detects your hardware and sets HSA_OVERRIDE_GFX_VERSION for unsupported AMD GPUs.

Project description

gpu-doctor

Universal GPU setup and diagnostics for PyTorch and JAX.

One tool. Every backend. Every OS. Zero required dependencies.

pip install gpu-doctor
python -m gpu_doctor --check      # see exactly what's installed and detected
python -m gpu_doctor --install    # install the right torch for your machine

Why this exists

Setting up GPU acceleration for PyTorch and JAX is a maze of platform-specific steps:

  • Windows AMD — DirectML requires Python ≤ 3.11 and installs in a specific order
  • Linux AMD — ROCm requires HSA_OVERRIDE_GFX_VERSION=10.3.0 for older cards (RX 5700 XT, RX 5700, RX 5600 XT) that aren't in ROCm's default allow-list
  • NVIDIA — standard CUDA, but still requires matching wheel URLs
  • macOS — MPS just works, but only Apple Silicon

gpu-doctor handles all of it. It detects your hardware, applies workarounds automatically, and gives you back one device object that works everywhere.

Note for PyTorch users: PyTorch's own get_best_device() (Issue #149719) is still an open proposal. gpu-doctor ships it today, plus AMD-specific fixes PyTorch doesn't cover.


Platforms

Platform Backend Python Notes
Windows DirectML 3.11 only AMD / Intel / NVIDIA via torch-directml
Linux / WSL2 ROCm Any AMD GPU. gfx1010 override auto-applied.
Linux / Windows CUDA Any NVIDIA GPU. Standard torch install.
macOS MPS Any Apple Silicon (M1/M2/M3).
Any CPU 3.8+ Always works. Zero GPU needed.

Quick Start

Step 1 — Check what you have

python -m gpu_doctor --check

Output:

==================================================================
  gpu-doctor — Environment Report
==================================================================
  gpu-doctor : v1.0.0
  Python     : 3.11.9
  OS         : Windows  AMD64

  torch      : 2.4.1+cpu
  DirectML   : 0.2.5  (AMD Radeon RX 5700 XT)
  JAX        : not installed

  Best device: directml  ← use this
==================================================================

Step 2 — Install the right torch

python -m gpu_doctor --install

This detects your hardware and runs the correct pip install command automatically.

Step 3 — Use in your code

from gpu_doctor import get_best_device, get_torch_device, get_dtype

# Call BEFORE importing torch — sets env vars (HSA_OVERRIDE, etc.) first
device_type = get_best_device()   # 'directml' | 'rocm' | 'cuda' | 'mps' | 'cpu'
device      = get_torch_device()  # torch.device or DML device — ready for .to()
dtype       = get_dtype()         # torch.float16 or float32 (safe for your backend)

model = MyModel().to(device).to(dtype)

Manual Install (Platform-Specific)

Windows — DirectML (AMD / Intel / NVIDIA GPU)

Requires Python 3.11. torch-directml is compiled against the 3.11 ABI. It will not work on 3.12+.

:: Create a 3.11 venv
py -3.11 -m venv .venv311
.venv311\Scripts\activate
pip install gpu-doctor

:: Install torch (let DirectML pull torch 2.4.1 — do NOT pre-install torch)
python -m gpu_doctor --install

:: Verify
python -m gpu_doctor --check

Linux / WSL2 — AMD GPU (ROCm)

pip install gpu-doctor
python -m gpu_doctor --install    # auto-detects ROCm, sets HSA_OVERRIDE if needed

# For RX 5700 XT (gfx1010) — if not auto-set:
export HSA_OVERRIDE_GFX_VERSION=10.3.0   # add to ~/.bashrc

python -m gpu_doctor --check

Or use the Python API to set env vars before importing torch:

from gpu_doctor import get_best_device     # this sets HSA_OVERRIDE_GFX_VERSION
import torch                               # now sees gfx1010 correctly

Linux / WSL2 — NVIDIA GPU (CUDA)

pip install gpu-doctor
python -m gpu_doctor --install
python -m gpu_doctor --check

macOS — Apple Silicon (MPS)

pip install gpu-doctor
python -m gpu_doctor --install
python -m gpu_doctor --check

JAX Support

from gpu_doctor import configure_jax_amd, get_jax_backend

# MUST call before import jax
configure_jax_amd()    # sets XLA_FLAGS, MIOPEN_USER_DB_PATH, HSA_OVERRIDE, etc.

import jax
print(get_jax_backend())   # 'gpu' or 'cpu'

Install JAX:

python -m gpu_doctor --install-jax   # auto-detects ROCm / CUDA / CPU

AMD RX 5700 XT (gfx1010) — Special Notes

The RX 5700 XT uses the gfx1010 architecture (Navi 10), which is not in ROCm's default hardware allow-list. Without an override, PyTorch silently falls back to CPU with no error message.

gpu-doctor detects this automatically and sets HSA_OVERRIDE_GFX_VERSION=10.3.0 before torch imports. You don't have to know this exists.

from gpu_doctor import get_best_device   # sets HSA_OVERRIDE_GFX_VERSION=10.3.0
import torch
print(torch.cuda.is_available())   # True — RX 5700 XT now visible

Affected GPUs (auto-handled):

GPU Architecture Override
RX 5700 XT, RX 5700 gfx1010 10.3.0
RX 5600 XT, RX 5500 XT gfx1010/1011/1012 10.3.0
Radeon VII gfx906 9.0.6
RX Vega 56/64 gfx900 9.0.0
RX 6000+ (gfx1030+) RDNA2/3 None needed

CLI Reference

python -m gpu_doctor [options]

  (no args)       Quick summary — best device and torch version
  --check         Full environment report (Python, torch, JAX, GPU tools)
  --install       Install correct torch for this machine
  --install-jax   Install correct JAX for this machine
  --json          Machine-readable JSON output (for scripts/CI)
# Use JSON output in scripts
DEVICE=$(python -m gpu_doctor --json | python -c "import json,sys; print(json.load(sys.stdin)['best_device'])")
echo "Training on: $DEVICE"

Comparison

Feature gpu-doctor devicetorch torchruntime GPUtil pyamdgpuinfo
DirectML (Windows AMD)
ROCm (Linux AMD) Partial
CUDA (NVIDIA)
MPS (Apple Silicon)
CPU fallback
HSA_OVERRIDE auto-set
JAX support
--check diagnostic mode
JSON output
Zero required deps
PyPI installable

Troubleshooting

torch.cuda.is_available() returns False on AMD (Linux) Your GPU likely needs an HSA override. Check:

python -m gpu_doctor --check   # shows 'rocm_gfx_arch' and 'hsa_override_applied'

Then in Python, call get_best_device() BEFORE importing torch.

DirectML not detected on Windows

  • Confirm Python version: python --version — must be 3.11 or lower.
  • Install: pip install torch-directml without pre-installing torch.

privateuseone:0 device string Normal. This is PyTorch's internal name for DirectML. Use get_torch_device() which returns the correct device object (not the string) — this is critical for diffusers .to() calls.

JAX sees CPU even with ROCm installed Call configure_jax_amd() before import jax. The env vars must be set before JAX initializes its backend.

More: torch-amd-setup troubleshooting


Related


Keywords

AMD GPU PyTorch setup · torch.cuda.is_available False AMD fix · HSA_OVERRIDE_GFX_VERSION ROCm PyTorch Windows · DirectML PyTorch · gfx1010 ROCm fix · RX 5700 XT PyTorch get_best_device PyTorch · cross-platform GPU setup · JAX AMD ROCm · JAX DirectML privateuseone:0 fix · torch-directml Python 3.11 · Apple Silicon MPS PyTorch WSL2 ROCm setup · AMD Radeon deep learning · NVIDIA CUDA auto-detect PyTorch device selection · GPU auto-detect Python · Navi 10 deep learning


License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gpu_doctor-1.0.0.tar.gz (17.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

gpu_doctor-1.0.0-py3-none-any.whl (13.0 kB view details)

Uploaded Python 3

File details

Details for the file gpu_doctor-1.0.0.tar.gz.

File metadata

  • Download URL: gpu_doctor-1.0.0.tar.gz
  • Upload date:
  • Size: 17.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.15

File hashes

Hashes for gpu_doctor-1.0.0.tar.gz
Algorithm Hash digest
SHA256 a88f0248d5377e9bb00b12c6862b6d1768a616b43dd887a456fe747f165116e8
MD5 31f599f11a6dd301266ae76c8d189508
BLAKE2b-256 99ea95d816a478b522ccc56d78bfe1136a72c2fe708dfd361aff1bf3bce131a6

See more details on using hashes here.

File details

Details for the file gpu_doctor-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: gpu_doctor-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 13.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.15

File hashes

Hashes for gpu_doctor-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 748864f411beefe9683150f4dac84cdbf8d722c9051bb31230f3e0417c4cd55c
MD5 60fd7e21ad2602daf0a45b2b8108ed1b
BLAKE2b-256 60897bdb665e141d4240239297d3401931f21a4b3d6d1833ea061be6bc1c68a7

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page