Created
February 8, 2025 01:01
-
-
Save thesues/4ce1ac44ce4368ab79d3d478fa05bfe2 to your computer and use it in GitHub Desktop.
cuda event does not work in captured graph.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import queue | |
import threading | |
# 定义一个简单的两层全连接模型 | |
class SimpleModel(nn.Module): | |
def __init__(self, input_dim, hidden_dim, output_dim): | |
super(SimpleModel, self).__init__() | |
self.fc1 = nn.Linear(input_dim, hidden_dim) # 第一层全连接 | |
self.fc2 = nn.Linear(hidden_dim, output_dim) # 第二层全连接 | |
def forward(self, x): | |
x = self.fc1(x) | |
x = F.relu(x) | |
x = self.fc2(x) | |
return x | |
if __name__ == "__main__": | |
# 参数设置 | |
input_dim = 784 # 例如 MNIST 扁平化后 28x28=784 | |
hidden_dim = 128 | |
output_dim = 10 # 类别数 0~9 | |
batch_size = 32 | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = SimpleModel(input_dim, hidden_dim, output_dim).to(device) | |
model.eval() | |
# 固定输入数据缓冲区 | |
d_input = torch.randn(batch_size, input_dim, device=device) | |
# 固定输出缓冲区 | |
d_output = torch.empty(batch_size, output_dim, device=device) | |
# 创建禁用计时的事件(这样创建的 event 内部不做计时记录,更有可能被捕获) | |
events = [torch.cuda.Event(enable_timing=True) for _ in range(2)] | |
# 预热多次确保所有动态初始化完成 | |
for _ in range(3): | |
with torch.no_grad(): | |
_ = model(d_input) | |
torch.cuda.synchronize() | |
upload_queue = queue.Queue() | |
def upload_worker(): | |
while True: | |
event = upload_queue.get() | |
event.synchronize() | |
print("event done") | |
upload_thread = threading.Thread(target=upload_worker, daemon=True) | |
upload_thread.start() | |
for event in events: | |
upload_queue.put(event) | |
# 捕获 CUDA Graph | |
stream = torch.cuda.Stream() | |
graph = torch.cuda.CUDAGraph() | |
with torch.cuda.stream(stream): | |
with torch.no_grad(): | |
with torch.cuda.graph(graph): | |
# 假设模型有两个子模块(fc1 和 fc2),分层捕获 | |
x = d_input | |
for i, layer in enumerate(model.children()): | |
x = layer(x) | |
# 记录事件(注意:此处使用禁用计时的 event) | |
events[i].record() | |
# 假设最后的结果写入 d_output | |
d_output.copy_(x) | |
torch.cuda.synchronize(stream.device_index) | |
print("Captured graph output shape:", d_output.shape) | |
# 更新输入数据(保证尺寸和内存地址不变) | |
new_data = torch.ones_like(d_input) | |
d_input.copy_(new_data) | |
print("Start replay") | |
# these events CAN not BE reused in replay | |
for event in events: | |
upload_queue.put(event) | |
# 在 replay 之前,若需要将事件传递给其他线程处理,也可以先将事件放入队列, | |
# 但注意事件的 record() 已经在捕获图中完成。如果你希望在 replay 后重新记录事件, | |
# 可以考虑在 replay 后再调用 event.record()(在 capture 之外) | |
# 这里直接 replay 图: | |
with torch.cuda.stream(stream): | |
graph.replay() | |
torch.cuda.synchronize() | |
print("After replay, output shape:", d_output.shape) | |
print("Output sample:", d_output[0]) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment