llvmbot wrote:

<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

Before this change: `notifyOperationReplaced` was triggered when calling 
`RewriteBase::replaceOp`.
After this change: `notifyOperationReplaced` is triggered when 
`RewriterBase::replaceAllOpUsesWith` or `RewriterBase::replaceOp` is called.

Until now, every `notifyOperationReplaced` was always sent together with a 
`notifyOperationErased`, which made that `notifyOperationErased` callback 
irrelevant. More importantly, when a user called 
`RewriterBase::replaceAllOpUsesWith`+`RewriterBase::eraseOp` instead of 
`RewriterBase::replaceOp`, no `notifyOperationReplaced` callback was sent, even 
though the two notations are semantically equivalent. As an example, this can 
be a problem when applying patterns with the transform dialect because the 
`TrackingListener` will only see the `notifyOperationErased` callback and the 
payload op is dropped from the mappings.

Note: It is still possible to write semantically equivalent code that does not 
trigger a `notifyOperationReplaced` (e.g., when op results are replaced 
one-by-one), but this commit already improves the situation a lot.

---
Full diff: https://github.com/llvm/llvm-project/pull/84721.diff


3 Files Affected:

- (modified) mlir/include/mlir/IR/PatternMatch.h (+17-12) 
- (modified) mlir/lib/IR/PatternMatch.cpp (+16-8) 
- (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+4-1) 


``````````diff
diff --git a/mlir/include/mlir/IR/PatternMatch.h 
b/mlir/include/mlir/IR/PatternMatch.h
index 8d84ab6100007e..c1408c3f90a53b 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -409,9 +409,9 @@ class RewriterBase : public OpBuilder {
     /// Notify the listener that the specified operation was modified in-place.
     virtual void notifyOperationModified(Operation *op) {}
 
-    /// Notify the listener that the specified operation is about to be 
replaced
-    /// with another operation. This is called before the uses of the old
-    /// operation have been changed.
+    /// Notify the listener that all uses of the specified operation's results
+    /// are about to be replaced with the results of another operation. This is
+    /// called before the uses of the old operation have been changed.
     ///
     /// By default, this function calls the "operation replaced with values"
     /// notification.
@@ -420,9 +420,10 @@ class RewriterBase : public OpBuilder {
       notifyOperationReplaced(op, replacement->getResults());
     }
 
-    /// Notify the listener that the specified operation is about to be 
replaced
-    /// with the a range of values, potentially produced by other operations.
-    /// This is called before the uses of the operation have been changed.
+    /// Notify the listener that all uses of the specified operation's results
+    /// are about to be replaced with the a range of values, potentially
+    /// produced by other operations. This is called before the uses of the
+    /// operation have been changed.
     virtual void notifyOperationReplaced(Operation *op,
                                          ValueRange replacement) {}
 
@@ -628,12 +629,16 @@ class RewriterBase : public OpBuilder {
     for (auto it : llvm::zip(from, to))
       replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
   }
-  // Note: This function cannot be called `replaceAllUsesWith` because the
-  // overload resolution, when called with an op that can be implicitly
-  // converted to a Value, would be ambiguous.
-  void replaceAllOpUsesWith(Operation *from, ValueRange to) {
-    replaceAllUsesWith(from->getResults(), to);
-  }
+
+  /// Find uses of `from` and replace them with `to`. Also notify the listener
+  /// about every in-place op modification (for every use that was replaced)
+  /// and that the `from` operation is about to be replaced.
+  ///
+  /// Note: This function cannot be called `replaceAllUsesWith` because the
+  /// overload resolution, when called with an op that can be implicitly
+  /// converted to a Value, would be ambiguous.
+  void replaceAllOpUsesWith(Operation *from, ValueRange to);
+  void replaceAllOpUsesWith(Operation *from, Operation *to);
 
   /// Find uses of `from` and replace them with `to` if the `functor` returns
   /// true. Also notify the listener about every in-place op modification (for
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 4079ccc7567256..5944a0ea46a143 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -110,6 +110,22 @@ RewriterBase::~RewriterBase() {
   // Out of line to provide a vtable anchor for the class.
 }
 
+void RewriterBase::replaceAllOpUsesWith(Operation *from, ValueRange to) {
+  // Notify the listener that we're about to replace this op.
+  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
+    rewriteListener->notifyOperationReplaced(from, to);
+
+  replaceAllUsesWith(from->getResults(), to);
+}
+
+void RewriterBase::replaceAllOpUsesWith(Operation *from, Operation *to) {
+  // Notify the listener that we're about to replace this op.
+  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
+    rewriteListener->notifyOperationReplaced(from, to);
+
+  replaceAllUsesWith(from->getResults(), to->getResults());
+}
+
 /// This method replaces the results of the operation with the specified list 
of
 /// values. The number of provided values must match the number of results of
 /// the operation. The replaced op is erased.
@@ -117,10 +133,6 @@ void RewriterBase::replaceOp(Operation *op, ValueRange 
newValues) {
   assert(op->getNumResults() == newValues.size() &&
          "incorrect # of replacement values");
 
-  // Notify the listener that we're about to replace this op.
-  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
-    rewriteListener->notifyOperationReplaced(op, newValues);
-
   // Replace all result uses. Also notifies the listener of modifications.
   replaceAllOpUsesWith(op, newValues);
 
@@ -136,10 +148,6 @@ void RewriterBase::replaceOp(Operation *op, Operation 
*newOp) {
   assert(op->getNumResults() == newOp->getNumResults() &&
          "ops have different number of results");
 
-  // Notify the listener that we're about to replace this op.
-  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
-    rewriteListener->notifyOperationReplaced(op, newOp);
-
   // Replace all result uses. Also notifies the listener of modifications.
   replaceAllOpUsesWith(op, newOp->getResults());
 
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp 
b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 2da184bc3d85ba..76dc825fe44515 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -489,7 +489,10 @@ struct TestStrictPatternDriver
             OperationName("test.new_op", op->getContext()).getIdentifier(),
             op->getOperands(), op->getResultTypes());
       }
-      rewriter.replaceOp(op, newOp->getResults());
+      // "replaceOp" could be used instead of "replaceAllOpUsesWith"+"eraseOp".
+      // A "notifyOperationReplaced" callback is triggered in either case.
+      rewriter.replaceAllOpUsesWith(op, newOp->getResults());
+      rewriter.eraseOp(op);
       return success();
     }
   };

``````````

</details>


https://github.com/llvm/llvm-project/pull/84721
_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to