Skip to content

Instantly share code, notes, and snippets.

@snarkmaster
Created August 10, 2025 07:53
Show Gist options
  • Select an option

  • Save snarkmaster/f85ad34dcd0b0f3a810685fc5379948e to your computer and use it in GitHub Desktop.

Select an option

Save snarkmaster/f85ad34dcd0b0f3a810685fc5379948e to your computer and use it in GitHub Desktop.
Quick-and-dirty PoC of `await_suspend_destroy` that returns a handle
diff --git a/llvm-project/clang/lib/CodeGen/CGCoroutine.cpp b/third-party/tp2/llvm-fb/19/llvm-project/clang/lib/CodeGen/CGCoroutine.cpp
--- a/llvm-project/clang/lib/CodeGen/CGCoroutine.cpp
+++ b/llvm-project/clang/lib/CodeGen/CGCoroutine.cpp
@@ -175,15 +175,16 @@
// Check if this suspend should be calling `await_suspend_destroy`
static bool useCoroAwaitSuspendDestroy(const CoroutineSuspendExpr &S) {
- // This can only be an `await_suspend_destroy` suspend expression if it
- // returns void -- `buildCoawaitCalls` in `SemaCoroutine.cpp` asserts this.
- // Moreover, when `await_suspend` returns a handle, the outermost method call
- // is `.address()` -- making it harder to get the actual class or method.
+/* XXX -- don't need this any more, if the new logic is right
+ // This can be an `await_suspend_destroy` suspend expression if it
+ // returns void or handle. Bool returns are not supported.
if (S.getSuspendReturnType() !=
- CoroutineSuspendExpr::SuspendReturnType::SuspendVoid) {
+ CoroutineSuspendExpr::SuspendReturnType::SuspendVoid &&
+ S.getSuspendReturnType() !=
+ CoroutineSuspendExpr::SuspendReturnType::SuspendHandle) {
return false;
}
-
+*/
// `CGCoroutine.cpp` & `SemaCoroutine.cpp` must agree on whether this suspend
// expression uses `[[clang::coro_await_suspend_destroy]]`.
//
@@ -198,7 +199,26 @@
StringRef SuspendMethodName; // Primary
CXXRecordDecl *AwaiterClass = nullptr; // Debug-only, best-effort
if (auto *SuspendCall = dyn_cast<CallExpr>(S.getSuspendExpr()->IgnoreImplicit())) {
- if (auto *SuspendMember = dyn_cast<MemberExpr>(SuspendCall->getCallee())) {
+ auto *SuspendMember = dyn_cast<MemberExpr>(SuspendCall->getCallee());
+
+ // For handle return types, we need to peel back the .address() call
+ // to get to the actual await_suspend method call
+ if (SuspendMember &&
+ S.getSuspendReturnType() == CoroutineSuspendExpr::SuspendReturnType::SuspendHandle) {
+ if (auto *SuspendMethod = dyn_cast<CXXMethodDecl>(SuspendMember->getMemberDecl())) {
+ if (SuspendMethod->getName() == "address") {
+ // This is the .address() call, peel back to the actual suspend call
+ if (auto *InnerCall = dyn_cast<CallExpr>(SuspendMember->getBase()->IgnoreImplicit())) {
+ if (auto *InnerMember = dyn_cast<MemberExpr>(InnerCall->getCallee())) {
+ SuspendMember = InnerMember;
+ SuspendCall = InnerCall;
+ }
+ }
+ }
+ }
+ }
+
+ if (SuspendMember) {
if (auto *BaseExpr = SuspendMember->getBase()) {
// `IgnoreImplicitAsWritten` is critical since `await_suspend...` can be
// invoked on the base of the actual awaiter, and the base need not have
@@ -276,22 +296,60 @@
}
// The simplified `await_suspend_destroy` path avoids suspend intrinsics.
+//
+// If a coro has only `await_suspend_destroy` and trivial (`suspend_never`)
+// awaiters, then subsequent passes are able to allocate its frame on-stack.
+//
+// As of 2025, there is still an optimization gap between a realistic
+// short-circuiting coro, and the equivalent plain function. For a
+// guesstimate, expect 4-5ns per call on x86. One idea for improvement is to
+// also elide trivial suspends like `std::suspend_never`, in order to hit the
+// `HasCoroSuspend` path in `CoroEarly.cpp`.
static void emitAwaitSuspendDestroy(CodeGenFunction &CGF, CGCoroData &Coro,
llvm::Function *SuspendWrapper,
llvm::Value *Awaiter,
llvm::Value *Frame,
- bool AwaitSuspendCanThrow) {
+ bool AwaitSuspendCanThrow,
+ CoroutineSuspendExpr::SuspendReturnType SuspendReturnType) {
SmallVector<llvm::Value *, 2> DirectCallArgs;
DirectCallArgs.push_back(Awaiter);
DirectCallArgs.push_back(Frame);
+ llvm::CallBase *SuspendRet = nullptr;
if (AwaitSuspendCanThrow) {
- CGF.EmitCallOrInvoke(SuspendWrapper, DirectCallArgs);
+ SuspendRet = CGF.EmitCallOrInvoke(SuspendWrapper, DirectCallArgs);
} else {
- CGF.EmitNounwindRuntimeCall(SuspendWrapper, DirectCallArgs);
+ SuspendRet = CGF.EmitNounwindRuntimeCall(SuspendWrapper, DirectCallArgs);
}
- CGF.EmitBranchThroughCleanup(Coro.CleanupJD);
+ if (SuspendReturnType ==
+ CoroutineSuspendExpr::SuspendReturnType::SuspendHandle) {
+ // For handle returns, extract the resume function and perform a direct call
+ auto &Builder = CGF.Builder;
+ auto &Context = Builder.getContext();
+
+ // Create coro.subfn.addr call to get resume function
+ auto *IndexVal = llvm::ConstantInt::get(llvm::Type::getInt8Ty(Context),
+ 0); // ResumeIndex = 0
+ auto *SubFnIntrinsic = llvm::Intrinsic::getDeclaration(
+ &CGF.CGM.getModule(), llvm::Intrinsic::coro_subfn_addr);
+ auto *ResumeAddr =
+ Builder.CreateCall(SubFnIntrinsic, {SuspendRet, IndexVal});
+
+ // Create direct call to resume function
+ llvm::FunctionType *ResumeTy =
+ llvm::FunctionType::get(llvm::Type::getVoidTy(Context),
+ llvm::PointerType::getUnqual(Context), false);
+ auto *ResumeCall = Builder.CreateCall(ResumeTy, ResumeAddr, {SuspendRet});
+ ResumeCall->setCallingConv(llvm::CallingConv::Fast);
+
+ // This path doesn't return - the tail call transfers control to another
+ // coroutine
+ Builder.CreateUnreachable();
+ } else {
+ // For void returns, branch to cleanup as before
+ CGF.EmitBranchThroughCleanup(Coro.CleanupJD);
+ }
}
static LValueOrRValue emitSuspendExpression(CodeGenFunction &CGF, CGCoroData &Coro,
@@ -336,7 +394,7 @@
if (useCoroAwaitSuspendDestroy(S)) { // Call `await_suspend_destroy` & cleanup
emitAwaitSuspendDestroy(CGF, Coro, SuspendWrapper, Awaiter, Frame,
- AwaitSuspendCanThrow);
+ AwaitSuspendCanThrow, SuspendReturnType);
} else { // Normal suspend path -- can actually suspend, uses intrinsics
CGF.CurCoro.InSuspendBlock = true;
diff --git a/llvm-project/clang/lib/Sema/SemaCoroutine.cpp b/third-party/tp2/llvm-fb/19/llvm-project/clang/lib/Sema/SemaCoroutine.cpp
--- a/llvm-project/clang/lib/Sema/SemaCoroutine.cpp
+++ b/llvm-project/clang/lib/Sema/SemaCoroutine.cpp
@@ -484,26 +484,8 @@
// type Z.
QualType RetType = AwaitSuspend->getCallReturnType(S.Context);
- auto EmitAwaitSuspendDiag = [&](unsigned int DiagCode) {
- S.Diag(AwaitSuspend->getCalleeDecl()->getLocation(), DiagCode)
- << RetType;
- S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required)
- << AwaitSuspend->getDirectCallee();
- Calls.IsInvalid = true;
- };
-
- // `await_suspend_destroy` must return `void` -- and `CGCoroutine.cpp`
- // critically depends on this in `hasCoroAwaitSuspendDestroyAttr`.
- if (UseAwaitSuspendDestroy) {
- if (RetType->isVoidType()) {
- Calls.Results[ACT::ACT_Suspend] =
- S.MaybeCreateExprWithCleanups(AwaitSuspend);
- } else {
- EmitAwaitSuspendDiag(diag::err_await_suspend_destroy_invalid_return_type);
- }
- // Support for coroutine_handle returning await_suspend.
- } else if (Expr *TailCallSuspend =
- maybeTailCall(S, RetType, AwaitSuspend, Loc)) {
+ // Support for coroutine_handle-returning await_suspend and await_suspend_destroy.
+ if (Expr *TailCallSuspend = maybeTailCall(S, RetType, AwaitSuspend, Loc)) {
// Note that we don't wrap the expression with ExprWithCleanups here
// because that might interfere with tailcall contract (e.g. inserting
// clean up instructions in-between tailcall and return). Instead
@@ -511,13 +493,22 @@
// call.
Calls.Results[ACT::ACT_Suspend] = TailCallSuspend;
} else {
- // non-class prvalues always have cv-unqualified types
- if (RetType->isReferenceType() ||
- (!RetType->isBooleanType() && !RetType->isVoidType())) {
- EmitAwaitSuspendDiag(diag::err_await_suspend_invalid_return_type);
- } else
+ // XXX ditch the ref check, simplify the logic
+ if ((UseAwaitSuspendDestroy && !RetType->isVoidType()) ||
+ (!UseAwaitSuspendDestroy && (RetType->isReferenceType() ||
+ (!RetType->isBooleanType() && !RetType->isVoidType())))) {
+ auto DiagCode = UseAwaitSuspendDestroy ?
+ diag::err_await_suspend_destroy_invalid_return_type :
+ diag::err_await_suspend_invalid_return_type;
+ S.Diag(AwaitSuspend->getCalleeDecl()->getLocation(), DiagCode)
+ << RetType;
+ S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required)
+ << AwaitSuspend->getDirectCallee();
+ Calls.IsInvalid = true;
+ } else {
Calls.Results[ACT::ACT_Suspend] =
S.MaybeCreateExprWithCleanups(AwaitSuspend);
+ }
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment