Created
December 29, 2024 15:44
-
-
Save cloneofsimo/e00eca9eeb46f136a71652d0283917bb to your computer and use it in GitHub Desktop.
Are MFU correlated with watt usage in practice?
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import torch | |
import time | |
import random | |
import numpy as np | |
import multiprocessing | |
from multiprocessing import Process, Manager, Event | |
import plotly.express as px | |
import plotly.io as pio | |
import math | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
try: | |
import pynvml | |
nvml_available = True | |
except ImportError: | |
nvml_available = False | |
def power_monitor(device_index, power_list, stop_event, sample_delay=0.02): | |
pynvml.nvmlInit() | |
handle = pynvml.nvmlDeviceGetHandleByIndex(device_index) | |
while not stop_event.is_set(): | |
pwr = pynvml.nvmlDeviceGetPowerUsage(handle) | |
power_list.append(pwr) | |
time.sleep(sample_delay) | |
def measure_time_and_power(op, warmup=3, iters=1, device_index=0): | |
if not nvml_available: | |
raise ValueError("Power monitoring is not available. Please install pynvml.") | |
for _ in range(warmup): | |
op() | |
torch.cuda.synchronize() | |
manager = Manager() | |
power_list = manager.list() | |
stop_event = Event() | |
p = Process(target=power_monitor, args=(device_index, power_list, stop_event)) | |
p.start() | |
time.sleep(0.5) | |
start_time = time.time() | |
for _ in range(iters): | |
op() | |
torch.cuda.synchronize() | |
elapsed = (time.time() - start_time) / iters | |
stop_event.set() | |
p.join() | |
power_array = np.array(power_list, dtype=np.float32) | |
if len(power_array) == 0: | |
avg_power = float('nan') | |
else: | |
avg_power = float(power_array.mean()) | |
return elapsed, avg_power | |
@torch.no_grad() | |
def function1(A, B): | |
C = torch.matmul(A, B) | |
C = torch.relu(C) | |
C2 = C.clone() | |
return C2 | |
@torch.no_grad() | |
def function2(A, B): | |
C = torch.matmul(A, B) | |
C = C + 1.0 | |
C = C * 2.0 | |
C = torch.square(C) | |
C2 = C.clone() | |
return C2 | |
@torch.no_grad() | |
def function3(A, B): | |
C1 = torch.matmul(A, B) | |
C2 = torch.matmul(C1, B.T) | |
C2c = C2.clone() | |
return C2c | |
@torch.no_grad() | |
def function4(A, B): | |
C = torch.matmul(A, B) | |
out = torch.nn.functional.softmax(C, dim=-1).clone() | |
return out | |
@torch.no_grad() | |
def function5(A, B): | |
out1 = torch.matmul(A, B) | |
out2 = torch.matmul(out1, B.T) | |
out3 = torch.matmul(out2, B).clone() | |
return out3 | |
def flops_func1(M, K, N): | |
return 2.0*M*K*N + M*N | |
def flops_func2(M, K, N): | |
return 2.0*M*K*N + 3.0*(M*N) | |
def flops_func3(M, K, N): | |
return 2.0*M*K*N + 2.0*M*N*K | |
def flops_func4(M, K, N): | |
return 2.0*M*K*N + 2.0*M*N | |
def flops_func5(M, K, N): | |
return 2.0*M*K*N + 2.0*M*N*K + 2.0*M*K*N | |
def exact_flops(func, M, K, N): | |
if func == function1: | |
return flops_func1(M, K, N) | |
elif func == function2: | |
return flops_func2(M, K, N) | |
elif func == function3: | |
return flops_func3(M, K, N) | |
elif func == function4: | |
return flops_func4(M, K, N) | |
elif func == function5: | |
return flops_func5(M, K, N) | |
else: | |
return float('nan') | |
def mem_func1(M, K, N): | |
return (M*K + K*N) + 3*(M*N) | |
def mem_func2(M, K, N): | |
return (M*K + K*N) + (M*N) + 2*(M*N) + (M*N) | |
def mem_func3(M, K, N): | |
reads = (M*K + K*N) + (M*N + N*K) | |
writes = (M*N) + (M*K) + 2*(M*K) | |
return reads + writes | |
def mem_func4(M, K, N): | |
reads = (M*K + K*N) + (M*N) + (M*N) | |
writes = (M*N) + (M*N) | |
return (M*K + K*N) + (M*N) + 2*(M*N) + (M*N) | |
def mem_func5(M, K, N): | |
reads = ( | |
(M*K + K*N) + | |
(M*N + N*K) + | |
(M*K + K*N) | |
) | |
writes = (M*N) + (M*K) + (M*N) + 2*(M*N) | |
return reads + writes | |
def exact_mem_access(func, M, K, N): | |
if func == function1: | |
return mem_func1(M, K, N) | |
elif func == function2: | |
return mem_func2(M, K, N) | |
elif func == function3: | |
return mem_func3(M, K, N) | |
elif func == function4: | |
return mem_func4(M, K, N) | |
elif func == function5: | |
return mem_func5(M, K, N) | |
else: | |
return float('nan') | |
def main(): | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
sample_in_uniform = False | |
functions = [function1, function2, function3, function4, function5] | |
func_labels = ["Func1","Func2","Func3","Func4","Func5"] | |
mfu_data = [] | |
power_data = [] | |
mem_data = [] | |
func_id_data = [] | |
shape_labels = [] | |
min_size, max_size = 2000, 108000 | |
random.seed(42) | |
for i, func in enumerate(functions, start=1): | |
for shape_id in range(8): | |
if sample_in_uniform: | |
M = random.randint(min_size, max_size) | |
K = random.randint(min_size, max_size) | |
N = random.randint(min_size, max_size) | |
else: | |
log2max = math.log2(max_size) | |
log2min = math.log2(min_size) | |
M = 2 ** (random.uniform(log2min, log2max)) | |
K = 2 ** (random.uniform(log2min, log2max)) | |
N = 2 ** (random.uniform(log2min, log2max)) | |
M, K, N = (int(M) // 256) * 256, (int(K) // 256) * 256, (int(N) // 256) * 256 | |
A = torch.randn(M, K, dtype=torch.bfloat16, device=device) | |
B = torch.randn(K, N, dtype=torch.bfloat16, device=device) | |
flops = exact_flops(func, M, K, N) | |
mem_count = exact_mem_access(func, M, K, N) | |
def op(): | |
return func(A, B) | |
avgelapsed_time, avg_power_mW = measure_time_and_power( | |
op, warmup=3, iters=10, device_index=0 | |
) | |
if avgelapsed_time > 0: | |
mfu = (flops / avgelapsed_time) / 1e15 | |
else: | |
mfu = float('nan') | |
mfu_data.append(mfu) | |
power_data.append(avg_power_mW) | |
mem_data.append(mem_count) | |
func_id_data.append(func_labels[i-1]) | |
shape_labels.append(f"{M}x{K}x{N}") | |
print(f"Func#{i}, shape#{shape_id+1}, [M,K,N]={[M,K,N]}, " | |
f"FLOPs={flops/1e9:.2f} G, time={avgelapsed_time:.2f}s, " | |
f"MFU={mfu:.2f}, power={avg_power_mW/1000:.2f} W, " | |
f"memAccess={mem_count:e} elems") | |
torch.cuda.empty_cache() | |
torch.cuda.synchronize() | |
mfu_array = np.array(mfu_data) | |
power_array = np.array(power_data) | |
mem_array = np.array(mem_data, dtype=np.float64) | |
valid_mask = ~np.isnan(mfu_array) & ~np.isnan(power_array) | |
if valid_mask.sum() > 1: | |
corr = np.corrcoef(mfu_array[valid_mask], power_array[valid_mask])[0, 1] | |
print(f"\nCorrelation coefficient (MFU vs Power): {corr:.4f}") | |
else: | |
corr = float('nan') | |
print("Not enough valid data to compute correlation.") | |
log_mem = np.log10(mem_array) | |
import pandas as pd | |
df = pd.DataFrame({ | |
"MFU": mfu_array, | |
"Power_mW": power_array, | |
"MemoryAccess": mem_array, | |
"logMem": log_mem, | |
"Function": func_id_data, | |
"Shape": shape_labels | |
}) | |
fig = px.scatter( | |
df, | |
x="MFU", | |
y="Power_mW", | |
color="logMem", | |
symbol="Function", | |
hover_data=["Shape", "MemoryAccess"], | |
title=f"GPU MFU vs Power (colored by log10(MemAccess)), Corr={corr:.4f}" | |
) | |
fig.update_traces(marker=dict(size=8, line=dict(width=1, color='DarkSlateGrey'))) | |
fig.update_layout( | |
xaxis_title="MFU [FLOPs/s]", | |
yaxis_title="Avg GPU Power [mW]", | |
legend_title="Function ID", | |
legend=dict(orientation="v", y=1, x=0.1), | |
coloraxis_colorbar=dict(title="log10(MemAccess)") | |
) | |
pio.write_html(fig, file="mfu_vs_power_mem.html", auto_open=False) | |
print("Plotly figure saved to mfu_vs_power_mem.html") | |
fig.show() | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Modified to run more runs with sleep in between just in case, and stuff, but conclusion remains the same