Skip to content

Instantly share code, notes, and snippets.

@eric-czech
Last active January 3, 2025 17:33
Show Gist options
  • Save eric-czech/1c5029ae071b1a9498ff7fab8bdfd3e0 to your computer and use it in GitHub Desktop.
Save eric-czech/1c5029ae071b1a9498ff7fab8bdfd3e0 to your computer and use it in GitHub Desktop.
Triton build script for ARM/AArch64

This script verifies that Triton can be built and minimally tested on ARM64 systems, specifically on AWS Graviton2 instances. Instructions:

  1. Create a g5g.xlarge instance using the following AWS CLI command and make sure to add at least 30G storage:
aws ec2 run-instances \
    --instance-type g5g.xlarge \
    --image-id ami-02dcfe5d1d39baa4e \
    --key-name <key-pair-name> \
    --count 1 \
    --block-device-mappings '[{"DeviceName":"/dev/xvda","Ebs":{"VolumeSize":40,"VolumeType":"gp3"}}]' \
    --region us-east-1
  1. Connect to the host (ssh -i /path/to/<key-pair-name>.pem ec2-user@<hostname>)

  2. Execute the script with source build_triton.sh (it must be sourced for proper conda env activation).

  3. The wheel file for triton builds at /home/$USER/triton/python/dist/triton-*-linux_aarch64.whl. Here is one example: triton-3.2.0+git781ae0b2-cp39-cp39-linux_aarch64.whl

  4. The resulting system+python environment has these versions:

system:
  os: Amazon Linux 2023.6.20241212
  instance: g5g.xlarge
  cpu: Graviton2 - Neoverse-N1 (ARM/aarch64)
  gpu: NVIDIA T4G
  gpu-gencode: sm_75
  cuda-version: 12.6
  cuda-driver-version: 560.35.05
python:
  python: 3.9.21
  cuda-cudart: 12.6.77
  cuda-nvrtc: 12.6.85
  cuda-nvtx: 12.6.77
  cuda-version: 12.6
  libtorch: 2.5.1
  pytorch: 2.5.1
  torchvision: 0.20.1
  triton: 3.2.0+git781ae0b2
