Created
September 12, 2025 21:20
-
-
Save kundeng/7ae987bc1a6dfdf75175f9c0f0af9711 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Stop on any error | |
$ErrorActionPreference = "Stop" | |
Write-Host "=== Installing ROCm PyTorch Environment on Windows (gfx1151) ===" | |
# --------------------------- | |
# Helper: robust download | |
# --------------------------- | |
function Download-WithRetries { | |
param( | |
[Parameter(Mandatory=$true)][string]$Url, | |
[Parameter(Mandatory=$true)][string]$Dest, | |
[int]$MaxRetries = 3, | |
[int]$DelaySec = 5 | |
) | |
for ($i=1; $i -le $MaxRetries; $i++) { | |
try { | |
curl.exe -L $Url -o $Dest | |
if ((Get-Item $Dest).Length -gt 1000000) { return } | |
} catch { | |
if (Test-Path $Dest) { Remove-Item -Force $Dest -ErrorAction SilentlyContinue } | |
if ($i -lt $MaxRetries) { | |
Write-Warning "Download failed ($i/$MaxRetries). Retrying in $DelaySec s ..." | |
Start-Sleep -Seconds $DelaySec | |
} else { | |
throw "Failed to download $Url after $MaxRetries attempts." | |
} | |
} | |
} | |
} | |
# --------------------------- | |
# Ensure uv is installed | |
# --------------------------- | |
$uvPkg = winget list --id "astral-sh.uv" | Out-String | |
if ($uvPkg -notmatch "astral-sh.uv") { | |
Write-Host "Installing uv..." | |
winget install -e --id astral-sh.uv --source winget --accept-package-agreements --accept-source-agreements | |
} else { | |
Write-Host "uv already installed." | |
} | |
# Locate uv binary dynamically | |
$uvExe = (Get-Command uv -ErrorAction SilentlyContinue | Select-Object -ExpandProperty Source) | |
if (-not $uvExe) { | |
$possible = @( | |
"$env:LocalAppData\Microsoft\WinGet\Links\uv.exe", | |
"$env:LocalAppData\Programs\uv\bin\uv.exe", | |
"$env:LocalAppData\Programs\uv\uv.exe", | |
"$env:LocalAppData\Microsoft\WinGet\Packages\astral-sh.uv*\uv.exe" | |
) | |
foreach ($p in $possible) { | |
$found = Get-Item -Path $p -ErrorAction SilentlyContinue | |
if ($found) { $uvExe = $found.FullName; break } | |
} | |
} | |
if (-not $uvExe) { throw "uv not found. Try restarting PowerShell or reinstall: winget install astral-sh.uv" } | |
Write-Host "Using uv at: $uvExe" | |
function Run-UV { | |
param([Parameter(ValueFromRemainingArguments = $true)] $Args) | |
& $uvExe @Args | |
} | |
# --------------------------- | |
# Project folder & Python | |
# --------------------------- | |
$projDir = "$HOME\rocm-pytorch" | |
if (-not (Test-Path $projDir)) { New-Item -ItemType Directory -Force -Path $projDir | Out-Null } | |
Set-Location $projDir | |
Write-Host "Ensuring uv has Python 3.12..." | |
Run-UV python install 3.12 | |
if (-not (Test-Path ".venv")) { | |
Write-Host "Creating uv-managed Python 3.12 virtual environment..." | |
Run-UV venv --python=3.12 | |
} else { | |
Write-Host "Virtual environment already exists." | |
} | |
& .\.venv\Scripts\Activate.ps1 | |
# --------------------------- | |
# Resolve wheels from GitHub | |
# --------------------------- | |
$tag = "v6.5.0rc-pytorch-gfx110x" | |
$api = "https://api.github.com/repos/scottt/rocm-TheRock/releases/tags/$tag" | |
$headers = @{ "User-Agent" = "pwsh-rocm-setup"; "Accept" = "application/vnd.github+json" } | |
try { | |
$release = Invoke-RestMethod -Uri $api -Headers $headers -ErrorAction Stop | |
} catch { | |
throw "Failed to query GitHub release API for tag $tag. $_" | |
} | |
$assets = $release.assets | |
if (-not $assets) { throw "No assets found for tag $tag" } | |
$pickTorch = $assets | Where-Object { $_.name -match '^torch-.*cp312-.*win_amd64\.whl$' } | Sort-Object name -Descending | Select-Object -First 1 | |
$pickVision = $assets | Where-Object { $_.name -match '^torchvision-.*cp312-.*win_amd64\.whl$' } | Sort-Object name -Descending | Select-Object -First 1 | |
$pickAudio = $assets | Where-Object { $_.name -match '^torchaudio-.*cp312-.*win_amd64\.whl$' } | Sort-Object name -Descending | Select-Object -First 1 | |
if (-not $pickTorch -or -not $pickVision -or -not $pickAudio) { throw "Could not resolve all wheels for cp312." } | |
$wheelList = @($pickTorch, $pickVision, $pickAudio) | |
Write-Host "Resolved wheels:" | |
$wheelList | ForEach-Object { Write-Host " - $($_.name)" } | |
# --------------------------- | |
# Download wheels & install | |
# --------------------------- | |
$downloaded = @() | |
foreach ($w in $wheelList) { | |
$dest = Join-Path $projDir $w.name | |
if (-not (Test-Path $dest)) { | |
Write-Host "Downloading $($w.name)..." | |
Download-WithRetries -Url $w.browser_download_url -Dest $dest | |
} else { | |
Write-Host "$($w.name) already downloaded." | |
} | |
$downloaded += $dest | |
} | |
Write-Host "Installing wheels with uv pip..." | |
Run-UV pip install $downloaded | |
# --------------------------- | |
# Ensure NumPy <2 for ABI | |
# --------------------------- | |
Write-Host "Enforcing NumPy <2 to match ROCm wheels..." | |
Run-UV pip install "numpy<2" | |
# --------------------------- | |
# HIP runtime / driver install | |
# --------------------------- | |
$hipCheck = Get-Command hipinfo.exe -ErrorAction SilentlyContinue | |
if (-not $hipCheck) { | |
$hipUrl = "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-25.Q3-Win10-Win11-For-HIP.exe" | |
$hipExe = "$projDir\AMD-Software-PRO-Edition-25.Q3-Win10-Win11-For-HIP.exe" | |
if (-not (Test-Path $hipExe)) { | |
Write-Host "Downloading AMD Software PRO Edition (HIP runtime + drivers)..." | |
Download-WithRetries -Url $hipUrl -Dest $hipExe | |
} else { | |
Write-Host "AMD Software PRO Edition installer already present." | |
} | |
if (Test-Path $hipExe) { | |
Write-Host "Installing AMD Software PRO Edition silently..." | |
Start-Process -FilePath $hipExe -ArgumentList "/S" -Wait | |
Write-Host "Installation complete. A reboot may be required for drivers to take effect." | |
} | |
} else { | |
Write-Host "HIP SDK already installed (hipinfo.exe found)." | |
} | |
# --------------------------- | |
# PyTorch GPU test | |
# --------------------------- | |
Write-Host "Running PyTorch GPU test..." | |
$testScript = @" | |
import torch | |
print("Torch version:", torch.__version__) | |
print("CUDA available:", torch.cuda.is_available()) | |
if torch.cuda.is_available(): | |
print("Device count:", torch.cuda.device_count()) | |
for i in range(torch.cuda.device_count()): | |
print(f"Device {i}:", torch.cuda.get_device_name(i)) | |
x = torch.rand((3,3), device="cuda") | |
y = torch.mm(x, x) | |
print("Matrix multiply result on GPU:\n", y) | |
else: | |
print("No ROCm-compatible GPU detected.") | |
"@ | |
$testFile = "$projDir\test_torch_gpu.py" | |
$testScript | Out-File -Encoding UTF8 $testFile | |
python $testFile | |
Write-Host "`n=== Installation and test complete ===" | |
Write-Host "To activate later: .\.venv\Scripts\Activate.ps1" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment