llvmbot wrote:

<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-clang-codegen

Author: Abhinav Gaba (abhinavgaba)

<details>
<summary>Changes</summary>

These have been pulled out of the codegen PR #<!-- -->153683, to reduce the 
size of that PR.

---

Patch is 25.43 KiB, truncated to 20.00 KiB below, full version: 
https://github.com/llvm/llvm-project/pull/155625.diff


5 Files Affected:

- (modified) clang/include/clang/AST/OpenMPClause.h (+95) 
- (modified) clang/include/clang/Basic/OpenMPKinds.h (+8) 
- (modified) clang/lib/AST/OpenMPClause.cpp (+68) 
- (modified) clang/lib/Basic/OpenMPKinds.cpp (+5) 
- (modified) clang/lib/CodeGen/CGOpenMPRuntime.cpp (+358) 


``````````diff
diff --git a/clang/include/clang/AST/OpenMPClause.h 
b/clang/include/clang/AST/OpenMPClause.h
index 1118d3e062e68..9627e99a306b4 100644
--- a/clang/include/clang/AST/OpenMPClause.h
+++ b/clang/include/clang/AST/OpenMPClause.h
@@ -5815,6 +5815,12 @@ class OMPClauseMappableExprCommon {
     ValueDecl *getAssociatedDeclaration() const {
       return AssociatedDeclaration;
     }
+
+    bool operator==(const MappableComponent &Other) const {
+      return AssociatedExpressionNonContiguousPr ==
+                 Other.AssociatedExpressionNonContiguousPr &&
+             AssociatedDeclaration == Other.AssociatedDeclaration;
+    }
   };
 
   // List of components of an expression. This first one is the whole
