Created
February 23, 2024 16:55
-
-
Save shunting314/0667d69299a5d9435b2931da2d2df476 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Augment kernel metadata generated by kernel_metadata metric table in inductor. | |
For each row in input, use NCU to profile the kernel. The corresponding output row | |
contains more metadata gathered by NCU. | |
It can be super slow to run NCU. e.g. for the 10K kernels gathered from Huggingface, | |
it took almost a whole day to run NCU for each unique kernels. The script thus cache | |
the ncu output in the file system. If the ncu output is cached, we don't run NCU again. | |
Example input: https://gist.github.com/shunting314/22995da0da8b66d4cf989cb7f0508399 | |
Example output: https://gist.github.com/shunting314/cb36615e8b6e4143de2fba246db9244e | |
""" | |
import argparse | |
import csv | |
import dataclasses | |
from typing import List | |
import subprocess | |
import os | |
import itertools | |
import re | |
from typing import Optional | |
@dataclasses.dataclass | |
class LogLine: | |
model_name: str | |
kernel_name: str | |
kernel_path: str | |
kernel_category: str | |
size_hints: List[int] | |
reduction_hint: str | |
line_of_code: int | |
num_load:int | |
num_store:int | |
num_for_loop: int | |
num_atomic_add: int | |
num_args: int | |
xnumel: int | |
ynumel:int | |
rnumel:int | |
kernel_args_num_gb: float | |
# augmented fields | |
latency_us_under_profiling: Optional[float] = None | |
ncu_duration_us: Optional[float] = None | |
ncu_mem_bw_gbps: Optional[float] = None | |
ncu_mem_accessed_gb: Optional[float] = None | |
@property | |
def ncu_log_path(self): | |
return self.kernel_path[:-3] + ".ncu" | |
def valid_ncu_log_found(self): | |
ncu_log_path = self.ncu_log_path | |
if not os.path.exists(ncu_log_path): | |
return False | |
with open(ncu_log_path, "r") as f: | |
log_content = f.read() | |
return self.is_valid_ncu_output(log_content) | |
@staticmethod | |
def is_valid_ncu_output(ncu_output): | |
# this section reports memory bandwidth usage of the kernel | |
return "Section: Memory Workload Analysis" in ncu_output | |
@staticmethod | |
def _parse_latency_us_under_profiling(log_content): | |
""" | |
Match line like: 0.027ms 0.004GB 147.69GB/s | |
""" | |
m = re.search(r"^([\d.]+)ms\s+[\d.]+GB\s*[\d.]+GB/s$", log_content, re.M) | |
assert m, "benchmark output not found" | |
return float(m.group(1)) * 1000 # ms -> us | |
@staticmethod | |
def _parse_ncu_duration_us(log_content): | |
""" | |
Match line like: Duration usecond 6.21 | |
""" | |
m = re.search(r"^\s+Duration\s+([a-z]+)\s+([0-9.]+)\s*$", log_content, re.M) | |
assert m, "ncu duration not found" | |
unit = m.group(1) | |
quantity = float(m.group(2)) | |
if unit == "usecond": | |
return quantity | |
elif unit == "msecond": | |
return quantity * 1000 | |
elif unit == "second": | |
return quantity * 1000000 | |
else: | |
raise RuntimeError(f"Un-recognized unit {unit}") | |
@staticmethod | |
def _parse_ncu_mem_bw_gbps(log_content): | |
""" | |
Match line like: Memory Throughput Gbyte/second 4.39 | |
Sometimes the number contains comma. E.g. | |
Memory Throughput Gbyte/second 1,000.00 | |
""" | |
m = re.search(r"^\s+Memory Throughput\s+([KMGT])byte/second\s+([0-9.,]+)\s*$", log_content, re.M) | |
assert m, "ncu memory bw not found" | |
unit = m.group(1) | |
quantity = float(m.group(2).replace(",", "")) | |
if unit == "K": | |
return quantity / 1000000.0 | |
elif unit == "M": | |
return quantity / 1000.0 | |
elif unit == "G": | |
return quantity | |
elif unit == "T": | |
return quantity * 1000 # 1000 or 1024? | |
else: | |
raise RuntimeError(f"Un-recognized unit for mem bw: {unit}") | |
def parse_ncu_output(self): | |
with open(self.ncu_log_path, "r") as f: | |
log_content = f.read() | |
self.latency_us_under_profiling = self._parse_latency_us_under_profiling(log_content) | |
self.ncu_duration_us = self._parse_ncu_duration_us(log_content) | |
self.ncu_mem_bw_gbps = self._parse_ncu_mem_bw_gbps(log_content) | |
self.ncu_mem_accessed_gb = self.ncu_mem_bw_gbps * self.ncu_duration_us / 1e6 | |
ncu_run_count = itertools.count() | |
def run_ncu(self): | |
""" | |
NOTE: unlike in a shell, surrounding regex:triton with a pair of quotes | |
(i.e. become "regex:triton") will cause no kernel being found since the quotes | |
will be passed to ncu. While when running in a shell, the shell will remove the quotes. | |
""" | |
if self.valid_ncu_log_found(): | |
return | |
print(f"Run ncu for kernel {self.kernel_name} at {self.kernel_path}") | |
ncu_args = f""" | |
ncu -k regex:triton -c 1 --set full python {self.kernel_path} | |
""".strip().split() | |
ncu_out = subprocess.check_output(ncu_args).decode() | |
if not self.is_valid_ncu_output(ncu_out): | |
raise RuntimeError(f"Invalid ncu output generated for kernel {self.kernel_path}. Output: {ncu_out[:8192]}") | |
with open(self.ncu_log_path, "w") as f: | |
f.write(ncu_out) | |
print(f"{next(self.ncu_run_count)}: ncu output generated at {self.ncu_log_path}") | |
def write_line_obj_list_to_csv(line_obj_list, output_csv): | |
assert len(line_obj_list) > 0 | |
field_names = [f.name for f in dataclasses.fields(line_obj_list[0])] | |
with open(output_csv, "w") as f: | |
writer = csv.writer(f) | |
writer.writerow(field_names) | |
for line in line_obj_list: | |
values = [getattr(line, name) for name in field_names] | |
writer.writerow(values) | |
print(f"Output is written to {output_csv}") | |
args = None | |
def main(): | |
parser = argparse.ArgumentParser(description="Parse the CSV file with the metadata for inductor generated triton kernels") | |
parser.add_argument( | |
"-i", "--input-csv", help="The input CSV file to parse", required=True | |
) | |
parser.add_argument( | |
"-o", "--output-csv", help="The output CSV file. Each row may contain augmented fields compared to the row in input", required=True | |
) | |
global args | |
args = parser.parse_args() | |
line_obj_list = [] | |
with open(args.input_csv) as f: | |
csv_reader = csv.reader(f) | |
header = next(csv_reader) | |
for line in csv_reader: | |
try: | |
line_obj = LogLine(**{k: v for k, v in zip(header, line)}) | |
except TypeError: | |
print(f"Invalid log line {line}") | |
raise | |
if line_obj.kernel_category == "foreach": | |
# We don't have benchmark harness generated for foreach kernels. So skip. | |
continue | |
line_obj.run_ncu() | |
line_obj.parse_ncu_output() | |
line_obj_list.append(line_obj) | |
write_line_obj_list_to_csv(line_obj_list, args.output_csv) | |
if __name__ == "__main__": | |
main() | |
print("bye") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment