ymandel created this revision.
ymandel added a reviewer: ilya-biryukov.
Herald added a project: clang.

This revision adds a new kind of rewrite rule, `CompositeRewriteRule`, which
composes multiple subrules into a new rule that allows ordered-choice among its
subrules. With this feature, users can write the rules that appear later in the
list of subrules knowing that previous rules' patterns *have not matched*,
freeing them from reasoning about those cases in the current pattern.


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D61335

Files:
  clang/include/clang/Tooling/Refactoring/Transformer.h
  clang/lib/Tooling/Refactoring/Transformer.cpp
  clang/unittests/Tooling/TransformerTest.cpp

Index: clang/unittests/Tooling/TransformerTest.cpp
===================================================================
--- clang/unittests/Tooling/TransformerTest.cpp
+++ clang/unittests/Tooling/TransformerTest.cpp
@@ -116,7 +116,8 @@
     };
   }
 
-  void testRule(RewriteRule Rule, StringRef Input, StringRef Expected) {
+  template <typename R>
+  void testRule(R Rule, StringRef Input, StringRef Expected) {
     Transformer T(std::move(Rule), consumer());
     T.registerMatchers(&MatchFinder);
     compareSnippets(Expected, rewrite(Input));
@@ -375,6 +376,92 @@
            Input, Expected);
 }
 
+TEST_F(TransformerTest, OrderedRuleUnrelated) {
+  StringRef Flag = "flag";
+  RewriteRule FlagRule = makeRule(
+      cxxMemberCallExpr(on(expr(hasType(cxxRecordDecl(
+                                    hasName("proto::ProtoCommandLineFlag"))))
+                               .bind(Flag)),
+                        unless(callee(cxxMethodDecl(hasName("GetProto"))))),
+      change<clang::Expr>(Flag, "PROTO"));
+
+  std::string Input = R"cc(
+    proto::ProtoCommandLineFlag flag;
+    int x = flag.foo();
+    int y = flag.GetProto().foo();
+    int f(string s) { return strlen(s.c_str()); }
+  )cc";
+  std::string Expected = R"cc(
+    proto::ProtoCommandLineFlag flag;
+    int x = PROTO.foo();
+    int y = flag.GetProto().foo();
+    int f(string s) { return REPLACED; }
+  )cc";
+
+  testRule(makeOrderedRule({ruleStrlenSize(), FlagRule}), Input, Expected);
+}
+
+// Version of ruleStrlenSizeAny that inserts a method with a different name than
+// ruleStrlenSize, so we can tell their effect apart.
+RewriteRule ruleStrlenSizeDistinct() {
+  StringRef S;
+  return makeRule(
+      callExpr(callee(functionDecl(hasName("strlen"))),
+               hasArgument(0, cxxMemberCallExpr(
+                                  on(expr().bind(S)),
+                                  callee(cxxMethodDecl(hasName("c_str")))))),
+      change<clang::Expr>("DISTINCT"));
+}
+
+TEST_F(TransformerTest, OrderedRuleRelated) {
+  std::string Input = R"cc(
+    namespace foo {
+    struct mystring {
+      char* c_str();
+    };
+    int f(mystring s) { return strlen(s.c_str()); }
+    }  // namespace foo
+    int g(string s) { return strlen(s.c_str()); }
+  )cc";
+  std::string Expected = R"cc(
+    namespace foo {
+    struct mystring {
+      char* c_str();
+    };
+    int f(mystring s) { return DISTINCT; }
+    }  // namespace foo
+    int g(string s) { return REPLACED; }
+  )cc";
+
+  testRule(makeOrderedRule({ruleStrlenSize(), ruleStrlenSizeDistinct()}), Input,
+           Expected);
+}
+
+// Change the order of the rules to get a different result.
+TEST_F(TransformerTest, OrderedRuleRelatedSwapped) {
+  std::string Input = R"cc(
+    namespace foo {
+    struct mystring {
+      char* c_str();
+    };
+    int f(mystring s) { return strlen(s.c_str()); }
+    }  // namespace foo
+    int g(string s) { return strlen(s.c_str()); }
+  )cc";
+  std::string Expected = R"cc(
+    namespace foo {
+    struct mystring {
+      char* c_str();
+    };
+    int f(mystring s) { return DISTINCT; }
+    }  // namespace foo
+    int g(string s) { return DISTINCT; }
+  )cc";
+
+  testRule(makeOrderedRule({ruleStrlenSizeDistinct(), ruleStrlenSize()}), Input,
+           Expected);
+}
+
 //
 // Negative tests (where we expect no transformation to occur).
 //