@@ -5828,6 +5834,95 @@ class OMPClauseMappableExprCommon {
   using MappableExprComponentLists = SmallVector<MappableExprComponentList, 8>;
   using MappableExprComponentListsRef = ArrayRef<MappableExprComponentList>;
 
+  // Hash function to allow usage as DenseMap keys.
+  friend llvm::hash_code hash_value(const MappableComponent &MC) {
+    return llvm::hash_combine(MC.getAssociatedExpression(),
+                              MC.getAssociatedDeclaration(),
+                              MC.isNonContiguous());
+  }
+
+public:
+  /// Get the type of an element of a ComponentList Expr \p Exp.
+  ///
+  /// For something like the following:
+  /// ```c
+  ///  int *p, **p;
+  /// ```
+  /// The types for the following Exprs would be:
+  ///   Expr     | Type
+  ///   ---------|-----------
+  ///   p        | int *
+  ///   *p       | int
+  ///   p[0]     | int
+  ///   p[0:1]   | int
+  ///   pp       | int **
+  ///   pp[0]    | int *
+  ///   pp[0:1]  | int *
+  /// Note: this assumes that if \p Exp is an array-section, it is contiguous.
+  static QualType getComponentExprElementType(const Expr *Exp);
+
+  /// Find the attach pointer expression from a list of mappable expression
+  /// components.
+  ///
+  /// This function traverses the component list to find the first
+  /// expression that has a pointer type, which represents the attach
+  /// base pointer expr for the current component-list.
+  ///
+  /// For example, given the following:
+  ///
+  /// ```c
+  ///   struct S {
+  ///     int a;
+  ///     int b[10];
+  ///     int c[10][10];
+  ///     int *p;
+  ///     int **pp;
+  ///   }
+  ///   S s, *ps, **pps, *(pas[10]), ***ppps;
+  ///   int i;
+  /// ```
+  ///
+  /// The base-pointers for the following map operands would be:
+  ///   map list-item   | attach base-pointer   | attach base-pointer
+  ///                   | for directives except | target_update (if
+  ///                   | target_update         | different)
+  ///   ----------------|-----------------------|---------------------
+  ///   s               | N/A                   |
+  ///   s.a             | N/A                   |
+  ///   s.p             | N/A                   |
+  ///   ps              | N/A                   |
+  ///   ps->p           | ps                    |
+  ///   ps[1]           | ps                    |
+  ///   *(ps + 1)       | ps                    |
+  ///   (ps + 1)[1]     | ps                    |
+  ///   ps[1:10]        | ps                    |
+  ///   ps->b[10]       | ps                    |
+  ///   ps->p[10]       | ps->p                 |
+  ///   ps->c[1][2]     | ps                    |
+  ///   ps->c[1:2][2]   | (error diagnostic)    | N/A, TODO: ps
+  ///   ps->c[1:1][2]   | ps                    | N/A, TODO: ps
+  ///   pps[1][2]       | pps[1]                |
+  ///   pps[1:1][2]     | pps[1:1]              | N/A, TODO: pps[1:1]
+  ///   pps[1:i][2]     | pps[1:i]              | N/A, TODO: pps[1:i]
+  ///   pps[1:2][2]     | (error diagnostic)    | N/A
+  ///   pps[1]->p       | pps[1]                |
+  ///   pps[1]->p[10]   | pps[1]                |
+  ///   pas[1]          | N/A                   |
+  ///   pas[1][2]       | pas[1]                |
+  ///   ppps[1][2]      | ppps[1]               |
+  ///   ppps[1][2][3]   | ppps[1][2]            |
+  ///   ppps[1][2:1][3] | ppps[1][2:1]          | N/A, TODO: ppps[1][2:1]
+  ///   ppps[1][2:2][3] | (error diagnostic)    | N/A
+  /// Returns a pair of the attach pointer expression and its depth in the
+  /// component list.
+  /// TODO: This may need to be updated to handle ref_ptr/ptee cases for byref
+  /// map operands.
+  /// TODO: Handle cases for target-update, where the list-item is a
+  /// non-contiguous array-section that still has a base-pointer.
+  static std::pair<const Expr *, std::optional<size_t>>
+  findAttachPtrExpr(MappableExprComponentListRef Components,
+                    OpenMPDirectiveKind CurDirKind);
+
 protected:
   // Return the total number of elements in a list of component lists.
   static unsigned
diff --git a/clang/include/clang/Basic/OpenMPKinds.h 
b/clang/include/clang/Basic/OpenMPKinds.h
index f40db4c13c55a..e37887e8b86ba 100644
--- a/clang/include/clang/Basic/OpenMPKinds.h
+++ b/clang/include/clang/Basic/OpenMPKinds.h
@@ -301,6 +301,14 @@ bool isOpenMPTargetExecutionDirective(OpenMPDirectiveKind 
DKind);
 /// otherwise - false.
 bool isOpenMPTargetDataManagementDirective(OpenMPDirectiveKind DKind);
 
+/// Checks if the specified directive is a map-entering target directive.
+/// \param DKind Specified directive.
+/// \return true - the directive is a map-entering target directive like
+/// 'omp target', 'omp target data', 'omp target enter data',
+/// 'omp target parallel', etc. (excludes 'omp target exit data', 'omp target
+/// update') otherwise - false.
+bool isOpenMPTargetMapEnteringDirective(OpenMPDirectiveKind DKind);
+
 /// Checks if the specified composite/combined directive constitutes a teams
 /// directive in the outermost nest.  For example
 /// 'omp teams distribute' or 'omp teams distribute parallel for'.
diff --git a/clang/lib/AST/OpenMPClause.cpp b/clang/lib/AST/OpenMPClause.cpp
index 588b0dcc6d7b8..eff897a1a33b2 100644
--- a/clang/lib/AST/OpenMPClause.cpp
+++ b/clang/lib/AST/OpenMPClause.cpp
@@ -15,6 +15,7 @@
 #include "clang/AST/Attr.h"
 #include "clang/AST/Decl.h"
 #include "clang/AST/DeclOpenMP.h"
+#include "clang/AST/ExprOpenMP.h"
 #include "clang/Basic/LLVM.h"
 #include "clang/Basic/OpenMPKinds.h"
 #include "clang/Basic/TargetInfo.h"
@@ -1156,6 +1157,73 @@ unsigned 
OMPClauseMappableExprCommon::getUniqueDeclarationsTotalNumber(
   return UniqueDecls.size();
 }
 
+QualType
+OMPClauseMappableExprCommon::getComponentExprElementType(const Expr *Exp) {
+  assert(!isa<OMPArrayShapingExpr>(Exp) &&
+         "Cannot get element-type from array-shaping expr.");
+
+  // Unless we are handling array-section expressions, including
+  // array-subscripts, derefs, we can rely on getType.
+  if (!isa<ArraySectionExpr>(Exp))
+    return Exp->getType().getNonReferenceType().getCanonicalType();
+
+  // For array-sections, we need to find the type of one element of
+  // the section.
+  const auto *OASE = cast<ArraySectionExpr>(Exp);
+
+  QualType BaseType = ArraySectionExpr::getBaseOriginalType(OASE->getBase());
+
+  QualType ElemTy;
+  if (const auto *ATy = BaseType->getAsArrayTypeUnsafe())
+    ElemTy = ATy->getElementType();
+  else
+    ElemTy = BaseType->getPointeeType();
+
+  ElemTy = ElemTy.getNonReferenceType().getCanonicalType();
+  return ElemTy;
+}
+
+std::pair<const Expr *, std::optional<size_t>>
+OMPClauseMappableExprCommon::findAttachPtrExpr(
+    MappableExprComponentListRef Components, OpenMPDirectiveKind CurDirKind) {
+
+  // If we only have a single component, we have a map like "map(p)", which
+  // cannot have a base-pointer.
+  if (Components.size() < 2)
+    return {nullptr, std::nullopt};
+
+  // Only check for non-contiguous sections on target_update, since we can
+  // assume array-sections are contiguous on maps on other constructs, even if
+  // we are not sure of it at compile-time, like for a[1:x][2].
+  if (Components.back().isNonContiguous() && CurDirKind == OMPD_target_update)
+    return {nullptr, std::nullopt};
+
+  // To find the attach base-pointer, we start with the second component,
+  // stripping away one component at a time, until we reach a pointer Expr
+  // (that is not a binary operator). The first such pointer should be the
+  // attach base-pointer for the component list.
+  for (size_t I = 1; I < Components.size(); ++I) {
+    const Expr *CurExpr = Components[I].getAssociatedExpression();
+    if (!CurExpr)
+      break;
+
+    // If CurExpr is something like `p + 10`, we need to ignore it, since
+    // we are looking for `p`.
+    if (isa<BinaryOperator>(CurExpr))
+      continue;
+
+    // Keep going until we reach an Expr of pointer type.
+    QualType CurType = getComponentExprElementType(CurExpr);
+    if (!CurType->isPointerType())
+      continue;
+
+    // We have found a pointer Expr. This must be the attach pointer.
+    return {CurExpr, Components.size() - I};
+  }
+
+  return {nullptr, std::nullopt};
+}
+
 OMPMapClause *OMPMapClause::Create(
     const ASTContext &C, const OMPVarListLocTy &Locs, ArrayRef<Expr *> Vars,
     ArrayRef<ValueDecl *> Declarations,
diff --git a/clang/lib/Basic/OpenMPKinds.cpp b/clang/lib/Basic/OpenMPKinds.cpp
index 220b31b0f19bc..2f2a5b66e4ca5 100644
--- a/clang/lib/Basic/OpenMPKinds.cpp
+++ b/clang/lib/Basic/OpenMPKinds.cpp
@@ -650,6 +650,11 @@ bool 
clang::isOpenMPTargetDataManagementDirective(OpenMPDirectiveKind DKind) {
          DKind == OMPD_target_exit_data || DKind == OMPD_target_update;
 }
 
+bool clang::isOpenMPTargetMapEnteringDirective(OpenMPDirectiveKind DKind) {
+  return DKind == OMPD_target_data || DKind == OMPD_target_enter_data ||
+         isOpenMPTargetExecutionDirective(DKind);
+}
+
 bool clang::isOpenMPNestingTeamsDirective(OpenMPDirectiveKind DKind) {
   if (DKind == OMPD_teams)
     return true;
diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp 
b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
index f98339d472fa9..d592c29a412a9 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
@@ -6765,12 +6765,256 @@ llvm::Value 
*CGOpenMPRuntime::emitNumThreadsForTargetDirective(
 namespace {
 LLVM_ENABLE_BITMASK_ENUMS_IN_NAMESPACE();
 
+/// Utility to compare expression locations.
+/// Returns true if expr-loc of LHS is less-than that of RHS.
+/// This function asserts that both expressions have valid expr-locations.
+static bool compareExprLocs(const Expr *LHS, const Expr *RHS) {
+  // Assert that neither LHS nor RHS can be null
+  assert(LHS && "LHS expression cannot be null");
+  assert(RHS && "RHS expression cannot be null");
+
+  // Get source locations
+  SourceLocation LocLHS = LHS->getExprLoc();
+  SourceLocation LocRHS = RHS->getExprLoc();
+
+  // Assert that we have valid source locations
+  assert(LocLHS.isValid() && "LHS expression must have valid source location");
+  assert(LocRHS.isValid() && "RHS expression must have valid source location");
+
+  // Compare source locations for deterministic ordering
+  bool result = LocLHS < LocRHS;
+  return result;
+}
+
 // Utility to handle information from clauses associated with a given
 // construct that use mappable expressions (e.g. 'map' clause, 'to' clause).
 // It provides a convenient interface to obtain the information and generate
 // code for that information.
 class MappableExprsHandler {
 public:
+  /// Custom comparator for attach-pointer expressions that compares them by
+  /// complexity (i.e. their component-depth) first, then by their expr-locs if
+  /// they are semantically different.
+  struct AttachPtrExprComparator {
+    const MappableExprsHandler *Handler;
+    // Cache of previous equality comparison results.
+    mutable llvm::DenseMap<std::pair<const Expr *, const Expr *>, bool>
+        CachedEqualityComparisons;
+
+    AttachPtrExprComparator(const MappableExprsHandler *H) : Handler(H) {}
+
+    // Return true iff LHS is "less than" RHS.
+    bool operator()(const Expr *LHS, const Expr *RHS) const {
+      if (LHS == RHS)
+        return false;
+
+      // First, compare by complexity (depth)
+      auto ItLHS = Handler->AttachPtrComponentDepthMap.find(LHS);
+      auto ItRHS = Handler->AttachPtrComponentDepthMap.find(RHS);
+
+      std::optional<size_t> DepthLHS =
+          (ItLHS != Handler->AttachPtrComponentDepthMap.end()) ? ItLHS->second
+                                                               : std::nullopt;
+      std::optional<size_t> DepthRHS =
+          (ItRHS != Handler->AttachPtrComponentDepthMap.end()) ? ItRHS->second
+                                                               : std::nullopt;
+
+      // std::nullopt (no attach pointer) has lowest complexity
+      if (!DepthLHS.has_value() && !DepthRHS.has_value()) {
+        // Both have same complexity, now check semantic equality
+        if (areEqual(LHS, RHS))
+          return false;
+        // Different semantically, compare by location
+        return compareExprLocs(LHS, RHS);
+      }
+      if (!DepthLHS.has_value())
+        return true; // LHS has lower complexity
+      if (!DepthRHS.has_value())
+        return false; // RHS has lower complexity
+
+      // Both have values, compare by depth (lower depth = lower complexity)
+      if (DepthLHS.value() != DepthRHS.value())
+        return DepthLHS.value() < DepthRHS.value();
+
+      // Same complexity, now check semantic equality
+      if (areEqual(LHS, RHS))
+        return false;
+      // Different semantically, compare by location
+      return compareExprLocs(LHS, RHS);
+    }
+
+  public:
+    /// Return true if \p LHS and \p RHS are semantically equal. Uses 
pre-cached
+    /// results, if available, otherwise does a recursive semantic comparison.
+    bool areEqual(const Expr *LHS, const Expr *RHS) const {
+      // Check cache first for faster lookup
+      auto CachedResultIt = CachedEqualityComparisons.find({LHS, RHS});
+      if (CachedResultIt != CachedEqualityComparisons.end())
+        return CachedResultIt->second;
+
+      bool ComparisonResult = areSemanticallyEqual(LHS, RHS);
+
+      // Cache the result for future lookups (both orders since semantic
+      // equality is commutative)
+      CachedEqualityComparisons[{LHS, RHS}] = ComparisonResult;
+      CachedEqualityComparisons[{RHS, LHS}] = ComparisonResult;
+      return ComparisonResult;
+    }
+
+  private:
+    /// Helper function to compare attach-pointer expressions semantically.
+    /// This function handles various expression types that can be part of an
+    /// attach-pointer.
+    /// TODO: Not urgent, but we should ideally return true when comparing
+    /// `p[10]`, `*(p + 10)`,  `*(p + 5 + 5)`, `p[10:1]` etc.
+    bool areSemanticallyEqual(const Expr *LHS, const Expr *RHS) const {
+      if (LHS == RHS)
+        return true;
+
+      // If only one is null, they aren't equal
+      if (!LHS || !RHS)
+        return false;
+
+      ASTContext &Ctx = Handler->CGF.getContext();
+      // Strip away parentheses and no-op casts to get to the core expression
+      LHS = LHS->IgnoreParenNoopCasts(Ctx);
+      RHS = RHS->IgnoreParenNoopCasts(Ctx);
+
+      // Direct pointer comparison of the underlying expressions
+      if (LHS == RHS)
+        return true;
+
+      // Check if the expression classes match
+      if (LHS->getStmtClass() != RHS->getStmtClass())
+        return false;
+
+      // Handle DeclRefExpr (variable references)
+      if (const auto *LD = dyn_cast<DeclRefExpr>(LHS)) {
+        const auto *RD = dyn_cast<DeclRefExpr>(RHS);
+        if (!RD)
+          return false;
+        return LD->getDecl()->getCanonicalDecl() ==
+               RD->getDecl()->getCanonicalDecl();
+      }
+
+      // Handle ArraySubscriptExpr (array indexing like a[i])
+      if (const auto *LA = dyn_cast<ArraySubscriptExpr>(LHS)) {
+        const auto *RA = dyn_cast<ArraySubscriptExpr>(RHS);
+        if (!RA)
+          return false;
+        return areSemanticallyEqual(LA->getBase(), RA->getBase()) &&
+               areSemanticallyEqual(LA->getIdx(), RA->getIdx());
+      }
+
+      // Handle MemberExpr (member access like s.m or p->m)
+      if (const auto *LM = dyn_cast<MemberExpr>(LHS)) {
+        const auto *RM = dyn_cast<MemberExpr>(RHS);
+        if (!RM)
+          return false;
+        if (LM->getMemberDecl()->getCanonicalDecl() !=
+            RM->getMemberDecl()->getCanonicalDecl())
+          return false;
+        return areSemanticallyEqual(LM->getBase(), RM->getBase());
+      }
+
+      // Handle UnaryOperator (unary operations like *p, &x, etc.)
+      if (const auto *LU = dyn_cast<UnaryOperator>(LHS)) {
+        const auto *RU = dyn_cast<UnaryOperator>(RHS);
+        if (!RU)
+          return false;
+        if (LU->getOpcode() != RU->getOpcode())
+          return false;
+        return areSemanticallyEqual(LU->getSubExpr(), RU->getSubExpr());
+      }
+
+      // Handle BinaryOperator (binary operations like p + offset)
+      if (const auto *LB = dyn_cast<BinaryOperator>(LHS)) {
+        const auto *RB = dyn_cast<BinaryOperator>(RHS);
+        if (!RB)
+          return false;
+        if (LB->getOpcode() != RB->getOpcode())
+          return false;
+        return areSemanticallyEqual(LB->getLHS(), RB->getLHS()) &&
+               areSemanticallyEqual(LB->getRHS(), RB->getRHS());
+      }
+
+      // Handle ArraySectionExpr (array sections like a[0:1])
+      // Attach pointers should not contain array-sections, but currently we
+      // don't emit an error.
+      if (const auto *LAS = dyn_cast<ArraySectionExpr>(LHS)) {
+        const auto *RAS = dyn_cast<ArraySectionExpr>(RHS);
+        if (!RAS)
+          return false;
+        return areSemanticallyEqual(LAS->getBase(), RAS->getBase()) &&
+               areSemanticallyEqual(LAS->getLowerBound(),
+                                    RAS->getLowerBound()) &&
+               areSemanticallyEqual(LAS->getLength(), RAS->getLength());
+      }
+
+      // Handle CastExpr (explicit casts)
+      if (const auto *LC = dyn_cast<CastExpr>(LHS)) {
+        const auto *RC = dyn_cast<CastExpr>(RHS);
+        if (!RC)
+          return false;
+        if (LC->getCastKind() != RC->getCastKind())
+          return false;
+        return areSemanticallyEqual(LC->getSubExpr(), RC->getSubExpr());
+      }
+
+      // Handle CXXThisExpr (this pointer)
+      if (isa<CXXThisExpr>(LHS) && isa<CXXThisExpr>(RHS))
+        return true;
+
+      // Handle IntegerLiteral (integer constants)
+      if (const auto *LI = dyn_cast<IntegerLiteral>(LHS)) {
+        const auto *RI = dyn_cast<IntegerLiteral>(RHS);
+        if (!RI)
+          return false;
+        return LI->getValue() == RI->getValue();
+      }
+
+      // Handle CharacterLiteral (character constants)
+      if (const auto *LC = dyn_cast<CharacterLiteral>(LHS)) {
+        const auto *RC = dyn_cast<CharacterLiteral>(RHS);
+        if (!RC)
+          return false;
+        return LC->getValue() == RC->getValue();
+      }
+
+      // Handle FloatingLiteral (floating point constants)
+      if (const auto *LF = dyn_cast<FloatingLiteral>(LHS)) {
+        const auto *RF = dyn_cast<FloatingLiteral>(RHS);
+        if (!RF)
+          return false;
+        // Use bitwise comparison for floating point literals
+        return LF->getValue().bitwiseIsEqual(RF->getValue());
+      }
+
+      // Handle StringLiteral (string constants)
+      if (const auto *LS = dyn_cast<StringLiteral>(LHS)) {
+        const auto *RS = dyn_cast<StringLiteral>(RHS);
+        if (!RS)
+          return false;
+        return LS->getString() == RS->getString();
+      }
+
+      // Handle CXXNullPtrLiteralExpr (nullptr)
+      if (isa<CXXNullPtrLiteralExpr>(LHS) && isa<CXXNullPtrLiteralExpr>(RHS))
+        return true;
+
+      // Handle CXXBoolLiteralExpr (true/false)
+      if (const auto *LB = dyn_cast<CXXBoolLiteralExpr>(LHS)) {
+        const auto *RB = dyn_cast<CXXBoolLiteralExpr>(RHS);
+        if (!RB)
+          return false;
+        return LB->getValue() == RB->getValue();
+      }
+
+      // Fallback for other forms - use the existing comparison method
+      return Expr::isSameComparisonOperand(LHS, RHS);
+    }
+  };
+
   /// Get the offset of the OMP_MAP_MEMBER_OF field.
   static unsigned getFlagMemberOffset() {
     unsigned Offset = 0;
@@ -6846,8 +7090,42 @@ class MappableExprsHandler {
     Address LB = Address::invalid();
     bool IsArraySection = false;
     bool HasCompleteRecord = false;
+    // ATTACH information for delaye...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/155625
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to