Author: Jan Leyonberg Date: 2026-04-01T12:50:09-04:00 New Revision: 91adaeceb162357a33e2ea6155cb13a4198a981a
URL: https://github.com/llvm/llvm-project/commit/91adaeceb162357a33e2ea6155cb13a4198a981a DIFF: https://github.com/llvm/llvm-project/commit/91adaeceb162357a33e2ea6155cb13a4198a981a.diff LOG: [CIR][MLIR][OpenMP] Enable the MarkDeclareTarget pass for ClangIR (#189420) This patch enables the MarkDeclareTarget for CIR by adding the pass to the lowerings and attaching the declare target interface to the cir::FuncOp. The MarkDeclareTarget is also generalized to work on the FunctionOpInterface instead of func::Op since it needs to be able to handle cir::FuncOp as well. Co-authored-by: Claude Opus 4.6 <[email protected]> Added: clang/include/clang/CIR/Dialect/OpenMP/RegisterOpenMPExtensions.h clang/lib/CIR/Dialect/OpenMP/CMakeLists.txt clang/lib/CIR/Dialect/OpenMP/RegisterOpenMPExtensions.cpp clang/test/CIR/Transforms/omp-mark-declare-target.cir Modified: clang/lib/CIR/CodeGen/CIRGenerator.cpp clang/lib/CIR/CodeGen/CMakeLists.txt clang/lib/CIR/Dialect/CMakeLists.txt clang/lib/CIR/Lowering/DirectToLLVM/CMakeLists.txt clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp clang/tools/cir-opt/CMakeLists.txt clang/tools/cir-opt/cir-opt.cpp clang/tools/cir-translate/CMakeLists.txt clang/tools/cir-translate/cir-translate.cpp mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt mlir/lib/Dialect/OpenMP/Transforms/MarkDeclareTarget.cpp Removed: ################################################################################ diff --git a/clang/include/clang/CIR/Dialect/OpenMP/RegisterOpenMPExtensions.h b/clang/include/clang/CIR/Dialect/OpenMP/RegisterOpenMPExtensions.h new file mode 100644 index 0000000000000..2247025a4433b --- /dev/null +++ b/clang/include/clang/CIR/Dialect/OpenMP/RegisterOpenMPExtensions.h @@ -0,0 +1,22 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef CLANG_CIR_DIALECT_OPENMP_REGISTEROPENMPEXTENSIONS_H +#define CLANG_CIR_DIALECT_OPENMP_REGISTEROPENMPEXTENSIONS_H + +namespace mlir { +class DialectRegistry; +} // namespace mlir + +namespace cir::omp { + +void registerOpenMPExtensions(mlir::DialectRegistry ®istry); + +} // namespace cir::omp + +#endif // CLANG_CIR_DIALECT_OPENMP_REGISTEROPENMPEXTENSIONS_H diff --git a/clang/lib/CIR/CodeGen/CIRGenerator.cpp b/clang/lib/CIR/CodeGen/CIRGenerator.cpp index 80f85169b73cb..31d40c21ef6e1 100644 --- a/clang/lib/CIR/CodeGen/CIRGenerator.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenerator.cpp @@ -21,6 +21,7 @@ #include "clang/CIR/CIRGenerator.h" #include "clang/CIR/Dialect/IR/CIRDialect.h" #include "clang/CIR/Dialect/OpenACC/RegisterOpenACCExtensions.h" +#include "clang/CIR/Dialect/OpenMP/RegisterOpenMPExtensions.h" #include "llvm/IR/DataLayout.h" using namespace cir; @@ -56,9 +57,10 @@ void CIRGenerator::Initialize(ASTContext &astContext) { mlirContext->getOrLoadDialect<mlir::acc::OpenACCDialect>(); mlirContext->getOrLoadDialect<mlir::omp::OpenMPDialect>(); - // Register extensions to integrate CIR types with OpenACC. + // Register extensions to integrate CIR types with OpenACC and OpenMP. mlir::DialectRegistry registry; cir::acc::registerOpenACCExtensions(registry); + cir::omp::registerOpenMPExtensions(registry); mlirContext->appendDialectRegistry(registry); cgm = std::make_unique<clang::CIRGen::CIRGenModule>( diff --git a/clang/lib/CIR/CodeGen/CMakeLists.txt b/clang/lib/CIR/CodeGen/CMakeLists.txt index 0afff8ad7f555..3a2616fcd2526 100644 --- a/clang/lib/CIR/CodeGen/CMakeLists.txt +++ b/clang/lib/CIR/CodeGen/CMakeLists.txt @@ -65,6 +65,7 @@ add_clang_library(clangCIR clangLex ${dialect_libs} CIROpenACCSupport + CIROpenMPSupport MLIRCIR MLIRCIRInterfaces MLIRTargetLLVMIRImport diff --git a/clang/lib/CIR/Dialect/CMakeLists.txt b/clang/lib/CIR/Dialect/CMakeLists.txt index c825a61b2779b..e05c9becebbad 100644 --- a/clang/lib/CIR/Dialect/CMakeLists.txt +++ b/clang/lib/CIR/Dialect/CMakeLists.txt @@ -1,3 +1,4 @@ add_subdirectory(IR) add_subdirectory(OpenACC) +add_subdirectory(OpenMP) add_subdirectory(Transforms) diff --git a/clang/lib/CIR/Dialect/OpenMP/CMakeLists.txt b/clang/lib/CIR/Dialect/OpenMP/CMakeLists.txt new file mode 100644 index 0000000000000..f6f4017f0f1f6 --- /dev/null +++ b/clang/lib/CIR/Dialect/OpenMP/CMakeLists.txt @@ -0,0 +1,11 @@ +add_clang_library(CIROpenMPSupport + RegisterOpenMPExtensions.cpp + + DEPENDS + MLIRCIROpsIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRCIR + MLIROpenMPDialect + ) diff --git a/clang/lib/CIR/Dialect/OpenMP/RegisterOpenMPExtensions.cpp b/clang/lib/CIR/Dialect/OpenMP/RegisterOpenMPExtensions.cpp new file mode 100644 index 0000000000000..b5129202e66c4 --- /dev/null +++ b/clang/lib/CIR/Dialect/OpenMP/RegisterOpenMPExtensions.cpp @@ -0,0 +1,26 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Registration for OpenMP extensions as applied to CIR dialect. +// +//===----------------------------------------------------------------------===// + +#include "clang/CIR/Dialect/OpenMP/RegisterOpenMPExtensions.h" +#include "mlir/Dialect/OpenMP/OpenMPInterfaces.h" +#include "clang/CIR/Dialect/IR/CIRDialect.h" + +namespace cir::omp { + +void registerOpenMPExtensions(mlir::DialectRegistry ®istry) { + registry.addExtension(+[](mlir::MLIRContext *ctx, cir::CIRDialect *dialect) { + cir::FuncOp::attachInterface< + mlir::omp::DeclareTargetDefaultModel<cir::FuncOp>>(*ctx); + }); +} + +} // namespace cir::omp diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/CMakeLists.txt b/clang/lib/CIR/Lowering/DirectToLLVM/CMakeLists.txt index c7467fe40ba30..021397fee992b 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/CMakeLists.txt +++ b/clang/lib/CIR/Lowering/DirectToLLVM/CMakeLists.txt @@ -22,6 +22,7 @@ add_clang_library(clangCIRLoweringDirectToLLVM MLIRBuiltinToLLVMIRTranslation MLIRLLVMToLLVMIRTranslation MLIROpenMPToLLVMIRTranslation + MLIROpenMPTransforms MLIRIR ) diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index ba89fbe3091bc..149cd90b813ec 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -21,6 +21,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Dialect/OpenMP/Transforms/Passes.h" #include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinDialect.h" @@ -4936,6 +4937,7 @@ std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() { void populateCIRToLLVMPasses(mlir::OpPassManager &pm) { mlir::populateCIRPreLoweringPasses(pm); + pm.addPass(mlir::omp::createMarkDeclareTargetPass()); pm.addPass(createConvertCIRToLLVMPass()); } diff --git a/clang/test/CIR/Transforms/omp-mark-declare-target.cir b/clang/test/CIR/Transforms/omp-mark-declare-target.cir new file mode 100644 index 0000000000000..914589ec65bcf --- /dev/null +++ b/clang/test/CIR/Transforms/omp-mark-declare-target.cir @@ -0,0 +1,53 @@ +// RUN: cir-opt --omp-mark-declare-target %s -o - | FileCheck %s + +// Test that the MarkDeclareTarget pass propagates the declare_target +// attribute from explicitly marked functions to functions they call, +// and from omp.target regions to functions called within them. + +!s32i = !cir.int<s, 32> + +module { + // A helper function with no declare_target attribute initially. + // After the pass, it should be marked because @caller calls it. + // CHECK-LABEL: cir.func @helper + // CHECK-SAME: omp.declare_target = #omp.declaretarget<device_type = (host), capture_clause = (to) + cir.func @helper() { + cir.return + } + + // Explicitly marked as declare_target; calls @helper. + // CHECK-LABEL: cir.func @caller + // CHECK-SAME: omp.declare_target = #omp.declaretarget<device_type = (host), capture_clause = (to)> + cir.func @caller() attributes {omp.declare_target = #omp.declaretarget<device_type = (host), capture_clause = (to)>} { + cir.call @helper() : () -> () + cir.return + } + + // Called from within an omp.target region; should be marked as nohost. + // CHECK-LABEL: cir.func @device_helper + // CHECK-SAME: omp.declare_target = #omp.declaretarget<device_type = (nohost), capture_clause = (to) + cir.func @device_helper() { + cir.return + } + + // Contains an omp.target region that calls @device_helper. + // The function itself should NOT be marked as declare_target. + // CHECK-LABEL: cir.func @target_caller + // CHECK-NOT: omp.declare_target + // CHECK-SAME: { + cir.func @target_caller() { + omp.target { + cir.call @device_helper() : () -> () + omp.terminator + } + cir.return + } + + // Not called by any declare_target function or target region. + // CHECK-LABEL: cir.func @unrelated + // CHECK-NOT: omp.declare_target + // CHECK-SAME: { + cir.func @unrelated() { + cir.return + } +} diff --git a/clang/tools/cir-opt/CMakeLists.txt b/clang/tools/cir-opt/CMakeLists.txt index cae7de6f056a9..4e9553ed8a7e7 100644 --- a/clang/tools/cir-opt/CMakeLists.txt +++ b/clang/tools/cir-opt/CMakeLists.txt @@ -23,6 +23,7 @@ clang_target_link_libraries(cir-opt PRIVATE clangCIR clangCIRLoweringDirectToLLVM + CIROpenMPSupport MLIRCIR MLIRCIRTransforms ) @@ -35,6 +36,7 @@ target_link_libraries(cir-opt MLIRDialect MLIRIR MLIRMemRefDialect + MLIROpenMPTransforms MLIROptLib MLIRParser MLIRPass diff --git a/clang/tools/cir-opt/cir-opt.cpp b/clang/tools/cir-opt/cir-opt.cpp index a24bf5d581af9..05e3b9ec7e964 100644 --- a/clang/tools/cir-opt/cir-opt.cpp +++ b/clang/tools/cir-opt/cir-opt.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/Dialect/OpenMP/Transforms/Passes.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/Pass/PassManager.h" #include "mlir/Pass/PassOptions.h" @@ -25,6 +26,7 @@ #include "mlir/Tools/mlir-opt/MlirOptMain.h" #include "mlir/Transforms/Passes.h" #include "clang/CIR/Dialect/IR/CIRDialect.h" +#include "clang/CIR/Dialect/OpenMP/RegisterOpenMPExtensions.h" #include "clang/CIR/Dialect/Passes.h" #include "clang/CIR/Passes.h" @@ -37,6 +39,7 @@ int main(int argc, char **argv) { registry.insert<mlir::BuiltinDialect, cir::CIRDialect, mlir::memref::MemRefDialect, mlir::LLVM::LLVMDialect, mlir::DLTIDialect, mlir::omp::OpenMPDialect>(); + cir::omp::registerOpenMPExtensions(registry); ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { return mlir::createCIRCanonicalizePass(); @@ -71,6 +74,7 @@ int main(int argc, char **argv) { return mlir::createCXXABILoweringPass(); }); + mlir::omp::registerOpenMPPasses(); mlir::registerTransformsPasses(); return mlir::asMainReturnCode(MlirOptMain( diff --git a/clang/tools/cir-translate/CMakeLists.txt b/clang/tools/cir-translate/CMakeLists.txt index 21834799ea82f..53e60220b8736 100644 --- a/clang/tools/cir-translate/CMakeLists.txt +++ b/clang/tools/cir-translate/CMakeLists.txt @@ -13,6 +13,7 @@ clang_target_link_libraries(cir-translate PRIVATE clangCIR clangCIRLoweringDirectToLLVM + CIROpenMPSupport MLIRCIR MLIRCIRTransforms ) diff --git a/clang/tools/cir-translate/cir-translate.cpp b/clang/tools/cir-translate/cir-translate.cpp index 2b00d1bd62e4a..997d44dc5a62f 100644 --- a/clang/tools/cir-translate/cir-translate.cpp +++ b/clang/tools/cir-translate/cir-translate.cpp @@ -31,6 +31,7 @@ #include "clang/Basic/DiagnosticOptions.h" #include "clang/Basic/TargetInfo.h" #include "clang/CIR/Dialect/IR/CIRDialect.h" +#include "clang/CIR/Dialect/OpenMP/RegisterOpenMPExtensions.h" #include "clang/CIR/Dialect/Passes.h" #include "clang/CIR/LowerToLLVM.h" #include "clang/CIR/MissingFeatures.h" @@ -169,6 +170,7 @@ void registerToLLVMTranslation() { registry.insert<mlir::DLTIDialect, mlir::func::FuncDialect>(); mlir::registerAllToLLVMIRTranslations(registry); cir::direct::registerCIRDialectTranslation(registry); + cir::omp::registerOpenMPExtensions(registry); }); } diff --git a/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt b/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt index a46924cd9878e..9b11d4b87e8df 100644 --- a/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt @@ -6,8 +6,8 @@ add_mlir_dialect_library(MLIROpenMPTransforms MLIROpenMPPassIncGen LINK_LIBS PUBLIC + MLIRFunctionInterfaces MLIRIR - MLIRFuncDialect MLIRLLVMDialect MLIROpenMPDialect MLIRPass diff --git a/mlir/lib/Dialect/OpenMP/Transforms/MarkDeclareTarget.cpp b/mlir/lib/Dialect/OpenMP/Transforms/MarkDeclareTarget.cpp index 18a36f73edaf2..e3357e03d9c16 100644 --- a/mlir/lib/Dialect/OpenMP/Transforms/MarkDeclareTarget.cpp +++ b/mlir/lib/Dialect/OpenMP/Transforms/MarkDeclareTarget.cpp @@ -10,10 +10,10 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/IR/Operation.h" #include "mlir/IR/SymbolTable.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/SmallPtrSet.h" @@ -42,28 +42,30 @@ class MarkDeclareTargetPass void processSymbolRef(SymbolRefAttr symRef, ParentInfo parentInfo, llvm::SmallPtrSet<Operation *, 16> visited) { - if (auto currFOp = getOperation().lookupSymbol<func::FuncOp>(symRef)) { - auto current = - llvm::dyn_cast<omp::DeclareTargetInterface>(currFOp.getOperation()); - - if (current.isDeclareTarget()) { - auto currentDt = current.getDeclareTargetDeviceType(); - - // Found the same function twice, with diff erent device_types, - // mark as Any as it belongs to both - if (currentDt != parentInfo.devTy && - currentDt != omp::DeclareTargetDeviceType::any) { - current.setDeclareTarget(omp::DeclareTargetDeviceType::any, - current.getDeclareTargetCaptureClause(), - current.getDeclareTargetAutomap()); - } - } else { - current.setDeclareTarget(parentInfo.devTy, parentInfo.capClause, - parentInfo.automap); - } + Operation *symOp = getOperation().lookupSymbol(symRef); + if (!symOp) + return; + auto current = llvm::dyn_cast<omp::DeclareTargetInterface>(symOp); + if (!current) + return; + + if (current.isDeclareTarget()) { + auto currentDt = current.getDeclareTargetDeviceType(); - markNestedFuncs(parentInfo, currFOp, visited); + // Found the same function twice, with diff erent device_types, + // mark as Any as it belongs to both + if (currentDt != parentInfo.devTy && + currentDt != omp::DeclareTargetDeviceType::any) { + current.setDeclareTarget(omp::DeclareTargetDeviceType::any, + current.getDeclareTargetCaptureClause(), + current.getDeclareTargetAutomap()); + } + } else { + current.setDeclareTarget(parentInfo.devTy, parentInfo.capClause, + parentInfo.automap); } + + markNestedFuncs(parentInfo, symOp, visited); } void processReductionRefs(std::optional<mlir::ArrayAttr> symRefs, @@ -138,16 +140,16 @@ class MarkDeclareTargetPass // as implicitly declare target if they are called from within an explicitly // marked declare target function or a target region (TargetOp) void runOnOperation() override { - for (auto functionOp : getOperation().getOps<func::FuncOp>()) { - auto declareTargetOp = llvm::dyn_cast<omp::DeclareTargetInterface>( - functionOp.getOperation()); - if (declareTargetOp.isDeclareTarget()) { - llvm::SmallPtrSet<Operation *, 16> visited; - ParentInfo parentInfo{declareTargetOp.getDeclareTargetDeviceType(), - declareTargetOp.getDeclareTargetCaptureClause(), - declareTargetOp.getDeclareTargetAutomap()}; - markNestedFuncs(parentInfo, functionOp, visited); - } + for (auto funcOp : getOperation().getOps<FunctionOpInterface>()) { + auto declareTargetOp = + llvm::dyn_cast<omp::DeclareTargetInterface>(funcOp.getOperation()); + if (!declareTargetOp || !declareTargetOp.isDeclareTarget()) + continue; + llvm::SmallPtrSet<Operation *, 16> visited; + ParentInfo parentInfo{declareTargetOp.getDeclareTargetDeviceType(), + declareTargetOp.getDeclareTargetCaptureClause(), + declareTargetOp.getDeclareTargetAutomap()}; + markNestedFuncs(parentInfo, funcOp, visited); } // TODO: Extend to work with reverse-offloading, this shouldn't _______________________________________________ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
