llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-clang-codegen Author: Abhinav Gaba (abhinavgaba) <details> <summary>Changes</summary> These have been pulled out of the codegen PR #<!-- -->153683, to reduce the size of that PR. --- Patch is 25.43 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/155625.diff 5 Files Affected: - (modified) clang/include/clang/AST/OpenMPClause.h (+95) - (modified) clang/include/clang/Basic/OpenMPKinds.h (+8) - (modified) clang/lib/AST/OpenMPClause.cpp (+68) - (modified) clang/lib/Basic/OpenMPKinds.cpp (+5) - (modified) clang/lib/CodeGen/CGOpenMPRuntime.cpp (+358) ``````````diff diff --git a/clang/include/clang/AST/OpenMPClause.h b/clang/include/clang/AST/OpenMPClause.h index 1118d3e062e68..9627e99a306b4 100644 --- a/clang/include/clang/AST/OpenMPClause.h +++ b/clang/include/clang/AST/OpenMPClause.h @@ -5815,6 +5815,12 @@ class OMPClauseMappableExprCommon { ValueDecl *getAssociatedDeclaration() const { return AssociatedDeclaration; } + + bool operator==(const MappableComponent &Other) const { + return AssociatedExpressionNonContiguousPr == + Other.AssociatedExpressionNonContiguousPr && + AssociatedDeclaration == Other.AssociatedDeclaration; + } }; // List of components of an expression. This first one is the whole @@ -5828,6 +5834,95 @@ class OMPClauseMappableExprCommon { using MappableExprComponentLists = SmallVector<MappableExprComponentList, 8>; using MappableExprComponentListsRef = ArrayRef<MappableExprComponentList>; + // Hash function to allow usage as DenseMap keys. + friend llvm::hash_code hash_value(const MappableComponent &MC) { + return llvm::hash_combine(MC.getAssociatedExpression(), + MC.getAssociatedDeclaration(), + MC.isNonContiguous()); + } + +public: + /// Get the type of an element of a ComponentList Expr \p Exp. + /// + /// For something like the following: + /// ```c + /// int *p, **p; + /// ``` + /// The types for the following Exprs would be: + /// Expr | Type + /// ---------|----------- + /// p | int * + /// *p | int + /// p[0] | int + /// p[0:1] | int + /// pp | int ** + /// pp[0] | int * + /// pp[0:1] | int * + /// Note: this assumes that if \p Exp is an array-section, it is contiguous. + static QualType getComponentExprElementType(const Expr *Exp); + + /// Find the attach pointer expression from a list of mappable expression + /// components. + /// + /// This function traverses the component list to find the first + /// expression that has a pointer type, which represents the attach + /// base pointer expr for the current component-list. + /// + /// For example, given the following: + /// + /// ```c + /// struct S { + /// int a; + /// int b[10]; + /// int c[10][10]; + /// int *p; + /// int **pp; + /// } + /// S s, *ps, **pps, *(pas[10]), ***ppps; + /// int i; + /// ``` + /// + /// The base-pointers for the following map operands would be: + /// map list-item | attach base-pointer | attach base-pointer + /// | for directives except | target_update (if + /// | target_update | different) + /// ----------------|-----------------------|--------------------- + /// s | N/A | + /// s.a | N/A | + /// s.p | N/A | + /// ps | N/A | + /// ps->p | ps | + /// ps[1] | ps | + /// *(ps + 1) | ps | + /// (ps + 1)[1] | ps | + /// ps[1:10] | ps | + /// ps->b[10] | ps | + /// ps->p[10] | ps->p | + /// ps->c[1][2] | ps | + /// ps->c[1:2][2] | (error diagnostic) | N/A, TODO: ps + /// ps->c[1:1][2] | ps | N/A, TODO: ps + /// pps[1][2] | pps[1] | + /// pps[1:1][2] | pps[1:1] | N/A, TODO: pps[1:1] + /// pps[1:i][2] | pps[1:i] | N/A, TODO: pps[1:i] + /// pps[1:2][2] | (error diagnostic) | N/A + /// pps[1]->p | pps[1] | + /// pps[1]->p[10] | pps[1] | + /// pas[1] | N/A | + /// pas[1][2] | pas[1] | + /// ppps[1][2] | ppps[1] | + /// ppps[1][2][3] | ppps[1][2] | + /// ppps[1][2:1][3] | ppps[1][2:1] | N/A, TODO: ppps[1][2:1] + /// ppps[1][2:2][3] | (error diagnostic) | N/A + /// Returns a pair of the attach pointer expression and its depth in the + /// component list. + /// TODO: This may need to be updated to handle ref_ptr/ptee cases for byref + /// map operands. + /// TODO: Handle cases for target-update, where the list-item is a + /// non-contiguous array-section that still has a base-pointer. + static std::pair<const Expr *, std::optional<size_t>> + findAttachPtrExpr(MappableExprComponentListRef Components, + OpenMPDirectiveKind CurDirKind); + protected: // Return the total number of elements in a list of component lists. static unsigned diff --git a/clang/include/clang/Basic/OpenMPKinds.h b/clang/include/clang/Basic/OpenMPKinds.h index f40db4c13c55a..e37887e8b86ba 100644 --- a/clang/include/clang/Basic/OpenMPKinds.h +++ b/clang/include/clang/Basic/OpenMPKinds.h @@ -301,6 +301,14 @@ bool isOpenMPTargetExecutionDirective(OpenMPDirectiveKind DKind); /// otherwise - false. bool isOpenMPTargetDataManagementDirective(OpenMPDirectiveKind DKind); +/// Checks if the specified directive is a map-entering target directive. +/// \param DKind Specified directive. +/// \return true - the directive is a map-entering target directive like +/// 'omp target', 'omp target data', 'omp target enter data', +/// 'omp target parallel', etc. (excludes 'omp target exit data', 'omp target +/// update') otherwise - false. +bool isOpenMPTargetMapEnteringDirective(OpenMPDirectiveKind DKind); + /// Checks if the specified composite/combined directive constitutes a teams /// directive in the outermost nest. For example /// 'omp teams distribute' or 'omp teams distribute parallel for'. diff --git a/clang/lib/AST/OpenMPClause.cpp b/clang/lib/AST/OpenMPClause.cpp index 588b0dcc6d7b8..eff897a1a33b2 100644 --- a/clang/lib/AST/OpenMPClause.cpp +++ b/clang/lib/AST/OpenMPClause.cpp @@ -15,6 +15,7 @@ #include "clang/AST/Attr.h" #include "clang/AST/Decl.h" #include "clang/AST/DeclOpenMP.h" +#include "clang/AST/ExprOpenMP.h" #include "clang/Basic/LLVM.h" #include "clang/Basic/OpenMPKinds.h" #include "clang/Basic/TargetInfo.h" @@ -1156,6 +1157,73 @@ unsigned OMPClauseMappableExprCommon::getUniqueDeclarationsTotalNumber( return UniqueDecls.size(); } +QualType +OMPClauseMappableExprCommon::getComponentExprElementType(const Expr *Exp) { + assert(!isa<OMPArrayShapingExpr>(Exp) && + "Cannot get element-type from array-shaping expr."); + + // Unless we are handling array-section expressions, including + // array-subscripts, derefs, we can rely on getType. + if (!isa<ArraySectionExpr>(Exp)) + return Exp->getType().getNonReferenceType().getCanonicalType(); + + // For array-sections, we need to find the type of one element of + // the section. + const auto *OASE = cast<ArraySectionExpr>(Exp); + + QualType BaseType = ArraySectionExpr::getBaseOriginalType(OASE->getBase()); + + QualType ElemTy; + if (const auto *ATy = BaseType->getAsArrayTypeUnsafe()) + ElemTy = ATy->getElementType(); + else + ElemTy = BaseType->getPointeeType(); + + ElemTy = ElemTy.getNonReferenceType().getCanonicalType(); + return ElemTy; +} + +std::pair<const Expr *, std::optional<size_t>> +OMPClauseMappableExprCommon::findAttachPtrExpr( + MappableExprComponentListRef Components, OpenMPDirectiveKind CurDirKind) { + + // If we only have a single component, we have a map like "map(p)", which + // cannot have a base-pointer. + if (Components.size() < 2) + return {nullptr, std::nullopt}; + + // Only check for non-contiguous sections on target_update, since we can + // assume array-sections are contiguous on maps on other constructs, even if + // we are not sure of it at compile-time, like for a[1:x][2]. + if (Components.back().isNonContiguous() && CurDirKind == OMPD_target_update) + return {nullptr, std::nullopt}; + + // To find the attach base-pointer, we start with the second component, + // stripping away one component at a time, until we reach a pointer Expr + // (that is not a binary operator). The first such pointer should be the + // attach base-pointer for the component list. + for (size_t I = 1; I < Components.size(); ++I) { + const Expr *CurExpr = Components[I].getAssociatedExpression(); + if (!CurExpr) + break; + + // If CurExpr is something like `p + 10`, we need to ignore it, since + // we are looking for `p`. + if (isa<BinaryOperator>(CurExpr)) + continue; + + // Keep going until we reach an Expr of pointer type. + QualType CurType = getComponentExprElementType(CurExpr); + if (!CurType->isPointerType()) + continue; + + // We have found a pointer Expr. This must be the attach pointer. + return {CurExpr, Components.size() - I}; + } + + return {nullptr, std::nullopt}; +} + OMPMapClause *OMPMapClause::Create( const ASTContext &C, const OMPVarListLocTy &Locs, ArrayRef<Expr *> Vars, ArrayRef<ValueDecl *> Declarations, diff --git a/clang/lib/Basic/OpenMPKinds.cpp b/clang/lib/Basic/OpenMPKinds.cpp index 220b31b0f19bc..2f2a5b66e4ca5 100644 --- a/clang/lib/Basic/OpenMPKinds.cpp +++ b/clang/lib/Basic/OpenMPKinds.cpp @@ -650,6 +650,11 @@ bool clang::isOpenMPTargetDataManagementDirective(OpenMPDirectiveKind DKind) { DKind == OMPD_target_exit_data || DKind == OMPD_target_update; } +bool clang::isOpenMPTargetMapEnteringDirective(OpenMPDirectiveKind DKind) { + return DKind == OMPD_target_data || DKind == OMPD_target_enter_data || + isOpenMPTargetExecutionDirective(DKind); +} + bool clang::isOpenMPNestingTeamsDirective(OpenMPDirectiveKind DKind) { if (DKind == OMPD_teams) return true; diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp b/clang/lib/CodeGen/CGOpenMPRuntime.cpp index f98339d472fa9..d592c29a412a9 100644 --- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp +++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp @@ -6765,12 +6765,256 @@ llvm::Value *CGOpenMPRuntime::emitNumThreadsForTargetDirective( namespace { LLVM_ENABLE_BITMASK_ENUMS_IN_NAMESPACE(); +/// Utility to compare expression locations. +/// Returns true if expr-loc of LHS is less-than that of RHS. +/// This function asserts that both expressions have valid expr-locations. +static bool compareExprLocs(const Expr *LHS, const Expr *RHS) { + // Assert that neither LHS nor RHS can be null + assert(LHS && "LHS expression cannot be null"); + assert(RHS && "RHS expression cannot be null"); + + // Get source locations + SourceLocation LocLHS = LHS->getExprLoc(); + SourceLocation LocRHS = RHS->getExprLoc(); + + // Assert that we have valid source locations + assert(LocLHS.isValid() && "LHS expression must have valid source location"); + assert(LocRHS.isValid() && "RHS expression must have valid source location"); + + // Compare source locations for deterministic ordering + bool result = LocLHS < LocRHS; + return result; +} + // Utility to handle information from clauses associated with a given // construct that use mappable expressions (e.g. 'map' clause, 'to' clause). // It provides a convenient interface to obtain the information and generate // code for that information. class MappableExprsHandler { public: + /// Custom comparator for attach-pointer expressions that compares them by + /// complexity (i.e. their component-depth) first, then by their expr-locs if + /// they are semantically different. + struct AttachPtrExprComparator { + const MappableExprsHandler *Handler; + // Cache of previous equality comparison results. + mutable llvm::DenseMap<std::pair<const Expr *, const Expr *>, bool> + CachedEqualityComparisons; + + AttachPtrExprComparator(const MappableExprsHandler *H) : Handler(H) {} + + // Return true iff LHS is "less than" RHS. + bool operator()(const Expr *LHS, const Expr *RHS) const { + if (LHS == RHS) + return false; + + // First, compare by complexity (depth) + auto ItLHS = Handler->AttachPtrComponentDepthMap.find(LHS); + auto ItRHS = Handler->AttachPtrComponentDepthMap.find(RHS); + + std::optional<size_t> DepthLHS = + (ItLHS != Handler->AttachPtrComponentDepthMap.end()) ? ItLHS->second + : std::nullopt; + std::optional<size_t> DepthRHS = + (ItRHS != Handler->AttachPtrComponentDepthMap.end()) ? ItRHS->second + : std::nullopt; + + // std::nullopt (no attach pointer) has lowest complexity + if (!DepthLHS.has_value() && !DepthRHS.has_value()) { + // Both have same complexity, now check semantic equality + if (areEqual(LHS, RHS)) + return false; + // Different semantically, compare by location + return compareExprLocs(LHS, RHS); + } + if (!DepthLHS.has_value()) + return true; // LHS has lower complexity + if (!DepthRHS.has_value()) + return false; // RHS has lower complexity + + // Both have values, compare by depth (lower depth = lower complexity) + if (DepthLHS.value() != DepthRHS.value()) + return DepthLHS.value() < DepthRHS.value(); + + // Same complexity, now check semantic equality + if (areEqual(LHS, RHS)) + return false; + // Different semantically, compare by location + return compareExprLocs(LHS, RHS); + } + + public: + /// Return true if \p LHS and \p RHS are semantically equal. Uses pre-cached + /// results, if available, otherwise does a recursive semantic comparison. + bool areEqual(const Expr *LHS, const Expr *RHS) const { + // Check cache first for faster lookup + auto CachedResultIt = CachedEqualityComparisons.find({LHS, RHS}); + if (CachedResultIt != CachedEqualityComparisons.end()) + return CachedResultIt->second; + + bool ComparisonResult = areSemanticallyEqual(LHS, RHS); + + // Cache the result for future lookups (both orders since semantic + // equality is commutative) + CachedEqualityComparisons[{LHS, RHS}] = ComparisonResult; + CachedEqualityComparisons[{RHS, LHS}] = ComparisonResult; + return ComparisonResult; + } + + private: + /// Helper function to compare attach-pointer expressions semantically. + /// This function handles various expression types that can be part of an + /// attach-pointer. + /// TODO: Not urgent, but we should ideally return true when comparing + /// `p[10]`, `*(p + 10)`, `*(p + 5 + 5)`, `p[10:1]` etc. + bool areSemanticallyEqual(const Expr *LHS, const Expr *RHS) const { + if (LHS == RHS) + return true; + + // If only one is null, they aren't equal + if (!LHS || !RHS) + return false; + + ASTContext &Ctx = Handler->CGF.getContext(); + // Strip away parentheses and no-op casts to get to the core expression + LHS = LHS->IgnoreParenNoopCasts(Ctx); + RHS = RHS->IgnoreParenNoopCasts(Ctx); + + // Direct pointer comparison of the underlying expressions + if (LHS == RHS) + return true; + + // Check if the expression classes match + if (LHS->getStmtClass() != RHS->getStmtClass()) + return false; + + // Handle DeclRefExpr (variable references) + if (const auto *LD = dyn_cast<DeclRefExpr>(LHS)) { + const auto *RD = dyn_cast<DeclRefExpr>(RHS); + if (!RD) + return false; + return LD->getDecl()->getCanonicalDecl() == + RD->getDecl()->getCanonicalDecl(); + } + + // Handle ArraySubscriptExpr (array indexing like a[i]) + if (const auto *LA = dyn_cast<ArraySubscriptExpr>(LHS)) { + const auto *RA = dyn_cast<ArraySubscriptExpr>(RHS); + if (!RA) + return false; + return areSemanticallyEqual(LA->getBase(), RA->getBase()) && + areSemanticallyEqual(LA->getIdx(), RA->getIdx()); + } + + // Handle MemberExpr (member access like s.m or p->m) + if (const auto *LM = dyn_cast<MemberExpr>(LHS)) { + const auto *RM = dyn_cast<MemberExpr>(RHS); + if (!RM) + return false; + if (LM->getMemberDecl()->getCanonicalDecl() != + RM->getMemberDecl()->getCanonicalDecl()) + return false; + return areSemanticallyEqual(LM->getBase(), RM->getBase()); + } + + // Handle UnaryOperator (unary operations like *p, &x, etc.) + if (const auto *LU = dyn_cast<UnaryOperator>(LHS)) { + const auto *RU = dyn_cast<UnaryOperator>(RHS); + if (!RU) + return false; + if (LU->getOpcode() != RU->getOpcode()) + return false; + return areSemanticallyEqual(LU->getSubExpr(), RU->getSubExpr()); + } + + // Handle BinaryOperator (binary operations like p + offset) + if (const auto *LB = dyn_cast<BinaryOperator>(LHS)) { + const auto *RB = dyn_cast<BinaryOperator>(RHS); + if (!RB) + return false; + if (LB->getOpcode() != RB->getOpcode()) + return false; + return areSemanticallyEqual(LB->getLHS(), RB->getLHS()) && + areSemanticallyEqual(LB->getRHS(), RB->getRHS()); + } + + // Handle ArraySectionExpr (array sections like a[0:1]) + // Attach pointers should not contain array-sections, but currently we + // don't emit an error. + if (const auto *LAS = dyn_cast<ArraySectionExpr>(LHS)) { + const auto *RAS = dyn_cast<ArraySectionExpr>(RHS); + if (!RAS) + return false; + return areSemanticallyEqual(LAS->getBase(), RAS->getBase()) && + areSemanticallyEqual(LAS->getLowerBound(), + RAS->getLowerBound()) && + areSemanticallyEqual(LAS->getLength(), RAS->getLength()); + } + + // Handle CastExpr (explicit casts) + if (const auto *LC = dyn_cast<CastExpr>(LHS)) { + const auto *RC = dyn_cast<CastExpr>(RHS); + if (!RC) + return false; + if (LC->getCastKind() != RC->getCastKind()) + return false; + return areSemanticallyEqual(LC->getSubExpr(), RC->getSubExpr()); + } + + // Handle CXXThisExpr (this pointer) + if (isa<CXXThisExpr>(LHS) && isa<CXXThisExpr>(RHS)) + return true; + + // Handle IntegerLiteral (integer constants) + if (const auto *LI = dyn_cast<IntegerLiteral>(LHS)) { + const auto *RI = dyn_cast<IntegerLiteral>(RHS); + if (!RI) + return false; + return LI->getValue() == RI->getValue(); + } + + // Handle CharacterLiteral (character constants) + if (const auto *LC = dyn_cast<CharacterLiteral>(LHS)) { + const auto *RC = dyn_cast<CharacterLiteral>(RHS); + if (!RC) + return false; + return LC->getValue() == RC->getValue(); + } + + // Handle FloatingLiteral (floating point constants) + if (const auto *LF = dyn_cast<FloatingLiteral>(LHS)) { + const auto *RF = dyn_cast<FloatingLiteral>(RHS); + if (!RF) + return false; + // Use bitwise comparison for floating point literals + return LF->getValue().bitwiseIsEqual(RF->getValue()); + } + + // Handle StringLiteral (string constants) + if (const auto *LS = dyn_cast<StringLiteral>(LHS)) { + const auto *RS = dyn_cast<StringLiteral>(RHS); + if (!RS) + return false; + return LS->getString() == RS->getString(); + } + + // Handle CXXNullPtrLiteralExpr (nullptr) + if (isa<CXXNullPtrLiteralExpr>(LHS) && isa<CXXNullPtrLiteralExpr>(RHS)) + return true; + + // Handle CXXBoolLiteralExpr (true/false) + if (const auto *LB = dyn_cast<CXXBoolLiteralExpr>(LHS)) { + const auto *RB = dyn_cast<CXXBoolLiteralExpr>(RHS); + if (!RB) + return false; + return LB->getValue() == RB->getValue(); + } + + // Fallback for other forms - use the existing comparison method + return Expr::isSameComparisonOperand(LHS, RHS); + } + }; + /// Get the offset of the OMP_MAP_MEMBER_OF field. static unsigned getFlagMemberOffset() { unsigned Offset = 0; @@ -6846,8 +7090,42 @@ class MappableExprsHandler { Address LB = Address::invalid(); bool IsArraySection = false; bool HasCompleteRecord = false; + // ATTACH information for delaye... [truncated] `````````` </details> https://github.com/llvm/llvm-project/pull/155625 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits