llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-clangir Author: adams381 <details> <summary>Changes</summary> Implement variadic thunk emission via musttail and null-check pointer returns in covariant thunk adjustment, matching classic codegen behavior. Adds musttail UnitAttr to cir.call/cir.try_call with lowering to LLVM::MustTail. Made with [Cursor](https://cursor.com) --- Full diff: https://github.com/llvm/llvm-project/pull/191255.diff 6 Files Affected: - (modified) clang/include/clang/CIR/Dialect/IR/CIRDialect.td (+1) - (modified) clang/include/clang/CIR/Dialect/IR/CIROps.td (+1) - (modified) clang/lib/CIR/CodeGen/CIRGenVTables.cpp (+55-14) - (modified) clang/lib/CIR/Dialect/IR/CIRDialect.cpp (+8) - (modified) clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp (+4-1) - (modified) clang/test/CIR/CodeGen/thunks.cpp (+63) ``````````diff diff --git a/clang/include/clang/CIR/Dialect/IR/CIRDialect.td b/clang/include/clang/CIR/Dialect/IR/CIRDialect.td index f1f94c868e5b0..199bc0a6f5670 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIRDialect.td +++ b/clang/include/clang/CIR/Dialect/IR/CIRDialect.td @@ -75,6 +75,7 @@ def CIR_Dialect : Dialect { static llvm::StringRef getDefaultFuncAttrsAttrName() { return "default_func_attrs"; } static llvm::StringRef getResAttrsAttrName() { return "res_attrs"; } static llvm::StringRef getArgAttrsAttrName() { return "arg_attrs"; } + static llvm::StringRef getMustTailAttrName() { return "musttail"; } static llvm::StringRef getAMDGPUCodeObjectVersionAttrName() { return "cir.amdhsa_code_object_version"; } static llvm::StringRef getAMDGPUPrintfKindAttrName() { return "cir.amdgpu_printf_kind"; } diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index 93ca0172b2a7f..f12e7662952f7 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -3920,6 +3920,7 @@ class CIR_CallOpBase<string mnemonic, list<Trait> extra_traits = []> dag commonArgs = (ins OptionalAttr<FlatSymbolRefAttr>:$callee, Variadic<CIR_AnyType>:$args, UnitAttr:$nothrow, + UnitAttr:$musttail, DefaultValuedAttr<CIR_SideEffect, "SideEffect::All">:$side_effect, OptionalAttr<DictArrayAttr>:$arg_attrs, OptionalAttr<DictArrayAttr>:$res_attrs diff --git a/clang/lib/CIR/CodeGen/CIRGenVTables.cpp b/clang/lib/CIR/CodeGen/CIRGenVTables.cpp index 56839ca03dbb1..936e9b6487514 100644 --- a/clang/lib/CIR/CodeGen/CIRGenVTables.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenVTables.cpp @@ -559,26 +559,41 @@ uint64_t CIRGenVTables::getSecondaryVirtualPointerIndex(const CXXRecordDecl *rd, static RValue performReturnAdjustment(CIRGenFunction &cgf, QualType resultType, RValue rv, const ThunkInfo &thunk) { - // Emit the return adjustment. + // Emit the return adjustment. For non-reference pointer returns, match + // classic codegen: skip the adjustment when the returned pointer is null. bool nullCheckValue = !resultType->isReferenceType(); - mlir::Value returnValue = rv.getValue(); - if (nullCheckValue) - cgf.cgm.errorNYI( - "return adjustment with null check for non-reference types"); - const CXXRecordDecl *classDecl = resultType->getPointeeType()->getAsCXXRecordDecl(); CharUnits classAlign = cgf.cgm.getClassPointerAlignment(classDecl); mlir::Type pointeeType = cgf.convertTypeForMem(resultType->getPointeeType()); - returnValue = cgf.cgm.getCXXABI().performReturnAdjustment( - cgf, Address(returnValue, pointeeType, classAlign), classDecl, - thunk.Return); + CIRGenBuilderTy &builder = cgf.getBuilder(); + mlir::Location loc = returnValue.getLoc(); + + if (!nullCheckValue) { + returnValue = cgf.cgm.getCXXABI().performReturnAdjustment( + cgf, Address(returnValue, pointeeType, classAlign), classDecl, + thunk.Return); + return RValue::get(returnValue); + } - if (nullCheckValue) - cgf.cgm.errorNYI( - "return adjustment with null check for non-reference types"); + mlir::Value isNotNull = builder.createPtrIsNotNull(returnValue); + returnValue = + cir::TernaryOp::create( + builder, loc, isNotNull, + [&](mlir::OpBuilder &, mlir::Location) { + mlir::Value adjusted = cgf.cgm.getCXXABI().performReturnAdjustment( + cgf, Address(returnValue, pointeeType, classAlign), classDecl, + thunk.Return); + builder.createYield(loc, adjusted); + }, + [&](mlir::OpBuilder &, mlir::Location) { + mlir::Value nullVal = + builder.getNullPtr(returnValue.getType(), loc).getResult(); + builder.createYield(loc, nullVal); + }) + .getResult(); return RValue::get(returnValue); } @@ -743,8 +758,34 @@ void CIRGenFunction::emitCallAndReturnForThunk(cir::FuncOp callee, void CIRGenFunction::emitMustTailThunk(GlobalDecl gd, mlir::Value adjustedThisPtr, cir::FuncOp callee) { - assert(!cir::MissingFeatures::opCallMustTail()); - cgm.errorNYI("musttail thunk"); + // Forward all function arguments, replacing 'this' with the adjusted pointer. + // The call is marked musttail so varargs are forwarded correctly. + auto thunkFn = cast<cir::FuncOp>(curFn); + mlir::Block &entryBlock = thunkFn.getBody().front(); + SmallVector<mlir::Value> args; + for (mlir::BlockArgument arg : entryBlock.getArguments()) + args.push_back(arg); + + // Replace the 'this' argument (first arg) with the adjusted pointer. + assert(!args.empty() && "thunk must have at least 'this' argument"); + if (adjustedThisPtr.getType() != args[0].getType()) + adjustedThisPtr = builder.createBitcast(adjustedThisPtr, args[0].getType()); + args[0] = adjustedThisPtr; + + mlir::Location loc = thunkFn.getLoc(); + cir::FuncType calleeTy = callee.getFunctionType(); + mlir::Type retTy = calleeTy.getReturnType(); + + cir::CallOp call = builder.createCallOp(loc, callee, args); + call->setAttr(cir::CIRDialect::getMustTailAttrName(), + mlir::UnitAttr::get(builder.getContext())); + + if (isa<cir::VoidType>(retTy)) + cir::ReturnOp::create(builder, loc); + else + cir::ReturnOp::create(builder, loc, call->getResult(0)); + + finishThunk(); } void CIRGenFunction::generateThunk(cir::FuncOp fn, diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index e99faf840c15a..305d7c5389d8c 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -918,6 +918,10 @@ static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser, return ::mlir::failure(); } + if (parser.parseOptionalKeyword("musttail").succeeded()) + result.addAttribute(CIRDialect::getMustTailAttrName(), + mlir::UnitAttr::get(parser.getContext())); + if (parser.parseOptionalKeyword("nothrow").succeeded()) result.addAttribute(CIRDialect::getNoThrowAttrName(), mlir::UnitAttr::get(parser.getContext())); @@ -1020,6 +1024,9 @@ printCallCommon(mlir::Operation *op, mlir::FlatSymbolRefAttr calleeSym, printer << tryCall.getUnwindDest(); } + if (op->hasAttr(CIRDialect::getMustTailAttrName())) + printer << " musttail"; + if (isNothrow) printer << " nothrow"; @@ -1031,6 +1038,7 @@ printCallCommon(mlir::Operation *op, mlir::FlatSymbolRefAttr calleeSym, llvm::SmallVector<::llvm::StringRef> elidedAttrs = { CIRDialect::getCalleeAttrName(), + CIRDialect::getMustTailAttrName(), CIRDialect::getNoThrowAttrName(), CIRDialect::getSideEffectAttrName(), CIRDialect::getOperandSegmentSizesAttrName(), diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index 2117dd5903ec4..6f5b78dd70c3a 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -1642,7 +1642,8 @@ static void lowerCallAttributes(cir::CIRCallOpInterface op, attr.getName() == CIRDialect::getSideEffectAttrName() || attr.getName() == CIRDialect::getNoThrowAttrName() || attr.getName() == CIRDialect::getNoUnwindAttrName() || - attr.getName() == CIRDialect::getNoReturnAttrName()) + attr.getName() == CIRDialect::getNoReturnAttrName() || + attr.getName() == CIRDialect::getMustTailAttrName()) continue; assert(!cir::MissingFeatures::opFuncExtraAttrs()); @@ -1743,6 +1744,8 @@ rewriteCallOrInvoke(mlir::Operation *op, mlir::ValueRange callOperands, newOp.setNoUnwind(noUnwind); newOp.setWillReturn(willReturn); newOp.setNoreturn(noReturn); + if (op->hasAttr(CIRDialect::getMustTailAttrName())) + newOp.setTailCallKind(mlir::LLVM::TailCallKind::MustTail); } return mlir::success(); diff --git a/clang/test/CIR/CodeGen/thunks.cpp b/clang/test/CIR/CodeGen/thunks.cpp index 15c4810738420..b36e8a9805516 100644 --- a/clang/test/CIR/CodeGen/thunks.cpp +++ b/clang/test/CIR/CodeGen/thunks.cpp @@ -91,6 +91,36 @@ void C::g(int x) {} } // namespace Test4 +namespace CovariantReturn { +// Covariant return with virtual inheritance: return-adjusting thunks use a +// null check for pointer returns (classic PerformReturnAdjustment). +struct A { + virtual A *f(); +}; +struct B : virtual A { + virtual A *f(); +}; +struct C : B { + virtual C *f(); +}; +C *C::f() { return 0; } +} // namespace CovariantReturn + +namespace VarargThunk { +// Variadic this-adjusting thunk. On x86_64, the thunk forwards arguments +// via musttail (classic codegen) or direct argument forwarding (CIR). +struct A { + virtual void f(int x, ...); +}; +struct B { + virtual void f(int x, ...); +}; +struct C : A, B { + void f(int x, ...) override; +}; +void C::f(int x, ...) {} +} // namespace VarargThunk + // In CIR, all globals are emitted before functions. // Test1 vtable: C's vtable references the thunk for B's entry. @@ -183,6 +213,23 @@ void C::g(int x) {} // CIR: cir.call @_ZN5Test41C1gEi(%[[T4_RESULT]], %[[T4_ARG]]) // CIR: cir.return +// --- CovariantReturn: return adjustment with null check on pointer return --- + +// CIR-LABEL: cir.func {{.*}} @_ZTch0_v0_n32_N15CovariantReturn1C1fEv +// CIR: cir.call @_ZN15CovariantReturn1C1fEv +// CIR: cir.ternary + +// --- VarargThunk: variadic this-adjusting thunk --- + +// CIR: cir.func {{.*}} @_ZThn8_N11VarargThunk1C1fEiz(%arg0: !cir.ptr< +// CIR: %[[VT_THIS:.*]] = cir.load +// CIR: %[[VT_CAST:.*]] = cir.cast bitcast %[[VT_THIS]] : !cir.ptr<{{.*}}> -> !cir.ptr<!u8i> +// CIR: %[[VT_OFFSET:.*]] = cir.const #cir.int<-8> : !s64i +// CIR: %[[VT_ADJUSTED:.*]] = cir.ptr_stride %[[VT_CAST]], %[[VT_OFFSET]] : (!cir.ptr<!u8i>, !s64i) -> !cir.ptr<!u8i> +// CIR: %[[VT_RESULT:.*]] = cir.cast bitcast %[[VT_ADJUSTED]] : !cir.ptr<!u8i> -> !cir.ptr< +// CIR: cir.call @_ZN11VarargThunk1C1fEiz(%[[VT_RESULT]], %arg1) musttail +// CIR: cir.return + // --- LLVM checks --- // LLVM: @_ZTVN5Test11CE = global { [3 x ptr], [3 x ptr] } { @@ -231,6 +278,14 @@ void C::g(int x) {} // LLVM: %[[L4_ARG:.*]] = load i32, ptr // LLVM: call void @_ZN5Test41C1gEi(ptr{{.*}} %[[L4_ADJ]], i32{{.*}} %[[L4_ARG]]) +// LLVM-LABEL: define {{.*}} @_ZTch0_v0_n32_N15CovariantReturn1C1fEv +// LLVM: call {{.*}} @_ZN15CovariantReturn1C1fEv +// LLVM: phi ptr + +// LLVM-LABEL: define {{.*}} void @_ZThn8_N11VarargThunk1C1fEiz(ptr{{.*}}, i32{{.*}}, ...) +// LLVM: getelementptr i8, ptr {{.*}}, i64 -8 +// LLVM: musttail call void (ptr, i32, ...) @_ZN11VarargThunk1C1fEiz(ptr{{.*}}, i32{{.*}}, ...) + // --- OGCG checks --- // OGCG: @_ZTVN5Test11CE = unnamed_addr constant { [3 x ptr], [3 x ptr] } { @@ -278,3 +333,11 @@ void C::g(int x) {} // OGCG: %[[O4_ADJ:.*]] = getelementptr inbounds i8, ptr %[[O4_THIS]], i64 -8 // OGCG: %[[O4_ARG:.*]] = load i32, ptr // OGCG: {{.*}}call void @_ZN5Test41C1gEi(ptr{{.*}} %[[O4_ADJ]], i32{{.*}} %[[O4_ARG]]) + +// OGCG-LABEL: define {{.*}} @_ZTch0_v0_n32_N15CovariantReturn1C1fEv +// OGCG: {{.*}}call {{.*}} @_ZN15CovariantReturn1C1fEv +// OGCG: phi ptr + +// OGCG-LABEL: define {{.*}} void @_ZThn8_N11VarargThunk1C1fEiz(ptr{{.*}}, i32{{.*}}, ...) +// OGCG: getelementptr inbounds i8, ptr {{.*}}, i64 -8 +// OGCG: musttail call void (ptr, i32, ...) @_ZN11VarargThunk1C1fEiz(ptr{{.*}}, i32{{.*}}, ...) `````````` </details> https://github.com/llvm/llvm-project/pull/191255 _______________________________________________ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