#!/bin/bash
# Build Script for Triton Language
# ==========================================
#
# Note: This script includes several reboot points. After each reboot:
# 1. Reconnect to your instance using SSH
# 2. Navigate back to the script directory
# 3. Run the script again
set -euxo pipefail
echo "=== Checking system updates ==="
if dnf check-release-update; then
echo "=== No release update needed, continuing with installation ==="
elif [ $? -eq 100 ]; then
echo "=== Release update needed, performing update ==="
sudo dnf update -y
sudo dnf upgrade --releasever=latest -y
echo "=== Updates complete, rebooting system ==="
echo "After reboot:"
echo "1. Reconnect to the instance using SSH"
echo "2. Navigate back to: $(pwd)"
echo "3. Run this script again to continue the installation"
sudo reboot
exit 0
else
echo "=== Error checking for updates, exiting ==="
exit 1
fi
# Install base system dependencies
echo "=== Installing base system dependencies ==="
sudo dnf install -y git tmux
echo "=== Installing CUDA system dependencies ==="
sudo dnf install -y dkms kernel-devel kernel-modules-extra \
vulkan-devel libglvnd-devel \
elfutils-libelf-devel xorg-x11-server-Xorg
sudo systemctl enable --now dkms
# Check if CUDA is already installed
if [ -f "/usr/local/cuda/bin/nvcc" ] && [ -f "/usr/bin/nvidia-smi" ]; then
echo "=== CUDA already installed, skipping installation ==="
else
echo "=== Installing CUDA ==="
mkdir -p cuda
pushd cuda
CUDA_INSTALLER="cuda_12.6.3_560.35.05_linux_sbsa.run"
if [ ! -f "$CUDA_INSTALLER" ]; then
wget "https://developer.download.nvidia.com/compute/cuda/12.6.3/local_installers/$CUDA_INSTALLER"
fi
chmod +x ./*.run
sudo ./*.run --driver --toolkit --silent --tmpdir=`pwd`
popd
echo "=== CUDA installation complete, rebooting system ==="
echo "After reboot:"
echo "1. Reconnect to the instance using SSH"
echo "2. Navigate back to: $(pwd)"
echo "3. Run this script again to continue the installation"
sudo reboot
exit 0
fi
# Verify CUDA installation
echo "=== Verifying CUDA installation ==="
nvidia-smi
export CUDA_HOME=/usr/local/cuda
export PATH=$CUDA_HOME/bin:$PATH
$CUDA_HOME/bin/nvcc -V
echo "=== Installing Miniforge ==="
MINIFORGE_URL="https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-aarch64.sh"
if [ ! -f "Miniforge3-Linux-aarch64.sh" ]; then
wget $MINIFORGE_URL
else
echo "=== Miniforge3 installer already downloaded, skipping download ==="
fi
if [ ! -d "$HOME/miniforge3" ]; then
bash Miniforge3-Linux-aarch64.sh -b
~/miniforge3/bin/conda init bash
source ~/.bashrc
else
echo "=== Miniforge3 already installed in $HOME/miniforge3, skipping installation ==="
fi
echo "=== Creating Triton development environment ==="
if ! ~/miniforge3/bin/conda env list | grep -q "triton "; then
~/miniforge3/bin/conda create -n triton python=3.9 -y
source ~/miniforge3/bin/activate triton
mamba install -y \
pip pybind11 \
pytest scipy numpy pytest lit pandas matplotlib
# Install PyTorch with CUDA support
mamba install -y pytorch torchvision -c pytorch
pip install ninja cmake build twine wheel setuptools
else
echo "=== Triton environment already exists, activating ==="
source ~/miniforge3/bin/activate triton
fi
echo "=== Cloning Triton repository ==="
TRITON_REPO_URL="https://github.com/triton-lang/triton.git"
git clone "$TRITON_REPO_URL"
cd triton
echo "=== Installing Triton system dependencies ==="
sudo dnf install -y cmake ninja-build
echo "=== Setting build environment variables ==="
export TRITON_BUILD_WITH_CLANG_LLD=true
export TRITON_BUILD_WITH_CCACHE=true
# See https://github.com/triton-lang/triton/blob/781ae0b2f6b26c496054624d807fc64c0ed6594c/.github/workflows/wheels.yml#L56-L58
# for commentary on needing to lower MAX_JOBS from default 2 * n_cpus
export MAX_JOBS=3
echo "=== Building Triton ==="
python -m build python/ --wheel --no-isolation --verbose 2>&1 | tee triton_build.log
echo "=== Installing Triton ==="
pip install python/dist/triton*.whl 2>&1 | tee triton_install.log
echo "=== Running Triton tests ==="
(
cd python/build/cmake.linux-aarch64-cpython-3.9/ && \
echo "=== Running ctest ===" && \
ctest -j32 && \
echo "=== Running lit tests ===" && \
lit test
)
echo "=== Creating GPU test file ==="
cat > triton_test.py << 'EOL'
import torch
import triton
import triton.language as tl
import numpy as np
@triton.jit
def add_vectors(
x_ptr, # Pointer to first input vector
y_ptr, # Pointer to second input vector
output_ptr, # Pointer to output vector
n_elements, # Size of vectors
BLOCK_SIZE: tl.constexpr, # Number of elements per block
):
# We can add print statements in interpreter mode
print(f"Processing block of size {BLOCK_SIZE}")
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
# Load vectors
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
# Add them
output = x + y
# Store result
tl.store(output_ptr + offsets, output, mask=mask)
def main():
# Vector size
size = 128
# Create input vectors on GPU
x = torch.randn(size, device='cuda')
y = torch.randn(size, device='cuda')
output = torch.empty_like(x)
# Launch kernel
grid = (triton.cdiv(size, 128),)
add_vectors[grid](x, y, output, size, BLOCK_SIZE=128)
# Verify result (keeping tensors on GPU)
torch_output = x + y
max_diff = torch.max(torch.abs(output - torch_output)).item()
print(f"Max difference between Triton and Torch: {max_diff}")
# Assert the results match within a small tolerance
tolerance = 1e-6
assert max_diff <= tolerance, f"Results don't match! Max difference ({max_diff}) exceeds tolerance ({tolerance})"
print("Test passed successfully!")
if __name__ == "__main__":
main()
EOL
echo "=== Running GPU test ==="
python triton_test.py
echo "=== SUCCESS: Installation complete! All tests passed ==="
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment