[llvm-branch-commits] [mlir] [mlir][IR] Add listener notifications for pattern begin/end (PR #84131)

2024-03-07 Thread Matthias Springer via llvm-branch-commits

https://github.com/matthias-springer updated 
https://github.com/llvm/llvm-project/pull/84131

>From a65d640a0ca2c6810da0878ed42db39f27cebfe1 Mon Sep 17 00:00:00 2001
From: Matthias Springer 
Date: Fri, 8 Mar 2024 07:19:33 +
Subject: [PATCH] [mlir][IR] Add listener notifications for pattern begin/end

---
 mlir/include/mlir/IR/PatternMatch.h   | 30 ++---
 .../Transforms/Utils/DialectConversion.cpp| 29 +++-
 .../Utils/GreedyPatternRewriteDriver.cpp  | 33 +--
 3 files changed, 69 insertions(+), 23 deletions(-)

diff --git a/mlir/include/mlir/IR/PatternMatch.h 
b/mlir/include/mlir/IR/PatternMatch.h
index e3500b3f9446d8..49544c42790d4d 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -432,11 +432,22 @@ class RewriterBase : public OpBuilder {
 /// Note: This notification is not triggered when unlinking an operation.
 virtual void notifyOperationErased(Operation *op) {}
 
-/// Notify the listener that the pattern failed to match the given
-/// operation, and provide a callback to populate a diagnostic with the
-/// reason why the failure occurred. This method allows for derived
-/// listeners to optionally hook into the reason why a rewrite failed, and
-/// display it to users.
+/// Notify the listener that the specified pattern is about to be applied
+/// at the specified root operation.
+virtual void notifyPatternBegin(const Pattern , Operation *op) {}
+
+/// Notify the listener that a pattern application finished with the
+/// specified status. "success" indicates that the pattern was applied
+/// successfully. "failure" indicates that the pattern could not be
+/// applied. The pattern may have communicated the reason for the failure
+/// with `notifyMatchFailure`.
+virtual void notifyPatternEnd(const Pattern ,
+  LogicalResult status) {}
+
+/// Notify the listener that the pattern failed to match, and provide a
+/// callback to populate a diagnostic with the reason why the failure
+/// occurred. This method allows for derived listeners to optionally hook
+/// into the reason why a rewrite failed, and display it to users.
 virtual void
 notifyMatchFailure(Location loc,
function_ref reasonCallback) {}
@@ -478,6 +489,15 @@ class RewriterBase : public OpBuilder {
   if (auto *rewriteListener = dyn_cast(listener))
 rewriteListener->notifyOperationErased(op);
 }
+void notifyPatternBegin(const Pattern , Operation *op) override {
+  if (auto *rewriteListener = dyn_cast(listener))
+rewriteListener->notifyPatternBegin(pattern, op);
+}
+void notifyPatternEnd(const Pattern ,
+  LogicalResult status) override {
+  if (auto *rewriteListener = dyn_cast(listener))
+rewriteListener->notifyPatternEnd(pattern, status);
+}
 void notifyMatchFailure(
 Location loc,
 function_ref reasonCallback) override {
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp 
b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index c1a261eab8487d..cd49bd121a62e5 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1856,7 +1856,8 @@ class OperationLegalizer {
   using LegalizationAction = ConversionTarget::LegalizationAction;
 
   OperationLegalizer(const ConversionTarget ,
- const FrozenRewritePatternSet );
+ const FrozenRewritePatternSet ,
+ const ConversionConfig );
 
   /// Returns true if the given operation is known to be illegal on the target.
   bool isIllegal(Operation *op) const;
@@ -1948,12 +1949,16 @@ class OperationLegalizer {
 
   /// The pattern applicator to use for conversions.
   PatternApplicator applicator;
+
+  /// Dialect conversion configuration.
+  const ConversionConfig 
 };
 } // namespace
 
 OperationLegalizer::OperationLegalizer(const ConversionTarget ,
-   const FrozenRewritePatternSet )
-: target(targetInfo), applicator(patterns) {
+   const FrozenRewritePatternSet ,
+   const ConversionConfig )
+: target(targetInfo), applicator(patterns), config(config) {
   // The set of patterns that can be applied to illegal operations to transform
   // them into legal ones.
   DenseMap legalizerPatterns;
@@ -2098,7 +2103,10 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
 
   // Functor that returns if the given pattern may be applied.
   auto canApply = [&](const Pattern ) {
-return canApplyPattern(op, pattern, rewriter);
+bool canApply = canApplyPattern(op, pattern, rewriter);
+if (canApply && config.listener)
+  config.listener->notifyPatternBegin(pattern, op);
+return canApply;
   };
 
   // Functor that 

[llvm-branch-commits] [mlir] [mlir][IR] Add listener notifications for pattern begin/end (PR #84131)

2024-03-07 Thread Mehdi Amini via llvm-branch-commits

https://github.com/joker-eph approved this pull request.


https://github.com/llvm/llvm-project/pull/84131
___
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits


[llvm-branch-commits] [mlir] [mlir][IR] Add listener notifications for pattern begin/end (PR #84131)

2024-03-07 Thread Mehdi Amini via llvm-branch-commits


@@ -572,20 +571,33 @@ bool GreedyPatternRewriteDriver::processWorklist() {
 logger.getOStream() << ")' {\n";
 logger.indent();
   });
+  if (config.listener)
+config.listener->notifyPatternBegin(pattern, op);
   return true;
 };
-auto onFailure = [&](const Pattern ) {
-  LLVM_DEBUG(logResult("failure", "pattern failed to match"));
-};
-auto onSuccess = [&](const Pattern ) {
-  LLVM_DEBUG(logResult("success", "pattern applied successfully"));
-  return success();
-};
-#else
-function_ref canApply = {};
-function_ref onFailure = {};
-function_ref onSuccess = {};
-#endif
+function_ref onFailure =
+[&](const Pattern ) {
+  LLVM_DEBUG(logResult("failure", "pattern failed to match"));
+  if (config.listener)
+config.listener->notifyPatternEnd(pattern, failure());
+};
+function_ref onSuccess =
+[&](const Pattern ) {
+  LLVM_DEBUG(logResult("success", "pattern applied successfully"));
+  if (config.listener)
+config.listener->notifyPatternEnd(pattern, success());
+  return success();
+};
+
+#ifdef NDEBUG
+// Optimization: PatternApplicator callbacks are not needed when running in
+// optimized mode and without a listener.
+if (!config.listener) {
+  canApply = nullptr;
+  onFailure = nullptr;
+  onSuccess = nullptr;
+}
+#endif // NDEBUG

joker-eph wrote:

Note: I didn't suggest changing this, what you had here was reasonable!

https://github.com/llvm/llvm-project/pull/84131
___
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits


[llvm-branch-commits] [mlir] [mlir][IR] Add listener notifications for pattern begin/end (PR #84131)

2024-03-07 Thread Matthias Springer via llvm-branch-commits


@@ -68,9 +68,9 @@ class PatternApplicator {
   ///invalidate the match and try another pattern.
   LogicalResult
   matchAndRewrite(Operation *op, PatternRewriter ,
-  function_ref canApply = {},
-  function_ref onFailure = {},
-  function_ref onSuccess = {});
+  std::function canApply = {},
+  std::function onFailure = {},
+  std::function onSuccess = 
{});

matthias-springer wrote:

That's a good way to think about it.


https://github.com/llvm/llvm-project/pull/84131
___
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits


[llvm-branch-commits] [mlir] [mlir][IR] Add listener notifications for pattern begin/end (PR #84131)

2024-03-07 Thread Matthias Springer via llvm-branch-commits

https://github.com/matthias-springer updated 
https://github.com/llvm/llvm-project/pull/84131

>From 24a56caffaa23f6da73b129ca96f28e9a9bbf388 Mon Sep 17 00:00:00 2001
From: Matthias Springer 
Date: Fri, 8 Mar 2024 07:19:33 +
Subject: [PATCH] [mlir][IR] Add listener notifications for pattern begin/end

---
 mlir/include/mlir/IR/PatternMatch.h   | 30 ++---
 .../Transforms/Utils/DialectConversion.cpp| 29 +
 .../Utils/GreedyPatternRewriteDriver.cpp  | 42 ---
 3 files changed, 73 insertions(+), 28 deletions(-)

diff --git a/mlir/include/mlir/IR/PatternMatch.h 
b/mlir/include/mlir/IR/PatternMatch.h
index e3500b3f9446d8..49544c42790d4d 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -432,11 +432,22 @@ class RewriterBase : public OpBuilder {
 /// Note: This notification is not triggered when unlinking an operation.
 virtual void notifyOperationErased(Operation *op) {}
 
-/// Notify the listener that the pattern failed to match the given
-/// operation, and provide a callback to populate a diagnostic with the
-/// reason why the failure occurred. This method allows for derived
-/// listeners to optionally hook into the reason why a rewrite failed, and
-/// display it to users.
+/// Notify the listener that the specified pattern is about to be applied
+/// at the specified root operation.
+virtual void notifyPatternBegin(const Pattern , Operation *op) {}
+
+/// Notify the listener that a pattern application finished with the
+/// specified status. "success" indicates that the pattern was applied
+/// successfully. "failure" indicates that the pattern could not be
+/// applied. The pattern may have communicated the reason for the failure
+/// with `notifyMatchFailure`.
+virtual void notifyPatternEnd(const Pattern ,
+  LogicalResult status) {}
+
+/// Notify the listener that the pattern failed to match, and provide a
+/// callback to populate a diagnostic with the reason why the failure
+/// occurred. This method allows for derived listeners to optionally hook
+/// into the reason why a rewrite failed, and display it to users.
 virtual void
 notifyMatchFailure(Location loc,
function_ref reasonCallback) {}
@@ -478,6 +489,15 @@ class RewriterBase : public OpBuilder {
   if (auto *rewriteListener = dyn_cast(listener))
 rewriteListener->notifyOperationErased(op);
 }
+void notifyPatternBegin(const Pattern , Operation *op) override {
+  if (auto *rewriteListener = dyn_cast(listener))
+rewriteListener->notifyPatternBegin(pattern, op);
+}
+void notifyPatternEnd(const Pattern ,
+  LogicalResult status) override {
+  if (auto *rewriteListener = dyn_cast(listener))
+rewriteListener->notifyPatternEnd(pattern, status);
+}
 void notifyMatchFailure(
 Location loc,
 function_ref reasonCallback) override {
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp 
b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index c1a261eab8487d..cd49bd121a62e5 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1856,7 +1856,8 @@ class OperationLegalizer {
   using LegalizationAction = ConversionTarget::LegalizationAction;
 
   OperationLegalizer(const ConversionTarget ,
- const FrozenRewritePatternSet );
+ const FrozenRewritePatternSet ,
+ const ConversionConfig );
 
   /// Returns true if the given operation is known to be illegal on the target.
   bool isIllegal(Operation *op) const;
@@ -1948,12 +1949,16 @@ class OperationLegalizer {
 
   /// The pattern applicator to use for conversions.
   PatternApplicator applicator;
+
+  /// Dialect conversion configuration.
+  const ConversionConfig 
 };
 } // namespace
 
 OperationLegalizer::OperationLegalizer(const ConversionTarget ,
-   const FrozenRewritePatternSet )
-: target(targetInfo), applicator(patterns) {
+   const FrozenRewritePatternSet ,
+   const ConversionConfig )
+: target(targetInfo), applicator(patterns), config(config) {
   // The set of patterns that can be applied to illegal operations to transform
   // them into legal ones.
   DenseMap legalizerPatterns;
@@ -2098,7 +2103,10 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
 
   // Functor that returns if the given pattern may be applied.
   auto canApply = [&](const Pattern ) {
-return canApplyPattern(op, pattern, rewriter);
+bool canApply = canApplyPattern(op, pattern, rewriter);
+if (canApply && config.listener)
+  config.listener->notifyPatternBegin(pattern, op);
+return canApply;
   };
 
   // Functor that cleans up 

[llvm-branch-commits] [mlir] [mlir][IR] Add listener notifications for pattern begin/end (PR #84131)

2024-03-07 Thread Mehdi Amini via llvm-branch-commits

https://github.com/joker-eph edited 
https://github.com/llvm/llvm-project/pull/84131
___
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits


[llvm-branch-commits] [mlir] [mlir][IR] Add listener notifications for pattern begin/end (PR #84131)

2024-03-07 Thread Mehdi Amini via llvm-branch-commits

https://github.com/joker-eph edited 
https://github.com/llvm/llvm-project/pull/84131
___
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits


[llvm-branch-commits] [mlir] [mlir][IR] Add listener notifications for pattern begin/end (PR #84131)

2024-03-07 Thread Mehdi Amini via llvm-branch-commits


@@ -68,9 +68,9 @@ class PatternApplicator {
   ///invalidate the match and try another pattern.
   LogicalResult
   matchAndRewrite(Operation *op, PatternRewriter ,
-  function_ref canApply = {},
-  function_ref onFailure = {},
-  function_ref onSuccess = {});
+  std::function canApply = {},
+  std::function onFailure = {},
+  std::function onSuccess = 
{});

joker-eph wrote:

> What are you referring to with this function?

Where this comment thread is anchored: `matchAndRewrite`

> The problem here is really just caused by the fact that the canApply = 
> assignment is inside of a nested scope. And the lambda object is dead by the 
> time matcher.matchAndRewrite is called. I.e., the canApply function_ref 
> points to an already free'd lambda. At least that's my understanding.

Yes, but that's a problem for the call-site, I don't quite see where you make 
the connection to the signature of `matchAndRewrite`?


https://github.com/llvm/llvm-project/pull/84131
___
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits


[llvm-branch-commits] [mlir] [mlir][IR] Add listener notifications for pattern begin/end (PR #84131)

2024-03-07 Thread Matthias Springer via llvm-branch-commits

https://github.com/matthias-springer edited 
https://github.com/llvm/llvm-project/pull/84131
___
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits


[llvm-branch-commits] [mlir] [mlir][IR] Add listener notifications for pattern begin/end (PR #84131)

2024-03-07 Thread Matthias Springer via llvm-branch-commits


@@ -68,9 +68,9 @@ class PatternApplicator {
   ///invalidate the match and try another pattern.
   LogicalResult
   matchAndRewrite(Operation *op, PatternRewriter ,
-  function_ref canApply = {},
-  function_ref onFailure = {},
-  function_ref onSuccess = {});
+  std::function canApply = {},
+  std::function onFailure = {},
+  std::function onSuccess = 
{});

matthias-springer wrote:

What are you referring to with `this function`?

The problem here is really just caused by the fact that the `canApply =` 
assignment is inside of a nested scope. And the lambda object is dead by the 
time `matcher.matchAndRewrite` is called. I.e., the `canApply` function_ref 
points to an already free'd lambda. At least that's my understanding.

What's the C++ guidelines wrt. to `function` vs. `function_ref`. This is the 
first time I ran into such an issue, and assigning lambdas to `function_ref` 
feels "dangerous" to me now. When using `function`, I don't have to think about 
the lifetime of an object.


https://github.com/llvm/llvm-project/pull/84131
___
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits


[llvm-branch-commits] [mlir] [mlir][IR] Add listener notifications for pattern begin/end (PR #84131)

2024-03-07 Thread Mehdi Amini via llvm-branch-commits


@@ -68,9 +68,9 @@ class PatternApplicator {
   ///invalidate the match and try another pattern.
   LogicalResult
   matchAndRewrite(Operation *op, PatternRewriter ,
-  function_ref canApply = {},
-  function_ref onFailure = {},
-  function_ref onSuccess = {});
+  std::function canApply = {},
+  std::function onFailure = {},
+  std::function onSuccess = 
{});

joker-eph wrote:

That can explain why you changed it at the call-site, but I'm puzzled about 
this function: it does not capture the callback as far as I can tell.

https://github.com/llvm/llvm-project/pull/84131
___
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits


[llvm-branch-commits] [mlir] [mlir][IR] Add listener notifications for pattern begin/end (PR #84131)

2024-03-07 Thread Matthias Springer via llvm-branch-commits


@@ -68,9 +68,9 @@ class PatternApplicator {
   ///invalidate the match and try another pattern.
   LogicalResult
   matchAndRewrite(Operation *op, PatternRewriter ,
-  function_ref canApply = {},
-  function_ref onFailure = {},
-  function_ref onSuccess = {});
+  std::function canApply = {},
+  std::function onFailure = {},
+  std::function onSuccess = 
{});

matthias-springer wrote:

(The existing code seemed to care about performance here; There were different 
`canApply` etc. depending on `NDEBUG`.)


https://github.com/llvm/llvm-project/pull/84131
___
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits


[llvm-branch-commits] [mlir] [mlir][IR] Add listener notifications for pattern begin/end (PR #84131)

2024-03-07 Thread Matthias Springer via llvm-branch-commits

https://github.com/matthias-springer edited 
https://github.com/llvm/llvm-project/pull/84131
___
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits


[llvm-branch-commits] [mlir] [mlir][IR] Add listener notifications for pattern begin/end (PR #84131)

2024-03-07 Thread Matthias Springer via llvm-branch-commits

https://github.com/matthias-springer edited 
https://github.com/llvm/llvm-project/pull/84131
___
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits


[llvm-branch-commits] [mlir] [mlir][IR] Add listener notifications for pattern begin/end (PR #84131)

2024-03-07 Thread Matthias Springer via llvm-branch-commits


@@ -68,9 +68,9 @@ class PatternApplicator {
   ///invalidate the match and try another pattern.
   LogicalResult
   matchAndRewrite(Operation *op, PatternRewriter ,
-  function_ref canApply = {},
-  function_ref onFailure = {},
-  function_ref onSuccess = {});
+  std::function canApply = {},
+  std::function onFailure = {},
+  std::function onSuccess = 
{});

matthias-springer wrote:

There is a `stack-use-after-scope` (reported by ASAN, also crashes in opt mode) 
with `function_ref`. The callback in the greedy pattern rewriter is defined 
inside of an `if` check:
```c++
function_ref canApply = {};
function_ref onFailure = {};
function_ref onSuccess = {};
bool debugBuild = false;
#ifdef NDEBUG
debugBuild = true;
#endif // NDEBUG
if (debugBuild || config.listener) {
  canApply = [&](const Pattern ) {
if (this->config.listener) { /* ... */ }
// ...
  }
  // ...
}

// `canApply` points to a lambda that is out of scope.
LogicalResult matchResult =
matcher.matchAndRewrite(op, *this, canApply, onFailure, onSuccess);
```

`function_ref` is a "non-owning function wrapper", but the lambda captures 
`this`.

Changing to `std::function` is one way to fix it. I could also just always pass 
a lambda. That would actually be my preferred solution, but there is a slight 
overhead when running in opt mode and without listener because the callback 
would always be called (even if it does not do anything):
```c++
LogicalResult matchResult = matcher.matchAndRewrite(
op, *this,
/*canApply=*/[&](const Pattern ) {
  if (this->listener) { /* ... */ }
  // ...
},
/*onFailure=*/[&](const Pattern ) { /* ... */},
/*onSuccess=*/[&](const Pattern ) { /* ... */});
```

What do you think?


https://github.com/llvm/llvm-project/pull/84131
___
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits


[llvm-branch-commits] [mlir] [mlir][IR] Add listener notifications for pattern begin/end (PR #84131)

2024-03-07 Thread Mehdi Amini via llvm-branch-commits


@@ -562,30 +562,39 @@ bool GreedyPatternRewriteDriver::processWorklist() {
 // Try to match one of the patterns. The rewriter is automatically
 // notified of any necessary changes, so there is nothing else to do
 // here.
+std::function canApply = nullptr;
+std::function onFailure = nullptr;
+std::function onSuccess = nullptr;
+bool debugBuild = false;
 #ifndef NDEBUG
-auto canApply = [&](const Pattern ) {
-  LLVM_DEBUG({
-logger.getOStream() << "\n";
-logger.startLine() << "* Pattern " << pattern.getDebugName() << " : '"
-   << op->getName() << " -> (";
-llvm::interleaveComma(pattern.getGeneratedOps(), logger.getOStream());
-logger.getOStream() << ")' {\n";
-logger.indent();
-  });
-  return true;
-};
-auto onFailure = [&](const Pattern ) {
-  LLVM_DEBUG(logResult("failure", "pattern failed to match"));
-};
-auto onSuccess = [&](const Pattern ) {
-  LLVM_DEBUG(logResult("success", "pattern applied successfully"));
-  return success();
-};
-#else
-function_ref canApply = {};
-function_ref onFailure = {};
-function_ref onSuccess = {};
-#endif
+debugBuild = true;

joker-eph wrote:

Oh never mind I see!

https://github.com/llvm/llvm-project/pull/84131
___
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits


[llvm-branch-commits] [mlir] [mlir][IR] Add listener notifications for pattern begin/end (PR #84131)

2024-03-07 Thread Mehdi Amini via llvm-branch-commits


@@ -562,30 +562,39 @@ bool GreedyPatternRewriteDriver::processWorklist() {
 // Try to match one of the patterns. The rewriter is automatically
 // notified of any necessary changes, so there is nothing else to do
 // here.
+std::function canApply = nullptr;
+std::function onFailure = nullptr;
+std::function onSuccess = nullptr;
+bool debugBuild = false;
 #ifndef NDEBUG
-auto canApply = [&](const Pattern ) {
-  LLVM_DEBUG({
-logger.getOStream() << "\n";
-logger.startLine() << "* Pattern " << pattern.getDebugName() << " : '"
-   << op->getName() << " -> (";
-llvm::interleaveComma(pattern.getGeneratedOps(), logger.getOStream());
-logger.getOStream() << ")' {\n";
-logger.indent();
-  });
-  return true;
-};
-auto onFailure = [&](const Pattern ) {
-  LLVM_DEBUG(logResult("failure", "pattern failed to match"));
-};
-auto onSuccess = [&](const Pattern ) {
-  LLVM_DEBUG(logResult("success", "pattern applied successfully"));
-  return success();
-};
-#else
-function_ref canApply = {};
-function_ref onFailure = {};
-function_ref onSuccess = {};
-#endif
+debugBuild = true;

joker-eph wrote:

Why changing the structure of the code with this variable?

https://github.com/llvm/llvm-project/pull/84131
___
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits


[llvm-branch-commits] [mlir] [mlir][IR] Add listener notifications for pattern begin/end (PR #84131)

2024-03-07 Thread Mehdi Amini via llvm-branch-commits


@@ -68,9 +68,9 @@ class PatternApplicator {
   ///invalidate the match and try another pattern.
   LogicalResult
   matchAndRewrite(Operation *op, PatternRewriter ,
-  function_ref canApply = {},
-  function_ref onFailure = {},
-  function_ref onSuccess = {});
+  std::function canApply = {},
+  std::function onFailure = {},
+  std::function onSuccess = 
{});

joker-eph wrote:

Why this change?

https://github.com/llvm/llvm-project/pull/84131
___
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits


[llvm-branch-commits] [mlir] [mlir][IR] Add listener notifications for pattern begin/end (PR #84131)

2024-03-07 Thread Matthias Springer via llvm-branch-commits

https://github.com/matthias-springer updated 
https://github.com/llvm/llvm-project/pull/84131

>From 0aef4b91f6aad0335e7eae2849edffd4338f4c40 Mon Sep 17 00:00:00 2001
From: Matthias Springer 
Date: Fri, 8 Mar 2024 03:42:15 +
Subject: [PATCH] [mlir][IR] Add listener notifications for pattern begin/end

---
 mlir/include/mlir/IR/PatternMatch.h   | 30 --
 mlir/include/mlir/Rewrite/PatternApplicator.h |  6 +-
 mlir/lib/Rewrite/PatternApplicator.cpp|  6 +-
 .../Transforms/Utils/DialectConversion.cpp| 29 +++---
 .../Utils/GreedyPatternRewriteDriver.cpp  | 57 +++
 5 files changed, 85 insertions(+), 43 deletions(-)

diff --git a/mlir/include/mlir/IR/PatternMatch.h 
b/mlir/include/mlir/IR/PatternMatch.h
index e3500b3f9446d8..49544c42790d4d 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -432,11 +432,22 @@ class RewriterBase : public OpBuilder {
 /// Note: This notification is not triggered when unlinking an operation.
 virtual void notifyOperationErased(Operation *op) {}
 
-/// Notify the listener that the pattern failed to match the given
-/// operation, and provide a callback to populate a diagnostic with the
-/// reason why the failure occurred. This method allows for derived
-/// listeners to optionally hook into the reason why a rewrite failed, and
-/// display it to users.
+/// Notify the listener that the specified pattern is about to be applied
+/// at the specified root operation.
+virtual void notifyPatternBegin(const Pattern , Operation *op) {}
+
+/// Notify the listener that a pattern application finished with the
+/// specified status. "success" indicates that the pattern was applied
+/// successfully. "failure" indicates that the pattern could not be
+/// applied. The pattern may have communicated the reason for the failure
+/// with `notifyMatchFailure`.
+virtual void notifyPatternEnd(const Pattern ,
+  LogicalResult status) {}
+
+/// Notify the listener that the pattern failed to match, and provide a
+/// callback to populate a diagnostic with the reason why the failure
+/// occurred. This method allows for derived listeners to optionally hook
+/// into the reason why a rewrite failed, and display it to users.
 virtual void
 notifyMatchFailure(Location loc,
function_ref reasonCallback) {}
@@ -478,6 +489,15 @@ class RewriterBase : public OpBuilder {
   if (auto *rewriteListener = dyn_cast(listener))
 rewriteListener->notifyOperationErased(op);
 }
+void notifyPatternBegin(const Pattern , Operation *op) override {
+  if (auto *rewriteListener = dyn_cast(listener))
+rewriteListener->notifyPatternBegin(pattern, op);
+}
+void notifyPatternEnd(const Pattern ,
+  LogicalResult status) override {
+  if (auto *rewriteListener = dyn_cast(listener))
+rewriteListener->notifyPatternEnd(pattern, status);
+}
 void notifyMatchFailure(
 Location loc,
 function_ref reasonCallback) override {
diff --git a/mlir/include/mlir/Rewrite/PatternApplicator.h 
b/mlir/include/mlir/Rewrite/PatternApplicator.h
index f7871f819a273b..c767bf8fee9073 100644
--- a/mlir/include/mlir/Rewrite/PatternApplicator.h
+++ b/mlir/include/mlir/Rewrite/PatternApplicator.h
@@ -68,9 +68,9 @@ class PatternApplicator {
   ///invalidate the match and try another pattern.
   LogicalResult
   matchAndRewrite(Operation *op, PatternRewriter ,
-  function_ref canApply = {},
-  function_ref onFailure = {},
-  function_ref onSuccess = {});
+  std::function canApply = {},
+  std::function onFailure = {},
+  std::function onSuccess = 
{});
 
   /// Apply a cost model to the patterns within this applicator.
   void applyCostModel(CostModel model);
diff --git a/mlir/lib/Rewrite/PatternApplicator.cpp 
b/mlir/lib/Rewrite/PatternApplicator.cpp
index ea43f8a147d479..fecfb030a77fbf 100644
--- a/mlir/lib/Rewrite/PatternApplicator.cpp
+++ b/mlir/lib/Rewrite/PatternApplicator.cpp
@@ -129,9 +129,9 @@ void PatternApplicator::walkAllPatterns(
 
 LogicalResult PatternApplicator::matchAndRewrite(
 Operation *op, PatternRewriter ,
-function_ref canApply,
-function_ref onFailure,
-function_ref onSuccess) {
+std::function canApply,
+std::function onFailure,
+std::function onSuccess) {
   // Before checking native patterns, first match against the bytecode. This
   // won't automatically perform any rewrites so there is no need to worry 
about
   // conflicts.
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp 
b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index c1a261eab8487d..cd49bd121a62e5 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ 

[llvm-branch-commits] [mlir] [mlir][IR] Add listener notifications for pattern begin/end (PR #84131)

2024-03-07 Thread Matthias Springer via llvm-branch-commits

https://github.com/matthias-springer updated 
https://github.com/llvm/llvm-project/pull/84131

>From 407c7f7c24a87f409a921328605cc93637386d19 Mon Sep 17 00:00:00 2001
From: Matthias Springer 
Date: Fri, 8 Mar 2024 02:01:50 +
Subject: [PATCH] [mlir][IR] Add listener notifications for pattern begin/end

---
 mlir/include/mlir/IR/PatternMatch.h   | 30 +--
 .../Transforms/Utils/DialectConversion.cpp| 29 +++---
 .../Utils/GreedyPatternRewriteDriver.cpp  | 53 +++
 3 files changed, 77 insertions(+), 35 deletions(-)

diff --git a/mlir/include/mlir/IR/PatternMatch.h 
b/mlir/include/mlir/IR/PatternMatch.h
index e3500b3f9446d8..49544c42790d4d 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -432,11 +432,22 @@ class RewriterBase : public OpBuilder {
 /// Note: This notification is not triggered when unlinking an operation.
 virtual void notifyOperationErased(Operation *op) {}
 
-/// Notify the listener that the pattern failed to match the given
-/// operation, and provide a callback to populate a diagnostic with the
-/// reason why the failure occurred. This method allows for derived
-/// listeners to optionally hook into the reason why a rewrite failed, and
-/// display it to users.
+/// Notify the listener that the specified pattern is about to be applied
+/// at the specified root operation.
+virtual void notifyPatternBegin(const Pattern , Operation *op) {}
+
+/// Notify the listener that a pattern application finished with the
+/// specified status. "success" indicates that the pattern was applied
+/// successfully. "failure" indicates that the pattern could not be
+/// applied. The pattern may have communicated the reason for the failure
+/// with `notifyMatchFailure`.
+virtual void notifyPatternEnd(const Pattern ,
+  LogicalResult status) {}
+
+/// Notify the listener that the pattern failed to match, and provide a
+/// callback to populate a diagnostic with the reason why the failure
+/// occurred. This method allows for derived listeners to optionally hook
+/// into the reason why a rewrite failed, and display it to users.
 virtual void
 notifyMatchFailure(Location loc,
function_ref reasonCallback) {}
@@ -478,6 +489,15 @@ class RewriterBase : public OpBuilder {
   if (auto *rewriteListener = dyn_cast(listener))
 rewriteListener->notifyOperationErased(op);
 }
+void notifyPatternBegin(const Pattern , Operation *op) override {
+  if (auto *rewriteListener = dyn_cast(listener))
+rewriteListener->notifyPatternBegin(pattern, op);
+}
+void notifyPatternEnd(const Pattern ,
+  LogicalResult status) override {
+  if (auto *rewriteListener = dyn_cast(listener))
+rewriteListener->notifyPatternEnd(pattern, status);
+}
 void notifyMatchFailure(
 Location loc,
 function_ref reasonCallback) override {
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp 
b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index c1a261eab8487d..cd49bd121a62e5 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1856,7 +1856,8 @@ class OperationLegalizer {
   using LegalizationAction = ConversionTarget::LegalizationAction;
 
   OperationLegalizer(const ConversionTarget ,
- const FrozenRewritePatternSet );
+ const FrozenRewritePatternSet ,
+ const ConversionConfig );
 
   /// Returns true if the given operation is known to be illegal on the target.
   bool isIllegal(Operation *op) const;
@@ -1948,12 +1949,16 @@ class OperationLegalizer {
 
   /// The pattern applicator to use for conversions.
   PatternApplicator applicator;
+
+  /// Dialect conversion configuration.
+  const ConversionConfig 
 };
 } // namespace
 
 OperationLegalizer::OperationLegalizer(const ConversionTarget ,
-   const FrozenRewritePatternSet )
-: target(targetInfo), applicator(patterns) {
+   const FrozenRewritePatternSet ,
+   const ConversionConfig )
+: target(targetInfo), applicator(patterns), config(config) {
   // The set of patterns that can be applied to illegal operations to transform
   // them into legal ones.
   DenseMap legalizerPatterns;
@@ -2098,7 +2103,10 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
 
   // Functor that returns if the given pattern may be applied.
   auto canApply = [&](const Pattern ) {
-return canApplyPattern(op, pattern, rewriter);
+bool canApply = canApplyPattern(op, pattern, rewriter);
+if (canApply && config.listener)
+  config.listener->notifyPatternBegin(pattern, op);
+return canApply;
   };
 
   // Functor that cleans up the 

[llvm-branch-commits] [mlir] [mlir][IR] Add listener notifications for pattern begin/end (PR #84131)

2024-03-06 Thread via llvm-branch-commits

llvmbot wrote:




@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)


Changes

This commit adds two new notifications to `RewriterBase::Listener`:
* `notifyPatternBegin`: Called when a pattern application begins during a 
greedy pattern rewrite or dialect conversion.
* `notifyPatternEnd`: Called when a pattern application finishes during a 
greedy pattern rewrite or dialect conversion.

The listener infrastructure already provides a `notifyMatchFailure` callback 
that notifies about the reason for a pattern match failure. The two new 
notifications provide additional information about pattern applications.

This change is in preparation of improving the handle update mechanism in the 
`apply_conversion_patterns` transform op.


---
Full diff: https://github.com/llvm/llvm-project/pull/84131.diff


3 Files Affected:

- (modified) mlir/include/mlir/IR/PatternMatch.h (+25-5) 
- (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+21-8) 
- (modified) mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp (+31-22) 


``diff
diff --git a/mlir/include/mlir/IR/PatternMatch.h 
b/mlir/include/mlir/IR/PatternMatch.h
index f8d22cfb22afd0..838b4947648f5e 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -432,11 +432,22 @@ class RewriterBase : public OpBuilder {
 /// Note: This notification is not triggered when unlinking an operation.
 virtual void notifyOperationErased(Operation *op) {}
 
-/// Notify the listener that the pattern failed to match the given
-/// operation, and provide a callback to populate a diagnostic with the
-/// reason why the failure occurred. This method allows for derived
-/// listeners to optionally hook into the reason why a rewrite failed, and
-/// display it to users.
+/// Notify the listener that the specified pattern is about to be applied
+/// at the specified root operation.
+virtual void notifyPatternBegin(const Pattern , Operation *op) {}
+
+/// Notify the listener that a pattern application finished with the
+/// specified status. "success" indicates that the pattern was applied
+/// successfully. "failure" indicates that the pattern could not be
+/// applied. The pattern may have communicated the reason for the failure
+/// with `notifyMatchFailure`.
+virtual void notifyPatternEnd(const Pattern ,
+  LogicalResult status) {}
+
+/// Notify the listener that the pattern failed to match, and provide a
+/// callback to populate a diagnostic with the reason why the failure
+/// occurred. This method allows for derived listeners to optionally hook
+/// into the reason why a rewrite failed, and display it to users.
 virtual void
 notifyMatchFailure(Location loc,
function_ref reasonCallback) {}
@@ -478,6 +489,15 @@ class RewriterBase : public OpBuilder {
   if (auto *rewriteListener = dyn_cast(listener))
 rewriteListener->notifyOperationErased(op);
 }
+void notifyPatternBegin(const Pattern , Operation *op) override {
+  if (auto *rewriteListener = dyn_cast(listener))
+rewriteListener->notifyPatternBegin(pattern, op);
+}
+void notifyPatternEnd(const Pattern ,
+  LogicalResult status) override {
+  if (auto *rewriteListener = dyn_cast(listener))
+rewriteListener->notifyPatternEnd(pattern, status);
+}
 void notifyMatchFailure(
 Location loc,
 function_ref reasonCallback) override {
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp 
b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index a5145246bc30c4..587fbe209b58af 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1863,7 +1863,8 @@ class OperationLegalizer {
   using LegalizationAction = ConversionTarget::LegalizationAction;
 
   OperationLegalizer(const ConversionTarget ,
- const FrozenRewritePatternSet );
+ const FrozenRewritePatternSet ,
+ const ConversionConfig );
 
   /// Returns true if the given operation is known to be illegal on the target.
   bool isIllegal(Operation *op) const;
@@ -1955,12 +1956,16 @@ class OperationLegalizer {
 
   /// The pattern applicator to use for conversions.
   PatternApplicator applicator;
+
+  /// Dialect conversion configuration.
+  const ConversionConfig 
 };
 } // namespace
 
 OperationLegalizer::OperationLegalizer(const ConversionTarget ,
-   const FrozenRewritePatternSet )
-: target(targetInfo), applicator(patterns) {
+   const FrozenRewritePatternSet ,
+   const ConversionConfig )
+: target(targetInfo), applicator(patterns), config(config) {
   // The set of patterns that can be applied to illegal operations to 

[llvm-branch-commits] [mlir] [mlir][IR] Add listener notifications for pattern begin/end (PR #84131)

2024-03-06 Thread Matthias Springer via llvm-branch-commits

https://github.com/matthias-springer created 
https://github.com/llvm/llvm-project/pull/84131

This commit adds two new notifications to `RewriterBase::Listener`:
* `notifyPatternBegin`: Called when a pattern application begins during a 
greedy pattern rewrite or dialect conversion.
* `notifyPatternEnd`: Called when a pattern application finishes during a 
greedy pattern rewrite or dialect conversion.

The listener infrastructure already provides a `notifyMatchFailure` callback 
that notifies about the reason for a pattern match failure. The two new 
notifications provide additional information about pattern applications.

This change is in preparation of improving the handle update mechanism in the 
`apply_conversion_patterns` transform op.


>From 3ecec094d43a4eb9405ad057c32e4579f1fbe680 Mon Sep 17 00:00:00 2001
From: Matthias Springer 
Date: Wed, 6 Mar 2024 08:00:35 +
Subject: [PATCH] [mlir][IR] Add listener notifications for pattern begin/end

---
 mlir/include/mlir/IR/PatternMatch.h   | 30 +--
 .../Transforms/Utils/DialectConversion.cpp| 29 +++---
 .../Utils/GreedyPatternRewriteDriver.cpp  | 53 +++
 3 files changed, 77 insertions(+), 35 deletions(-)

diff --git a/mlir/include/mlir/IR/PatternMatch.h 
b/mlir/include/mlir/IR/PatternMatch.h
index f8d22cfb22afd0..838b4947648f5e 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -432,11 +432,22 @@ class RewriterBase : public OpBuilder {
 /// Note: This notification is not triggered when unlinking an operation.
 virtual void notifyOperationErased(Operation *op) {}
 
-/// Notify the listener that the pattern failed to match the given
-/// operation, and provide a callback to populate a diagnostic with the
-/// reason why the failure occurred. This method allows for derived
-/// listeners to optionally hook into the reason why a rewrite failed, and
-/// display it to users.
+/// Notify the listener that the specified pattern is about to be applied
+/// at the specified root operation.
+virtual void notifyPatternBegin(const Pattern , Operation *op) {}
+
+/// Notify the listener that a pattern application finished with the
+/// specified status. "success" indicates that the pattern was applied
+/// successfully. "failure" indicates that the pattern could not be
+/// applied. The pattern may have communicated the reason for the failure
+/// with `notifyMatchFailure`.
+virtual void notifyPatternEnd(const Pattern ,
+  LogicalResult status) {}
+
+/// Notify the listener that the pattern failed to match, and provide a
+/// callback to populate a diagnostic with the reason why the failure
+/// occurred. This method allows for derived listeners to optionally hook
+/// into the reason why a rewrite failed, and display it to users.
 virtual void
 notifyMatchFailure(Location loc,
function_ref reasonCallback) {}
@@ -478,6 +489,15 @@ class RewriterBase : public OpBuilder {
   if (auto *rewriteListener = dyn_cast(listener))
 rewriteListener->notifyOperationErased(op);
 }
+void notifyPatternBegin(const Pattern , Operation *op) override {
+  if (auto *rewriteListener = dyn_cast(listener))
+rewriteListener->notifyPatternBegin(pattern, op);
+}
+void notifyPatternEnd(const Pattern ,
+  LogicalResult status) override {
+  if (auto *rewriteListener = dyn_cast(listener))
+rewriteListener->notifyPatternEnd(pattern, status);
+}
 void notifyMatchFailure(
 Location loc,
 function_ref reasonCallback) override {
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp 
b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index a5145246bc30c4..587fbe209b58af 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1863,7 +1863,8 @@ class OperationLegalizer {
   using LegalizationAction = ConversionTarget::LegalizationAction;
 
   OperationLegalizer(const ConversionTarget ,
- const FrozenRewritePatternSet );
+ const FrozenRewritePatternSet ,
+ const ConversionConfig );
 
   /// Returns true if the given operation is known to be illegal on the target.
   bool isIllegal(Operation *op) const;
@@ -1955,12 +1956,16 @@ class OperationLegalizer {
 
   /// The pattern applicator to use for conversions.
   PatternApplicator applicator;
+
+  /// Dialect conversion configuration.
+  const ConversionConfig 
 };
 } // namespace
 
 OperationLegalizer::OperationLegalizer(const ConversionTarget ,
-   const FrozenRewritePatternSet )
-: target(targetInfo), applicator(patterns) {
+   const FrozenRewritePatternSet ,
+   const ConversionConfig )
+