Index: clang/lib/Tooling/Refactoring/Transformer.cpp
===================================================================
--- clang/lib/Tooling/Refactoring/Transformer.cpp
+++ clang/lib/Tooling/Refactoring/Transformer.cpp
@@ -28,6 +28,7 @@
 using namespace tooling;
 
 using ast_matchers::MatchFinder;
+using ast_matchers::internal::DynTypedMatcher;
 using ast_type_traits::ASTNodeKind;
 using ast_type_traits::DynTypedNode;
 using llvm::Error;
@@ -171,7 +172,7 @@
   return Transformations;
 }
 
-RewriteRule tooling::makeRule(ast_matchers::internal::DynTypedMatcher M,
+RewriteRule tooling::makeRule(DynTypedMatcher M,
                               SmallVector<ASTEdit, 1> Edits) {
   M.setAllowBind(true);
   // `tryBind` is guaranteed to succeed, because `AllowBind` was set to true.
@@ -181,6 +182,80 @@
 
 constexpr llvm::StringLiteral RewriteRule::RootId;
 
+// Determines whether A is higher than B in the class hierarchy.
+static bool isHigher(ASTNodeKind A, ASTNodeKind B) {
+  static auto QualKind = ASTNodeKind::getFromNodeKind<QualType>();
+  static auto TypeKind = ASTNodeKind::getFromNodeKind<Type>();
+  /// Mimic the implicit conversions of Matcher<>.
+  /// - From Matcher<Type> to Matcher<QualType>
+  /// - From Matcher<Base> to Matcher<Derived>
+  return (A.isSame(TypeKind) && B.isSame(QualKind)) || A.isBaseOf(B);
+}
+
+// Try to find a common kind to which all of the rule's matchers can be
+// converted.
+static ASTNodeKind findCommonKind(const std::vector<RewriteRule> &Rules) {
+  assert(!Rules.empty());
+  ASTNodeKind JoinKind = Rules[0].Matcher.getSupportedKind();
+  // Find a (least) Kind K, for which M.canConvertTo(K) holds, for all matchers
+  // M in Rules.
+  for (const auto &R : Rules) {
+    auto K = R.Matcher.getSupportedKind();
+    if (isHigher(JoinKind, K)) {
+      JoinKind = K;
+      continue;
+    }
+    if (K.isSame(JoinKind) || isHigher(K, JoinKind))
+      // JoinKind is already the lowest.
+      continue;
+    // K and JoinKind are unrelated -- there is no least common kind.
+    return ASTNodeKind();
+  }
+  return JoinKind;
+}
+
+// Binds each rule's matcher to a unique (and deterministic) tag based on
+// `TagBase`.
+static std::vector<DynTypedMatcher>
+taggedMatchers(StringRef TagBase, const std::vector<RewriteRule> &Rules) {
+  std::vector<DynTypedMatcher> Matchers;
+  Matchers.reserve(Rules.size());
+  size_t count = 0;
+  for (const auto &R : Rules) {
+    std::string Tag = (TagBase + Twine(count)).str();
+    ++count;
+    auto M = R.Matcher.tryBind(Tag);
+    assert(M && "RewriteRule matchers should be bindable.");
+    Matchers.push_back(*std::move(M));
+  }
+  return Matchers;
+}
+
+CompositeRewriteRule tooling::makeOrderedRule(std::vector<RewriteRule> Rules) {
+  auto CommonKind = findCommonKind(Rules);
+  assert(!CommonKind.isNone() && "Rules must have compatible matchers.");
+  // Explicitly bind `M` to ensure we use `Rules` before it is moved.
+  auto M = DynTypedMatcher::constructVariadic(
+      DynTypedMatcher::VO_AnyOf, CommonKind, taggedMatchers("Tag", Rules));
+  return {std::move(M), std::move(Rules)};
+}
+
+// Finds the rule that was "selected" -- that is, whose matcher triggered the
+// `MatchResult`.
+const RewriteRule &tooling::findSelectedRule(const CompositeRewriteRule &Rule,
+                                             const MatchResult &Result) {
+  if (Rule.Rules.size() == 1)
+    return Rule.Rules[0];
+
+  auto &NodesMap = Result.Nodes.getMap();
+  for (size_t i = 0, N = Rule.Rules.size(); i < N; ++i) {
+    std::string Tag = ("Tag" + Twine(i)).str();
+    if (NodesMap.find(Tag) != NodesMap.end())
+      return Rule.Rules[i];
+  }
+  llvm_unreachable("No tag found for rule set.");
+}
+
 void Transformer::registerMatchers(MatchFinder *MatchFinder) {
   MatchFinder->addDynamicMatcher(Rule.Matcher, this);
 }
@@ -197,7 +272,8 @@
       Root->second.getSourceRange().getBegin());
   assert(RootLoc.isValid() && "Invalid location for Root node of match.");
 
-  auto Transformations = translateEdits(Result, Rule.Edits);
+  auto Transformations =
+      translateEdits(Result, findSelectedRule(Rule, Result).Edits);
   if (!Transformations) {
     Consumer(Transformations.takeError());
     return;
Index: clang/include/clang/Tooling/Refactoring/Transformer.h
===================================================================
--- clang/include/clang/Tooling/Refactoring/Transformer.h
+++ clang/include/clang/Tooling/Refactoring/Transformer.h
@@ -221,6 +221,26 @@
 translateEdits(const ast_matchers::MatchFinder::MatchResult &Result,
                llvm::ArrayRef<ASTEdit> Edits);
 
+/// Composes multiple rules into a single object that can be registered with a
+/// single matcher.  Upon match, the tags in said matcher can be used to
+/// determine which rule in \c Rules to apply.
+struct CompositeRewriteRule {
+  // Matcher that multiplexes the composed rules.  Demultiplexing is done with
+  // \p findSelectedRule.
+  ast_matchers::internal::DynTypedMatcher Matcher;
+  std::vector<RewriteRule> Rules;
+};
+
+// Creates a composite rule that applies the first rule in `Rules` whose pattern
+// matches a given node. All of the rules must use the same kind of matcher
+// (that is, share a base class in the AST hierarchy).
+CompositeRewriteRule makeOrderedRule(std::vector<RewriteRule> Rules);
+
+/// Returns the subrule of \c Rule that was selected in the given match result.
+const RewriteRule &
+findSelectedRule(const CompositeRewriteRule &Rule,
+                 const ast_matchers::MatchFinder::MatchResult &Result);
+
 /// Handles the matcher and callback registration for a single rewrite rule, as
 /// defined by the arguments of the constructor.
 class Transformer : public ast_matchers::MatchFinder::MatchCallback {
@@ -233,9 +253,12 @@
   /// because of macros, but doesn't fail.  Note that clients are responsible
   /// for handling the case that independent \c AtomicChanges conflict with each
   /// other.
-  Transformer(RewriteRule Rule, ChangeConsumer Consumer)
+  Transformer(CompositeRewriteRule Rule, ChangeConsumer Consumer)
       : Rule(std::move(Rule)), Consumer(std::move(Consumer)) {}
 
+  Transformer(RewriteRule Rule, ChangeConsumer Consumer)
+      : Rule{Rule.Matcher, {std::move(Rule)}}, Consumer(std::move(Consumer)) {}
+
   /// N.B. Passes `this` pointer to `MatchFinder`.  So, this object should not
   /// be moved after this call.
   void registerMatchers(ast_matchers::MatchFinder *MatchFinder);
@@ -245,7 +268,7 @@
   void run(const ast_matchers::MatchFinder::MatchResult &Result) override;
 
 private:
-  RewriteRule Rule;
+  CompositeRewriteRule Rule;
   /// Receives each successful rewrites as an \c AtomicChange.
   ChangeConsumer Consumer;
 };
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to