================
@@ -2770,6 +2770,442 @@ void x() {
   EXPECT_TRUE(notMatchesWithOpenMP(Source2, Matcher));
 }
 
+TEST(ASTMatchersTestOpenMP, OMPTargetUpdateDirective_From) {
+  auto Matcher = stmt(ompTargetUpdateDirective());
+
+  StringRef Source0 = R"(
+    void foo() {
+      int arr[8];
+      #pragma omp target update from(arr[0:8:2])
+      ;
+    }
+  )";
+  EXPECT_TRUE(matchesWithOpenMP(Source0, Matcher));
+}
+
+TEST(ASTMatchersTestOpenMP, OMPTargetUpdateDirective_To) {
+  auto Matcher = stmt(ompTargetUpdateDirective());
+
+  StringRef Source0 = R"(
+    void foo() {
+      int arr[8];
+      #pragma omp target update to(arr[0:8:2])
+      ;
+    }
+  )";
+  EXPECT_TRUE(matchesWithOpenMP(Source0, Matcher));
+}
+
+TEST(ASTMatchersTestOpenMP, OMPFromClause) {
+  auto Matcher = ompTargetUpdateDirective(hasAnyClause(ompFromClause()));
+
+  StringRef Source0 = R"(
+    void foo() {
+      int arr[8];
+      #pragma omp target update from(arr[0:8:2])
+      ;
+    }
+  )";
+  EXPECT_TRUE(matchesWithOpenMP(Source0, Matcher));
+
+  auto astUnit = tooling::buildASTFromCodeWithArgs(Source0, {"-fopenmp"});
+  ASSERT_TRUE(astUnit);
+
+  auto Results = match(ompTargetUpdateDirective().bind("directive"),
+                       astUnit->getASTContext());
+  ASSERT_FALSE(Results.empty());
+
+  const auto *Directive =
+      Results[0].getNodeAs<OMPTargetUpdateDirective>("directive");
+  ASSERT_TRUE(Directive);
+
+  OMPFromClause *FromClause = nullptr;
+  for (auto *Clause : Directive->clauses()) {
+    if ((FromClause = dyn_cast<OMPFromClause>(Clause))) {
+      break;
+    }
+  }
+  ASSERT_TRUE(FromClause);
+
+  for (const auto *VarExpr : FromClause->varlist()) {
+    const auto *ArraySection = dyn_cast<ArraySectionExpr>(VarExpr);
+    if (!ArraySection)
+      continue;
+
+    // base (arr)
+    const Expr *Base = ArraySection->getBase();
+    ASSERT_TRUE(Base);
+
+    // lower bound (0)
+    const Expr *LowerBound = ArraySection->getLowerBound();
+    ASSERT_TRUE(LowerBound);
+    if (const auto *LowerBoundLiteral = dyn_cast<IntegerLiteral>(LowerBound)) {
+      EXPECT_EQ(LowerBoundLiteral->getValue().getZExtValue(), 0u);
+    }
+
+    // length (8)
+    const Expr *Length = ArraySection->getLength();
+    ASSERT_TRUE(Length);
+    if (const auto *LengthLiteral = dyn_cast<IntegerLiteral>(Length)) {
+      EXPECT_EQ(LengthLiteral->getValue().getZExtValue(), 8u);
+    }
+
+    // stride (2)
+    const Expr *Stride = ArraySection->getStride();
+    ASSERT_TRUE(Stride);
+    if (const auto *StrideLiteral = dyn_cast<IntegerLiteral>(Stride)) {
+      EXPECT_EQ(StrideLiteral->getValue().getZExtValue(), 2u);
+    }
+  }
+}
+
+TEST(ASTMatchersTestOpenMP, OMPToClause) {
+  auto Matcher = ompTargetUpdateDirective(hasAnyClause(ompToClause()));
+
+  StringRef Source0 = R"(
+    void foo() {
+      int arr[8];
+      #pragma omp target update to(arr[0:8:2])
+      ;
+    }
+  )";
+  EXPECT_TRUE(matchesWithOpenMP(Source0, Matcher));
+
+  auto astUnit = tooling::buildASTFromCodeWithArgs(Source0, {"-fopenmp"});
+  ASSERT_TRUE(astUnit);
+
+  auto Results = match(ompTargetUpdateDirective().bind("directive"),
+                       astUnit->getASTContext());
+  ASSERT_FALSE(Results.empty());
+
+  const auto *Directive =
+      Results[0].getNodeAs<OMPTargetUpdateDirective>("directive");
+  ASSERT_TRUE(Directive);
+
+  OMPToClause *ToClause = nullptr;
+  for (auto *Clause : Directive->clauses()) {
+    if ((ToClause = dyn_cast<OMPToClause>(Clause))) {
+      break;
+    }
+  }
+  ASSERT_TRUE(ToClause);
+
+  for (const auto *VarExpr : ToClause->varlist()) {
+    const auto *ArraySection = dyn_cast<ArraySectionExpr>(VarExpr);
+    if (!ArraySection)
+      continue;
+
+    // base (arr)
+    const Expr *Base = ArraySection->getBase();
+    ASSERT_TRUE(Base);
+
+    // lower bound (0)
+    const Expr *LowerBound = ArraySection->getLowerBound();
+    ASSERT_TRUE(LowerBound);
+    if (const auto *LowerBoundLiteral = dyn_cast<IntegerLiteral>(LowerBound)) {
+      EXPECT_EQ(LowerBoundLiteral->getValue().getZExtValue(), 0u);
+    }
+
+    // length (8)
+    const Expr *Length = ArraySection->getLength();
+    ASSERT_TRUE(Length);
+    if (const auto *LengthLiteral = dyn_cast<IntegerLiteral>(Length)) {
+      EXPECT_EQ(LengthLiteral->getValue().getZExtValue(), 8u);
+    }
+
+    // stride (2)
+    const Expr *Stride = ArraySection->getStride();
+    ASSERT_TRUE(Stride);
+    if (const auto *StrideLiteral = dyn_cast<IntegerLiteral>(Stride)) {
+      EXPECT_EQ(StrideLiteral->getValue().getZExtValue(), 2u);
+    }
+  }
+}
+
+TEST(ASTMatchersTestOpenMP, OMPFromClause_DoesNotMatchToClause) {
+  auto Matcher = ompTargetUpdateDirective(hasAnyClause(ompFromClause()));
+
+  StringRef Source0 = R"(
+    void foo() {
+      int arr[8];
+      #pragma omp target update to(arr[0:8:2])
+      ;
+    }
+  )";
+  EXPECT_TRUE(notMatchesWithOpenMP(Source0, Matcher));
+}
+
+TEST(ASTMatchersTestOpenMP, OMPToClause_DoesNotMatchFromClause) {
+  auto Matcher = ompTargetUpdateDirective(hasAnyClause(ompToClause()));
+
+  StringRef Source0 = R"(
+    void foo() {
+      int arr[8];
+      #pragma omp target update from(arr[0:8:2])
+      ;
+    }
+  )";
+  EXPECT_TRUE(notMatchesWithOpenMP(Source0, Matcher));
+}
+
+TEST(ASTMatchersTestOpenMP, OMPFromClause_ArraySection_DifferentOffsetValue) {
+  StringRef Source0 = R"(
+    void foo() {
+      int arr[20];
+      #pragma omp target update from(arr[7:8:2])
+      ;
+    }
+  )";
+
+  auto astUnit = tooling::buildASTFromCodeWithArgs(Source0, {"-fopenmp"});
+  ASSERT_TRUE(astUnit);
+
+  auto Results = match(ompTargetUpdateDirective().bind("directive"),
+                       astUnit->getASTContext());
+  ASSERT_FALSE(Results.empty());
+
+  const auto *Directive =
+      Results[0].getNodeAs<OMPTargetUpdateDirective>("directive");
+  ASSERT_TRUE(Directive);
+
+  OMPFromClause *FromClause = nullptr;
+  for (auto *Clause : Directive->clauses()) {
+    if ((FromClause = dyn_cast<OMPFromClause>(Clause))) {
+      break;
+    }
+  }
+  ASSERT_TRUE(FromClause);
+
+  for (const auto *VarExpr : FromClause->varlist()) {
+    const auto *ArraySection = dyn_cast<ArraySectionExpr>(VarExpr);
+    if (!ArraySection)
+      continue;
+
+    // lower bound (7, not 0)
+    const Expr *LowerBound = ArraySection->getLowerBound();
+    ASSERT_TRUE(LowerBound);
+    ASSERT_TRUE(isa<IntegerLiteral>(LowerBound));
+    const auto *LowerBoundLiteral = cast<IntegerLiteral>(LowerBound);
+    EXPECT_NE(LowerBoundLiteral->getValue().getZExtValue(), 0u);
+    EXPECT_EQ(LowerBoundLiteral->getValue().getZExtValue(), 7u);
+  }
+}
+
+TEST(ASTMatchersTestOpenMP, OMPFromClause_ArraySection_DifferentLengthValue) {
+  StringRef Source0 = R"(
+    void foo() {
+      int arr[20];
+      #pragma omp target update from(arr[0:15:2])
+      ;
+    }
+  )";
+
+  auto astUnit = tooling::buildASTFromCodeWithArgs(Source0, {"-fopenmp"});
+  ASSERT_TRUE(astUnit);
+
+  auto Results = match(ompTargetUpdateDirective().bind("directive"),
+                       astUnit->getASTContext());
+  ASSERT_FALSE(Results.empty());
+
+  const auto *Directive =
+      Results[0].getNodeAs<OMPTargetUpdateDirective>("directive");
+  ASSERT_TRUE(Directive);
+
+  OMPFromClause *FromClause = nullptr;
+  for (auto *Clause : Directive->clauses()) {
+    if ((FromClause = dyn_cast<OMPFromClause>(Clause))) {
+      break;
+    }
+  }
+  ASSERT_TRUE(FromClause);
+
+  for (const auto *VarExpr : FromClause->varlist()) {
+    const auto *ArraySection = dyn_cast<ArraySectionExpr>(VarExpr);
+    if (!ArraySection)
+      continue;
+
+    // length (15, not 8)
+    const Expr *Length = ArraySection->getLength();
+    ASSERT_TRUE(Length);
+    ASSERT_TRUE(isa<IntegerLiteral>(Length));
+    const auto *LengthLiteral = cast<IntegerLiteral>(Length);
+    EXPECT_NE(LengthLiteral->getValue().getZExtValue(), 8u);
+    EXPECT_EQ(LengthLiteral->getValue().getZExtValue(), 15u);
+  }
+}
+
+TEST(ASTMatchersTestOpenMP, OMPFromClause_ArraySection_DifferentStrideValue) {
+  StringRef Source0 = R"(
+    void foo() {
+      int arr[20];
+      #pragma omp target update from(arr[0:8:5])
+      ;
+    }
+  )";
+
+  auto astUnit = tooling::buildASTFromCodeWithArgs(Source0, {"-fopenmp"});
+  ASSERT_TRUE(astUnit);
+
+  auto Results = match(ompTargetUpdateDirective().bind("directive"),
+                       astUnit->getASTContext());
+  ASSERT_FALSE(Results.empty());
+
+  const auto *Directive =
+      Results[0].getNodeAs<OMPTargetUpdateDirective>("directive");
+  ASSERT_TRUE(Directive);
+
+  OMPFromClause *FromClause = nullptr;
+  for (auto *Clause : Directive->clauses()) {
+    if ((FromClause = dyn_cast<OMPFromClause>(Clause))) {
+      break;
+    }
+  }
+  ASSERT_TRUE(FromClause);
+
+  for (const auto *VarExpr : FromClause->varlist()) {
+    const auto *ArraySection = dyn_cast<ArraySectionExpr>(VarExpr);
+    if (!ArraySection)
+      continue;
+
+    // stride (5, not 2)
+    const Expr *Stride = ArraySection->getStride();
+    ASSERT_TRUE(Stride);
+    ASSERT_TRUE(isa<IntegerLiteral>(Stride));
+    const auto *StrideLiteral = cast<IntegerLiteral>(Stride);
+    EXPECT_NE(StrideLiteral->getValue().getZExtValue(), 2u);
+    EXPECT_EQ(StrideLiteral->getValue().getZExtValue(), 5u);
+  }
+}
+
+TEST(ASTMatchersTestOpenMP, OMPToClause_ArraySection_DifferentOffsetValue) {
+  StringRef Source0 = R"(
+    void foo() {
+      int arr[20];
+      #pragma omp target update to(arr[4:8:2])
+      ;
+    }
+  )";
+
+  auto astUnit = tooling::buildASTFromCodeWithArgs(Source0, {"-fopenmp"});
+  ASSERT_TRUE(astUnit);
+
+  auto Results = match(ompTargetUpdateDirective().bind("directive"),
+                       astUnit->getASTContext());
+  ASSERT_FALSE(Results.empty());
+
+  const auto *Directive =
+      Results[0].getNodeAs<OMPTargetUpdateDirective>("directive");
+  ASSERT_TRUE(Directive);
+
+  OMPToClause *ToClause = nullptr;
+  for (auto *Clause : Directive->clauses()) {
+    if ((ToClause = dyn_cast<OMPToClause>(Clause))) {
+      break;
+    }
+  }
+  ASSERT_TRUE(ToClause);
+
+  for (const auto *VarExpr : ToClause->varlist()) {
+    const auto *ArraySection = dyn_cast<ArraySectionExpr>(VarExpr);
+    if (!ArraySection)
+      continue;
+
+    // lower bound (4, not 0)
+    const Expr *LowerBound = ArraySection->getLowerBound();
+    ASSERT_TRUE(LowerBound);
+    ASSERT_TRUE(isa<IntegerLiteral>(LowerBound));
+    const auto *LowerBoundLiteral = cast<IntegerLiteral>(LowerBound);
+    EXPECT_NE(LowerBoundLiteral->getValue().getZExtValue(), 0u);
+    EXPECT_EQ(LowerBoundLiteral->getValue().getZExtValue(), 4u);
+  }
+}
+
+TEST(ASTMatchersTestOpenMP, OMPToClause_ArraySection_DifferentLengthValue) {
+  StringRef Source0 = R"(
+    void foo() {
+      int arr[20];
+      #pragma omp target update to(arr[0:20:2])
+      ;
+    }
+  )";
+
+  auto astUnit = tooling::buildASTFromCodeWithArgs(Source0, {"-fopenmp"});
+  ASSERT_TRUE(astUnit);
+
+  auto Results = match(ompTargetUpdateDirective().bind("directive"),
+                       astUnit->getASTContext());
+  ASSERT_FALSE(Results.empty());
+
+  const auto *Directive =
+      Results[0].getNodeAs<OMPTargetUpdateDirective>("directive");
+  ASSERT_TRUE(Directive);
+
+  OMPToClause *ToClause = nullptr;
+  for (auto *Clause : Directive->clauses()) {
+    if ((ToClause = dyn_cast<OMPToClause>(Clause))) {
+      break;
+    }
+  }
+  ASSERT_TRUE(ToClause);
+
+  for (const auto *VarExpr : ToClause->varlist()) {
+    const auto *ArraySection = dyn_cast<ArraySectionExpr>(VarExpr);
+    if (!ArraySection)
+      continue;
+
+    // length (20, not 8)
+    const Expr *Length = ArraySection->getLength();
+    ASSERT_TRUE(Length);
+    ASSERT_TRUE(isa<IntegerLiteral>(Length));
+    const auto *LengthLiteral = cast<IntegerLiteral>(Length);
+    EXPECT_NE(LengthLiteral->getValue().getZExtValue(), 8u);
+    EXPECT_EQ(LengthLiteral->getValue().getZExtValue(), 20u);
+  }
+}
+
+TEST(ASTMatchersTestOpenMP, OMPToClause_ArraySection_DifferentStrideValue) {
+  StringRef Source0 = R"(
+    void foo() {
+      int arr[20];
+      #pragma omp target update to(arr[0:8:6])
+      ;
+    }
+  )";
+
+  auto astUnit = tooling::buildASTFromCodeWithArgs(Source0, {"-fopenmp"});
+  ASSERT_TRUE(astUnit);
+
+  auto Results = match(ompTargetUpdateDirective().bind("directive"),
+                       astUnit->getASTContext());
+  ASSERT_FALSE(Results.empty());
+
+  const auto *Directive =
+      Results[0].getNodeAs<OMPTargetUpdateDirective>("directive");
+  ASSERT_TRUE(Directive);
+
+  OMPToClause *ToClause = nullptr;
+  for (auto *Clause : Directive->clauses()) {
+    if ((ToClause = dyn_cast<OMPToClause>(Clause))) {
+      break;
+    }
+  }
+  ASSERT_TRUE(ToClause);
+
+  for (const auto *VarExpr : ToClause->varlist()) {
+    const auto *ArraySection = dyn_cast<ArraySectionExpr>(VarExpr);
+    if (!ArraySection)
+      continue;
+
+    // stride (6, not 2)
+    const Expr *Stride = ArraySection->getStride();
+    ASSERT_TRUE(Stride);
+    ASSERT_TRUE(isa<IntegerLiteral>(Stride));
----------------
shiltian wrote:

This assertion is redundant as well, because `cast<...>` will crash if the type 
doesn't match.

https://github.com/llvm/llvm-project/pull/150580
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to