Created
April 26, 2026 16:03
-
-
Save nh2/1f61a60779dfa5fc303720dbb017dc81 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """Minimal repro: torch.profiler CUDA events on worker threads get wrong thread IDs. | |
| When `torch.profiler.profile()` is active and CUDA operations happen on | |
| worker threads (e.g. from `ThreadPoolExecutor`), the exported Chrome trace | |
| has CUDA runtime events (like `cudaLaunchKernel`) with `tid` values based | |
| on `pthread_self()` instead of `gettid()` (Linux kernel TID). | |
| This happens because `torch.profiler.profile()` registers `RecordFunction` | |
| callbacks via `at::addThreadLocalCallback()`, so only the thread that | |
| called `profile().__enter__()` gets callbacks. Kineto's | |
| `CuptiActivityProfiler::recordThreadInfo()` -- which maps `pthread_self()` | |
| to `gettid()` -- is only called from `ThreadLocalSubqueue`'s constructor, | |
| which only runs when `RecordFunction` fires on a thread. Worker threads | |
| never fire `RecordFunction`, so they are never registered, and CUPTI's | |
| raw `pthread_self()`-based thread IDs pass through to the trace output. | |
| This causes worker thread CUDA events to appear on extra unnamed rows in | |
| trace viewers like Perfetto, instead of on the Python thread that issued | |
| them. | |
| The underlying issue is in PyTorch's C++ code (`at::addThreadLocalCallback()` | |
| vs `at::addGlobalCallback()`), so it should be independent of the Python | |
| version. This script confirms that empirically. | |
| """ | |
| import ctypes | |
| import ctypes.util | |
| import json | |
| import os | |
| import sys | |
| import tempfile | |
| import threading | |
| from concurrent.futures import ThreadPoolExecutor | |
| def _get_libc() -> ctypes.CDLL: | |
| path = ctypes.util.find_library("c") | |
| if path is None: | |
| print("ERROR: Cannot find libc.", file=sys.stderr) | |
| sys.exit(1) | |
| return ctypes.CDLL(path, use_errno=True) | |
| def _get_libpthread() -> ctypes.CDLL: | |
| path = ctypes.util.find_library("pthread") | |
| if path is None: | |
| print("ERROR: Cannot find libpthread.", file=sys.stderr) | |
| sys.exit(1) | |
| return ctypes.CDLL(path, use_errno=True) | |
| _libc = _get_libc() | |
| _libpthread = _get_libpthread() | |
| _libpthread.pthread_self.restype = ctypes.c_ulong | |
| SYS_GETTID = 186 # x86_64 Linux | |
| def get_gettid() -> int: | |
| """Get the Linux kernel thread ID (`gettid()` / `SYS_gettid`).""" | |
| return _libc.syscall(SYS_GETTID) | |
| def get_pthread_self_i32() -> int: | |
| """Get `pthread_self()` truncated to signed `int32_t`. | |
| This matches Kineto's `threadId()` in `ThreadUtil.cpp`: | |
| `int32_t* ptr = reinterpret_cast<int32_t*>(&pth); return *ptr;` | |
| """ | |
| pth = _libpthread.pthread_self() | |
| return ctypes.c_int32(pth & 0xFFFFFFFF).value | |
| def cuda_work_on_thread() -> dict[str, int]: | |
| """Do some CUDA work and return thread ID info.""" | |
| sys_tid = get_gettid() | |
| pthread_i32 = get_pthread_self_i32() | |
| x = torch.randn(100, 100, device="cuda") | |
| _ = torch.matmul(x, x) | |
| torch.cuda.synchronize() | |
| return { | |
| "gettid": sys_tid, | |
| "pthread_self_i32": pthread_i32, | |
| "thread_name": threading.current_thread().name, | |
| } | |
| def main() -> None: | |
| print(f"Python {sys.version}") | |
| print(f"PyTorch {torch.__version__}") | |
| print() | |
| main_info = { | |
| "gettid": get_gettid(), | |
| "pthread_self_i32": get_pthread_self_i32(), | |
| "thread_name": "main", | |
| } | |
| print(f"Main thread: gettid={main_info['gettid']}, pthread_self_i32={main_info['pthread_self_i32']}") | |
| with profile(activities=supported_activities()) as prof: | |
| # CUDA work on main thread (profiler-starting thread) | |
| x = torch.randn(100, 100, device="cuda") | |
| _ = torch.matmul(x, x) | |
| torch.cuda.synchronize() | |
| # CUDA work on a worker thread | |
| with ThreadPoolExecutor(max_workers=1, thread_name_prefix="worker") as executor: | |
| future = executor.submit(cuda_work_on_thread) | |
| worker_info = future.result() | |
| print(f"Worker thread: gettid={worker_info['gettid']}, pthread_self_i32={worker_info['pthread_self_i32']}") | |
| print() | |
| # Export and analyze the Chrome trace | |
| with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as f: | |
| tmppath = f.name | |
| try: | |
| prof.export_chrome_trace(tmppath) | |
| with open(tmppath) as f: | |
| trace = json.load(f) | |
| # Collect thread info from trace | |
| tid_to_cats: dict[int, dict[str, int]] = {} | |
| tid_to_name: dict[int, str] = {} | |
| for ev in trace.get("traceEvents", []): | |
| tid = ev.get("tid") | |
| if tid is None: | |
| continue | |
| if ev.get("ph") == "M" and ev.get("name") == "thread_name": | |
| tid_to_name[tid] = ev.get("args", {}).get("name", "?") | |
| cat = ev.get("cat", "") | |
| if cat: | |
| tid_to_cats.setdefault(tid, {}) | |
| tid_to_cats[tid][cat] = tid_to_cats[tid].get(cat, 0) + 1 | |
| print("Trace thread IDs:") | |
| all_tids = sorted( | |
| set(list(tid_to_cats.keys()) + list(tid_to_name.keys())), | |
| key=lambda t: (isinstance(t, str), str(t)), | |
| ) | |
| for tid in all_tids: | |
| name = tid_to_name.get(tid, "(no name)") | |
| cats = tid_to_cats.get(tid, {}) | |
| print(f" tid={tid} ({name}): {cats}") | |
| print() | |
| # Check: do CUDA runtime events have gettid-based or pthread_self-based tids? | |
| known_gettids = {main_info["gettid"], worker_info["gettid"]} | |
| known_pthread_i32s = {main_info["pthread_self_i32"], worker_info["pthread_self_i32"]} | |
| # Kineto commit eb1713f ("Prevent Negative TIDs in Trace", Oct 2024, | |
| # in PyTorch >= ~2.5) applies `abs()` to all tids before writing JSON. | |
| known_pthread_abs = {abs(v) for v in known_pthread_i32s} | |
| # All non-gettid representations of pthread_self that Kineto might write: | |
| known_pthread_all = known_pthread_i32s | known_pthread_abs | |
| cuda_tids: set[int] = set() | |
| for ev in trace.get("traceEvents", []): | |
| if ev.get("cat") == "cuda_runtime": | |
| cuda_tids.add(ev.get("tid")) | |
| print("Analysis:") | |
| print(f" Known gettid values: {known_gettids}") | |
| print(f" Known pthread_self_i32 values: {known_pthread_i32s}") | |
| print(f" Known abs(pthread_self_i32): {known_pthread_abs}") | |
| print(f" CUDA runtime event tids: {cuda_tids}") | |
| print() | |
| unmapped_tids = cuda_tids - known_gettids | |
| pthread_matched = unmapped_tids & known_pthread_all | |
| if not unmapped_tids: | |
| print("RESULT: All CUDA runtime events use gettid-based tids.") | |
| print(" Kineto correctly remapped all threads.") | |
| print(" The pthread_self -> gettid remapping workaround is NOT needed.") | |
| elif pthread_matched: | |
| signed_only = pthread_matched & known_pthread_i32s - known_pthread_abs | |
| abs_only = pthread_matched & known_pthread_abs - known_pthread_i32s | |
| both = pthread_matched & known_pthread_i32s & known_pthread_abs | |
| if both: | |
| print(f"RESULT: {len(both)} CUDA tid(s) match pthread_self_i32 (positive): {both}") | |
| if signed_only: | |
| print(f"RESULT: {len(signed_only)} CUDA tid(s) match signed pthread_self_i32: {signed_only}") | |
| print(" (Kineto without sanitizeTid, PyTorch <= 2.3)") | |
| if abs_only: | |
| print(f"RESULT: {len(abs_only)} CUDA tid(s) match abs(pthread_self_i32): {abs_only}") | |
| print(" (Kineto with sanitizeTid, PyTorch >= 2.5)") | |
| print(" Kineto did NOT remap worker thread CUDA events to gettid.") | |
| print(" The pthread_self -> gettid remapping workaround IS needed.") | |
| else: | |
| print(f"RESULT: {len(unmapped_tids)} CUDA tid(s) are unrecognized: {unmapped_tids}") | |
| print(" These may be from internal CUDA/driver threads.") | |
| finally: | |
| os.unlink(tmppath) | |
| if __name__ == "__main__": | |
| import torch | |
| if not torch.cuda.is_available(): | |
| print("ERROR: CUDA not available, cannot test CUPTI thread ID mapping.", file=sys.stderr) | |
| sys.exit(1) | |
| from torch.profiler import profile, supported_activities | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment