Skip to content

Instantly share code, notes, and snippets.

@thesues
Created February 8, 2025 01:01
Show Gist options
  • Save thesues/4ce1ac44ce4368ab79d3d478fa05bfe2 to your computer and use it in GitHub Desktop.
Save thesues/4ce1ac44ce4368ab79d3d478fa05bfe2 to your computer and use it in GitHub Desktop.
cuda event does not work in captured graph.
import torch
import torch.nn as nn
import torch.nn.functional as F
import queue
import threading
# 定义一个简单的两层全连接模型
class SimpleModel(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super(SimpleModel, self).__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim) # 第一层全连接
self.fc2 = nn.Linear(hidden_dim, output_dim) # 第二层全连接
def forward(self, x):
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
return x
if __name__ == "__main__":
# 参数设置
input_dim = 784 # 例如 MNIST 扁平化后 28x28=784
hidden_dim = 128
output_dim = 10 # 类别数 0~9
batch_size = 32
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleModel(input_dim, hidden_dim, output_dim).to(device)
model.eval()
# 固定输入数据缓冲区
d_input = torch.randn(batch_size, input_dim, device=device)
# 固定输出缓冲区
d_output = torch.empty(batch_size, output_dim, device=device)
# 创建禁用计时的事件(这样创建的 event 内部不做计时记录,更有可能被捕获)
events = [torch.cuda.Event(enable_timing=True) for _ in range(2)]
# 预热多次确保所有动态初始化完成
for _ in range(3):
with torch.no_grad():
_ = model(d_input)
torch.cuda.synchronize()
upload_queue = queue.Queue()
def upload_worker():
while True:
event = upload_queue.get()
event.synchronize()
print("event done")
upload_thread = threading.Thread(target=upload_worker, daemon=True)
upload_thread.start()
for event in events:
upload_queue.put(event)
# 捕获 CUDA Graph
stream = torch.cuda.Stream()
graph = torch.cuda.CUDAGraph()
with torch.cuda.stream(stream):
with torch.no_grad():
with torch.cuda.graph(graph):
# 假设模型有两个子模块(fc1 和 fc2),分层捕获
x = d_input
for i, layer in enumerate(model.children()):
x = layer(x)
# 记录事件(注意:此处使用禁用计时的 event)
events[i].record()
# 假设最后的结果写入 d_output
d_output.copy_(x)
torch.cuda.synchronize(stream.device_index)
print("Captured graph output shape:", d_output.shape)
# 更新输入数据(保证尺寸和内存地址不变)
new_data = torch.ones_like(d_input)
d_input.copy_(new_data)
print("Start replay")
# these events CAN not BE reused in replay
for event in events:
upload_queue.put(event)
# 在 replay 之前,若需要将事件传递给其他线程处理,也可以先将事件放入队列,
# 但注意事件的 record() 已经在捕获图中完成。如果你希望在 replay 后重新记录事件,
# 可以考虑在 replay 后再调用 event.record()(在 capture 之外)
# 这里直接 replay 图:
with torch.cuda.stream(stream):
graph.replay()
torch.cuda.synchronize()
print("After replay, output shape:", d_output.shape)
print("Output sample:", d_output[0])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment