Skip to content

Instantly share code, notes, and snippets.

@shunting314
Created February 23, 2024 16:55
Show Gist options
  • Save shunting314/0667d69299a5d9435b2931da2d2df476 to your computer and use it in GitHub Desktop.
Save shunting314/0667d69299a5d9435b2931da2d2df476 to your computer and use it in GitHub Desktop.
"""
Augment kernel metadata generated by kernel_metadata metric table in inductor.
For each row in input, use NCU to profile the kernel. The corresponding output row
contains more metadata gathered by NCU.
It can be super slow to run NCU. e.g. for the 10K kernels gathered from Huggingface,
it took almost a whole day to run NCU for each unique kernels. The script thus cache
the ncu output in the file system. If the ncu output is cached, we don't run NCU again.
Example input: https://gist.github.com/shunting314/22995da0da8b66d4cf989cb7f0508399
Example output: https://gist.github.com/shunting314/cb36615e8b6e4143de2fba246db9244e
"""
import argparse
import csv
import dataclasses
from typing import List
import subprocess
import os
import itertools
import re
from typing import Optional
@dataclasses.dataclass
class LogLine:
model_name: str
kernel_name: str
kernel_path: str
kernel_category: str
size_hints: List[int]
reduction_hint: str
line_of_code: int
num_load:int
num_store:int
num_for_loop: int
num_atomic_add: int
num_args: int
xnumel: int
ynumel:int
rnumel:int
kernel_args_num_gb: float
# augmented fields
latency_us_under_profiling: Optional[float] = None
ncu_duration_us: Optional[float] = None
ncu_mem_bw_gbps: Optional[float] = None
ncu_mem_accessed_gb: Optional[float] = None
@property
def ncu_log_path(self):
return self.kernel_path[:-3] + ".ncu"
def valid_ncu_log_found(self):
ncu_log_path = self.ncu_log_path
if not os.path.exists(ncu_log_path):
return False
with open(ncu_log_path, "r") as f:
log_content = f.read()
return self.is_valid_ncu_output(log_content)
@staticmethod
def is_valid_ncu_output(ncu_output):
# this section reports memory bandwidth usage of the kernel
return "Section: Memory Workload Analysis" in ncu_output
@staticmethod
def _parse_latency_us_under_profiling(log_content):
"""
Match line like: 0.027ms 0.004GB 147.69GB/s
"""
m = re.search(r"^([\d.]+)ms\s+[\d.]+GB\s*[\d.]+GB/s$", log_content, re.M)
assert m, "benchmark output not found"
return float(m.group(1)) * 1000 # ms -> us
@staticmethod
def _parse_ncu_duration_us(log_content):
"""
Match line like: Duration usecond 6.21
"""
m = re.search(r"^\s+Duration\s+([a-z]+)\s+([0-9.]+)\s*$", log_content, re.M)
assert m, "ncu duration not found"
unit = m.group(1)
quantity = float(m.group(2))
if unit == "usecond":
return quantity
elif unit == "msecond":
return quantity * 1000
elif unit == "second":
return quantity * 1000000
else:
raise RuntimeError(f"Un-recognized unit {unit}")
@staticmethod
def _parse_ncu_mem_bw_gbps(log_content):
"""
Match line like: Memory Throughput Gbyte/second 4.39
Sometimes the number contains comma. E.g.
Memory Throughput Gbyte/second 1,000.00
"""
m = re.search(r"^\s+Memory Throughput\s+([KMGT])byte/second\s+([0-9.,]+)\s*$", log_content, re.M)
assert m, "ncu memory bw not found"
unit = m.group(1)
quantity = float(m.group(2).replace(",", ""))
if unit == "K":
return quantity / 1000000.0
elif unit == "M":
return quantity / 1000.0
elif unit == "G":
return quantity
elif unit == "T":
return quantity * 1000 # 1000 or 1024?
else:
raise RuntimeError(f"Un-recognized unit for mem bw: {unit}")
def parse_ncu_output(self):
with open(self.ncu_log_path, "r") as f:
log_content = f.read()
self.latency_us_under_profiling = self._parse_latency_us_under_profiling(log_content)
self.ncu_duration_us = self._parse_ncu_duration_us(log_content)
self.ncu_mem_bw_gbps = self._parse_ncu_mem_bw_gbps(log_content)
self.ncu_mem_accessed_gb = self.ncu_mem_bw_gbps * self.ncu_duration_us / 1e6
ncu_run_count = itertools.count()
def run_ncu(self):
"""
NOTE: unlike in a shell, surrounding regex:triton with a pair of quotes
(i.e. become "regex:triton") will cause no kernel being found since the quotes
will be passed to ncu. While when running in a shell, the shell will remove the quotes.
"""
if self.valid_ncu_log_found():
return
print(f"Run ncu for kernel {self.kernel_name} at {self.kernel_path}")
ncu_args = f"""
ncu -k regex:triton -c 1 --set full python {self.kernel_path}
""".strip().split()
ncu_out = subprocess.check_output(ncu_args).decode()
if not self.is_valid_ncu_output(ncu_out):
raise RuntimeError(f"Invalid ncu output generated for kernel {self.kernel_path}. Output: {ncu_out[:8192]}")
with open(self.ncu_log_path, "w") as f:
f.write(ncu_out)
print(f"{next(self.ncu_run_count)}: ncu output generated at {self.ncu_log_path}")
def write_line_obj_list_to_csv(line_obj_list, output_csv):
assert len(line_obj_list) > 0
field_names = [f.name for f in dataclasses.fields(line_obj_list[0])]
with open(output_csv, "w") as f:
writer = csv.writer(f)
writer.writerow(field_names)
for line in line_obj_list:
values = [getattr(line, name) for name in field_names]
writer.writerow(values)
print(f"Output is written to {output_csv}")
args = None
def main():
parser = argparse.ArgumentParser(description="Parse the CSV file with the metadata for inductor generated triton kernels")
parser.add_argument(
"-i", "--input-csv", help="The input CSV file to parse", required=True
)
parser.add_argument(
"-o", "--output-csv", help="The output CSV file. Each row may contain augmented fields compared to the row in input", required=True
)
global args
args = parser.parse_args()
line_obj_list = []
with open(args.input_csv) as f:
csv_reader = csv.reader(f)
header = next(csv_reader)
for line in csv_reader:
try:
line_obj = LogLine(**{k: v for k, v in zip(header, line)})
except TypeError:
print(f"Invalid log line {line}")
raise
if line_obj.kernel_category == "foreach":
# We don't have benchmark harness generated for foreach kernels. So skip.
continue
line_obj.run_ncu()
line_obj.parse_ncu_output()
line_obj_list.append(line_obj)
write_line_obj_list_to_csv(line_obj_list, args.output_csv)
if __name__ == "__main__":
main()
print("bye")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment