[llvm-branch-commits] [mlir] [mlir][IR] Add listener notifications for pattern begin/end (PR #84131)
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/84131 >From a65d640a0ca2c6810da0878ed42db39f27cebfe1 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Fri, 8 Mar 2024 07:19:33 + Subject: [PATCH] [mlir][IR] Add listener notifications for pattern begin/end --- mlir/include/mlir/IR/PatternMatch.h | 30 ++--- .../Transforms/Utils/DialectConversion.cpp| 29 +++- .../Utils/GreedyPatternRewriteDriver.cpp | 33 +-- 3 files changed, 69 insertions(+), 23 deletions(-) diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index e3500b3f9446d8..49544c42790d4d 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -432,11 +432,22 @@ class RewriterBase : public OpBuilder { /// Note: This notification is not triggered when unlinking an operation. virtual void notifyOperationErased(Operation *op) {} -/// Notify the listener that the pattern failed to match the given -/// operation, and provide a callback to populate a diagnostic with the -/// reason why the failure occurred. This method allows for derived -/// listeners to optionally hook into the reason why a rewrite failed, and -/// display it to users. +/// Notify the listener that the specified pattern is about to be applied +/// at the specified root operation. +virtual void notifyPatternBegin(const Pattern , Operation *op) {} + +/// Notify the listener that a pattern application finished with the +/// specified status. "success" indicates that the pattern was applied +/// successfully. "failure" indicates that the pattern could not be +/// applied. The pattern may have communicated the reason for the failure +/// with `notifyMatchFailure`. +virtual void notifyPatternEnd(const Pattern , + LogicalResult status) {} + +/// Notify the listener that the pattern failed to match, and provide a +/// callback to populate a diagnostic with the reason why the failure +/// occurred. This method allows for derived listeners to optionally hook +/// into the reason why a rewrite failed, and display it to users. virtual void notifyMatchFailure(Location loc, function_ref reasonCallback) {} @@ -478,6 +489,15 @@ class RewriterBase : public OpBuilder { if (auto *rewriteListener = dyn_cast(listener)) rewriteListener->notifyOperationErased(op); } +void notifyPatternBegin(const Pattern , Operation *op) override { + if (auto *rewriteListener = dyn_cast(listener)) +rewriteListener->notifyPatternBegin(pattern, op); +} +void notifyPatternEnd(const Pattern , + LogicalResult status) override { + if (auto *rewriteListener = dyn_cast(listener)) +rewriteListener->notifyPatternEnd(pattern, status); +} void notifyMatchFailure( Location loc, function_ref reasonCallback) override { diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index c1a261eab8487d..cd49bd121a62e5 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -1856,7 +1856,8 @@ class OperationLegalizer { using LegalizationAction = ConversionTarget::LegalizationAction; OperationLegalizer(const ConversionTarget , - const FrozenRewritePatternSet ); + const FrozenRewritePatternSet , + const ConversionConfig ); /// Returns true if the given operation is known to be illegal on the target. bool isIllegal(Operation *op) const; @@ -1948,12 +1949,16 @@ class OperationLegalizer { /// The pattern applicator to use for conversions. PatternApplicator applicator; + + /// Dialect conversion configuration. + const ConversionConfig }; } // namespace OperationLegalizer::OperationLegalizer(const ConversionTarget , - const FrozenRewritePatternSet ) -: target(targetInfo), applicator(patterns) { + const FrozenRewritePatternSet , + const ConversionConfig ) +: target(targetInfo), applicator(patterns), config(config) { // The set of patterns that can be applied to illegal operations to transform // them into legal ones. DenseMap legalizerPatterns; @@ -2098,7 +2103,10 @@ OperationLegalizer::legalizeWithPattern(Operation *op, // Functor that returns if the given pattern may be applied. auto canApply = [&](const Pattern ) { -return canApplyPattern(op, pattern, rewriter); +bool canApply = canApplyPattern(op, pattern, rewriter); +if (canApply && config.listener) + config.listener->notifyPatternBegin(pattern, op); +return canApply; }; // Functor that
[llvm-branch-commits] [mlir] [mlir][IR] Add listener notifications for pattern begin/end (PR #84131)
https://github.com/joker-eph approved this pull request. https://github.com/llvm/llvm-project/pull/84131 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][IR] Add listener notifications for pattern begin/end (PR #84131)
@@ -572,20 +571,33 @@ bool GreedyPatternRewriteDriver::processWorklist() { logger.getOStream() << ")' {\n"; logger.indent(); }); + if (config.listener) +config.listener->notifyPatternBegin(pattern, op); return true; }; -auto onFailure = [&](const Pattern ) { - LLVM_DEBUG(logResult("failure", "pattern failed to match")); -}; -auto onSuccess = [&](const Pattern ) { - LLVM_DEBUG(logResult("success", "pattern applied successfully")); - return success(); -}; -#else -function_ref canApply = {}; -function_ref onFailure = {}; -function_ref onSuccess = {}; -#endif +function_ref onFailure = +[&](const Pattern ) { + LLVM_DEBUG(logResult("failure", "pattern failed to match")); + if (config.listener) +config.listener->notifyPatternEnd(pattern, failure()); +}; +function_ref onSuccess = +[&](const Pattern ) { + LLVM_DEBUG(logResult("success", "pattern applied successfully")); + if (config.listener) +config.listener->notifyPatternEnd(pattern, success()); + return success(); +}; + +#ifdef NDEBUG +// Optimization: PatternApplicator callbacks are not needed when running in +// optimized mode and without a listener. +if (!config.listener) { + canApply = nullptr; + onFailure = nullptr; + onSuccess = nullptr; +} +#endif // NDEBUG joker-eph wrote: Note: I didn't suggest changing this, what you had here was reasonable! https://github.com/llvm/llvm-project/pull/84131 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][IR] Add listener notifications for pattern begin/end (PR #84131)
@@ -68,9 +68,9 @@ class PatternApplicator { ///invalidate the match and try another pattern. LogicalResult matchAndRewrite(Operation *op, PatternRewriter , - function_ref canApply = {}, - function_ref onFailure = {}, - function_ref onSuccess = {}); + std::function canApply = {}, + std::function onFailure = {}, + std::function onSuccess = {}); matthias-springer wrote: That's a good way to think about it. https://github.com/llvm/llvm-project/pull/84131 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][IR] Add listener notifications for pattern begin/end (PR #84131)
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/84131 >From 24a56caffaa23f6da73b129ca96f28e9a9bbf388 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Fri, 8 Mar 2024 07:19:33 + Subject: [PATCH] [mlir][IR] Add listener notifications for pattern begin/end --- mlir/include/mlir/IR/PatternMatch.h | 30 ++--- .../Transforms/Utils/DialectConversion.cpp| 29 + .../Utils/GreedyPatternRewriteDriver.cpp | 42 --- 3 files changed, 73 insertions(+), 28 deletions(-) diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index e3500b3f9446d8..49544c42790d4d 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -432,11 +432,22 @@ class RewriterBase : public OpBuilder { /// Note: This notification is not triggered when unlinking an operation. virtual void notifyOperationErased(Operation *op) {} -/// Notify the listener that the pattern failed to match the given -/// operation, and provide a callback to populate a diagnostic with the -/// reason why the failure occurred. This method allows for derived -/// listeners to optionally hook into the reason why a rewrite failed, and -/// display it to users. +/// Notify the listener that the specified pattern is about to be applied +/// at the specified root operation. +virtual void notifyPatternBegin(const Pattern , Operation *op) {} + +/// Notify the listener that a pattern application finished with the +/// specified status. "success" indicates that the pattern was applied +/// successfully. "failure" indicates that the pattern could not be +/// applied. The pattern may have communicated the reason for the failure +/// with `notifyMatchFailure`. +virtual void notifyPatternEnd(const Pattern , + LogicalResult status) {} + +/// Notify the listener that the pattern failed to match, and provide a +/// callback to populate a diagnostic with the reason why the failure +/// occurred. This method allows for derived listeners to optionally hook +/// into the reason why a rewrite failed, and display it to users. virtual void notifyMatchFailure(Location loc, function_ref reasonCallback) {} @@ -478,6 +489,15 @@ class RewriterBase : public OpBuilder { if (auto *rewriteListener = dyn_cast(listener)) rewriteListener->notifyOperationErased(op); } +void notifyPatternBegin(const Pattern , Operation *op) override { + if (auto *rewriteListener = dyn_cast(listener)) +rewriteListener->notifyPatternBegin(pattern, op); +} +void notifyPatternEnd(const Pattern , + LogicalResult status) override { + if (auto *rewriteListener = dyn_cast(listener)) +rewriteListener->notifyPatternEnd(pattern, status); +} void notifyMatchFailure( Location loc, function_ref reasonCallback) override { diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index c1a261eab8487d..cd49bd121a62e5 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -1856,7 +1856,8 @@ class OperationLegalizer { using LegalizationAction = ConversionTarget::LegalizationAction; OperationLegalizer(const ConversionTarget , - const FrozenRewritePatternSet ); + const FrozenRewritePatternSet , + const ConversionConfig ); /// Returns true if the given operation is known to be illegal on the target. bool isIllegal(Operation *op) const; @@ -1948,12 +1949,16 @@ class OperationLegalizer { /// The pattern applicator to use for conversions. PatternApplicator applicator; + + /// Dialect conversion configuration. + const ConversionConfig }; } // namespace OperationLegalizer::OperationLegalizer(const ConversionTarget , - const FrozenRewritePatternSet ) -: target(targetInfo), applicator(patterns) { + const FrozenRewritePatternSet , + const ConversionConfig ) +: target(targetInfo), applicator(patterns), config(config) { // The set of patterns that can be applied to illegal operations to transform // them into legal ones. DenseMap legalizerPatterns; @@ -2098,7 +2103,10 @@ OperationLegalizer::legalizeWithPattern(Operation *op, // Functor that returns if the given pattern may be applied. auto canApply = [&](const Pattern ) { -return canApplyPattern(op, pattern, rewriter); +bool canApply = canApplyPattern(op, pattern, rewriter); +if (canApply && config.listener) + config.listener->notifyPatternBegin(pattern, op); +return canApply; }; // Functor that cleans up
[llvm-branch-commits] [mlir] [mlir][IR] Add listener notifications for pattern begin/end (PR #84131)
https://github.com/joker-eph edited https://github.com/llvm/llvm-project/pull/84131 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][IR] Add listener notifications for pattern begin/end (PR #84131)
https://github.com/joker-eph edited https://github.com/llvm/llvm-project/pull/84131 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][IR] Add listener notifications for pattern begin/end (PR #84131)
@@ -68,9 +68,9 @@ class PatternApplicator { ///invalidate the match and try another pattern. LogicalResult matchAndRewrite(Operation *op, PatternRewriter , - function_ref canApply = {}, - function_ref onFailure = {}, - function_ref onSuccess = {}); + std::function canApply = {}, + std::function onFailure = {}, + std::function onSuccess = {}); joker-eph wrote: > What are you referring to with this function? Where this comment thread is anchored: `matchAndRewrite` > The problem here is really just caused by the fact that the canApply = > assignment is inside of a nested scope. And the lambda object is dead by the > time matcher.matchAndRewrite is called. I.e., the canApply function_ref > points to an already free'd lambda. At least that's my understanding. Yes, but that's a problem for the call-site, I don't quite see where you make the connection to the signature of `matchAndRewrite`? https://github.com/llvm/llvm-project/pull/84131 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][IR] Add listener notifications for pattern begin/end (PR #84131)
https://github.com/matthias-springer edited https://github.com/llvm/llvm-project/pull/84131 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][IR] Add listener notifications for pattern begin/end (PR #84131)
@@ -68,9 +68,9 @@ class PatternApplicator { ///invalidate the match and try another pattern. LogicalResult matchAndRewrite(Operation *op, PatternRewriter , - function_ref canApply = {}, - function_ref onFailure = {}, - function_ref onSuccess = {}); + std::function canApply = {}, + std::function onFailure = {}, + std::function onSuccess = {}); matthias-springer wrote: What are you referring to with `this function`? The problem here is really just caused by the fact that the `canApply =` assignment is inside of a nested scope. And the lambda object is dead by the time `matcher.matchAndRewrite` is called. I.e., the `canApply` function_ref points to an already free'd lambda. At least that's my understanding. What's the C++ guidelines wrt. to `function` vs. `function_ref`. This is the first time I ran into such an issue, and assigning lambdas to `function_ref` feels "dangerous" to me now. When using `function`, I don't have to think about the lifetime of an object. https://github.com/llvm/llvm-project/pull/84131 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][IR] Add listener notifications for pattern begin/end (PR #84131)
@@ -68,9 +68,9 @@ class PatternApplicator { ///invalidate the match and try another pattern. LogicalResult matchAndRewrite(Operation *op, PatternRewriter , - function_ref canApply = {}, - function_ref onFailure = {}, - function_ref onSuccess = {}); + std::function canApply = {}, + std::function onFailure = {}, + std::function onSuccess = {}); joker-eph wrote: That can explain why you changed it at the call-site, but I'm puzzled about this function: it does not capture the callback as far as I can tell. https://github.com/llvm/llvm-project/pull/84131 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][IR] Add listener notifications for pattern begin/end (PR #84131)
@@ -68,9 +68,9 @@ class PatternApplicator { ///invalidate the match and try another pattern. LogicalResult matchAndRewrite(Operation *op, PatternRewriter , - function_ref canApply = {}, - function_ref onFailure = {}, - function_ref onSuccess = {}); + std::function canApply = {}, + std::function onFailure = {}, + std::function onSuccess = {}); matthias-springer wrote: (The existing code seemed to care about performance here; There were different `canApply` etc. depending on `NDEBUG`.) https://github.com/llvm/llvm-project/pull/84131 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][IR] Add listener notifications for pattern begin/end (PR #84131)
https://github.com/matthias-springer edited https://github.com/llvm/llvm-project/pull/84131 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][IR] Add listener notifications for pattern begin/end (PR #84131)
https://github.com/matthias-springer edited https://github.com/llvm/llvm-project/pull/84131 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][IR] Add listener notifications for pattern begin/end (PR #84131)
@@ -68,9 +68,9 @@ class PatternApplicator { ///invalidate the match and try another pattern. LogicalResult matchAndRewrite(Operation *op, PatternRewriter , - function_ref canApply = {}, - function_ref onFailure = {}, - function_ref onSuccess = {}); + std::function canApply = {}, + std::function onFailure = {}, + std::function onSuccess = {}); matthias-springer wrote: There is a `stack-use-after-scope` (reported by ASAN, also crashes in opt mode) with `function_ref`. The callback in the greedy pattern rewriter is defined inside of an `if` check: ```c++ function_ref canApply = {}; function_ref onFailure = {}; function_ref onSuccess = {}; bool debugBuild = false; #ifdef NDEBUG debugBuild = true; #endif // NDEBUG if (debugBuild || config.listener) { canApply = [&](const Pattern ) { if (this->config.listener) { /* ... */ } // ... } // ... } // `canApply` points to a lambda that is out of scope. LogicalResult matchResult = matcher.matchAndRewrite(op, *this, canApply, onFailure, onSuccess); ``` `function_ref` is a "non-owning function wrapper", but the lambda captures `this`. Changing to `std::function` is one way to fix it. I could also just always pass a lambda. That would actually be my preferred solution, but there is a slight overhead when running in opt mode and without listener because the callback would always be called (even if it does not do anything): ```c++ LogicalResult matchResult = matcher.matchAndRewrite( op, *this, /*canApply=*/[&](const Pattern ) { if (this->listener) { /* ... */ } // ... }, /*onFailure=*/[&](const Pattern ) { /* ... */}, /*onSuccess=*/[&](const Pattern ) { /* ... */}); ``` What do you think? https://github.com/llvm/llvm-project/pull/84131 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][IR] Add listener notifications for pattern begin/end (PR #84131)
@@ -562,30 +562,39 @@ bool GreedyPatternRewriteDriver::processWorklist() { // Try to match one of the patterns. The rewriter is automatically // notified of any necessary changes, so there is nothing else to do // here. +std::function canApply = nullptr; +std::function onFailure = nullptr; +std::function onSuccess = nullptr; +bool debugBuild = false; #ifndef NDEBUG -auto canApply = [&](const Pattern ) { - LLVM_DEBUG({ -logger.getOStream() << "\n"; -logger.startLine() << "* Pattern " << pattern.getDebugName() << " : '" - << op->getName() << " -> ("; -llvm::interleaveComma(pattern.getGeneratedOps(), logger.getOStream()); -logger.getOStream() << ")' {\n"; -logger.indent(); - }); - return true; -}; -auto onFailure = [&](const Pattern ) { - LLVM_DEBUG(logResult("failure", "pattern failed to match")); -}; -auto onSuccess = [&](const Pattern ) { - LLVM_DEBUG(logResult("success", "pattern applied successfully")); - return success(); -}; -#else -function_ref canApply = {}; -function_ref onFailure = {}; -function_ref onSuccess = {}; -#endif +debugBuild = true; joker-eph wrote: Oh never mind I see! https://github.com/llvm/llvm-project/pull/84131 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][IR] Add listener notifications for pattern begin/end (PR #84131)
@@ -562,30 +562,39 @@ bool GreedyPatternRewriteDriver::processWorklist() { // Try to match one of the patterns. The rewriter is automatically // notified of any necessary changes, so there is nothing else to do // here. +std::function canApply = nullptr; +std::function onFailure = nullptr; +std::function onSuccess = nullptr; +bool debugBuild = false; #ifndef NDEBUG -auto canApply = [&](const Pattern ) { - LLVM_DEBUG({ -logger.getOStream() << "\n"; -logger.startLine() << "* Pattern " << pattern.getDebugName() << " : '" - << op->getName() << " -> ("; -llvm::interleaveComma(pattern.getGeneratedOps(), logger.getOStream()); -logger.getOStream() << ")' {\n"; -logger.indent(); - }); - return true; -}; -auto onFailure = [&](const Pattern ) { - LLVM_DEBUG(logResult("failure", "pattern failed to match")); -}; -auto onSuccess = [&](const Pattern ) { - LLVM_DEBUG(logResult("success", "pattern applied successfully")); - return success(); -}; -#else -function_ref canApply = {}; -function_ref onFailure = {}; -function_ref onSuccess = {}; -#endif +debugBuild = true; joker-eph wrote: Why changing the structure of the code with this variable? https://github.com/llvm/llvm-project/pull/84131 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][IR] Add listener notifications for pattern begin/end (PR #84131)
@@ -68,9 +68,9 @@ class PatternApplicator { ///invalidate the match and try another pattern. LogicalResult matchAndRewrite(Operation *op, PatternRewriter , - function_ref canApply = {}, - function_ref onFailure = {}, - function_ref onSuccess = {}); + std::function canApply = {}, + std::function onFailure = {}, + std::function onSuccess = {}); joker-eph wrote: Why this change? https://github.com/llvm/llvm-project/pull/84131 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][IR] Add listener notifications for pattern begin/end (PR #84131)
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/84131 >From 0aef4b91f6aad0335e7eae2849edffd4338f4c40 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Fri, 8 Mar 2024 03:42:15 + Subject: [PATCH] [mlir][IR] Add listener notifications for pattern begin/end --- mlir/include/mlir/IR/PatternMatch.h | 30 -- mlir/include/mlir/Rewrite/PatternApplicator.h | 6 +- mlir/lib/Rewrite/PatternApplicator.cpp| 6 +- .../Transforms/Utils/DialectConversion.cpp| 29 +++--- .../Utils/GreedyPatternRewriteDriver.cpp | 57 +++ 5 files changed, 85 insertions(+), 43 deletions(-) diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index e3500b3f9446d8..49544c42790d4d 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -432,11 +432,22 @@ class RewriterBase : public OpBuilder { /// Note: This notification is not triggered when unlinking an operation. virtual void notifyOperationErased(Operation *op) {} -/// Notify the listener that the pattern failed to match the given -/// operation, and provide a callback to populate a diagnostic with the -/// reason why the failure occurred. This method allows for derived -/// listeners to optionally hook into the reason why a rewrite failed, and -/// display it to users. +/// Notify the listener that the specified pattern is about to be applied +/// at the specified root operation. +virtual void notifyPatternBegin(const Pattern , Operation *op) {} + +/// Notify the listener that a pattern application finished with the +/// specified status. "success" indicates that the pattern was applied +/// successfully. "failure" indicates that the pattern could not be +/// applied. The pattern may have communicated the reason for the failure +/// with `notifyMatchFailure`. +virtual void notifyPatternEnd(const Pattern , + LogicalResult status) {} + +/// Notify the listener that the pattern failed to match, and provide a +/// callback to populate a diagnostic with the reason why the failure +/// occurred. This method allows for derived listeners to optionally hook +/// into the reason why a rewrite failed, and display it to users. virtual void notifyMatchFailure(Location loc, function_ref reasonCallback) {} @@ -478,6 +489,15 @@ class RewriterBase : public OpBuilder { if (auto *rewriteListener = dyn_cast(listener)) rewriteListener->notifyOperationErased(op); } +void notifyPatternBegin(const Pattern , Operation *op) override { + if (auto *rewriteListener = dyn_cast(listener)) +rewriteListener->notifyPatternBegin(pattern, op); +} +void notifyPatternEnd(const Pattern , + LogicalResult status) override { + if (auto *rewriteListener = dyn_cast(listener)) +rewriteListener->notifyPatternEnd(pattern, status); +} void notifyMatchFailure( Location loc, function_ref reasonCallback) override { diff --git a/mlir/include/mlir/Rewrite/PatternApplicator.h b/mlir/include/mlir/Rewrite/PatternApplicator.h index f7871f819a273b..c767bf8fee9073 100644 --- a/mlir/include/mlir/Rewrite/PatternApplicator.h +++ b/mlir/include/mlir/Rewrite/PatternApplicator.h @@ -68,9 +68,9 @@ class PatternApplicator { ///invalidate the match and try another pattern. LogicalResult matchAndRewrite(Operation *op, PatternRewriter , - function_ref canApply = {}, - function_ref onFailure = {}, - function_ref onSuccess = {}); + std::function canApply = {}, + std::function onFailure = {}, + std::function onSuccess = {}); /// Apply a cost model to the patterns within this applicator. void applyCostModel(CostModel model); diff --git a/mlir/lib/Rewrite/PatternApplicator.cpp b/mlir/lib/Rewrite/PatternApplicator.cpp index ea43f8a147d479..fecfb030a77fbf 100644 --- a/mlir/lib/Rewrite/PatternApplicator.cpp +++ b/mlir/lib/Rewrite/PatternApplicator.cpp @@ -129,9 +129,9 @@ void PatternApplicator::walkAllPatterns( LogicalResult PatternApplicator::matchAndRewrite( Operation *op, PatternRewriter , -function_ref canApply, -function_ref onFailure, -function_ref onSuccess) { +std::function canApply, +std::function onFailure, +std::function onSuccess) { // Before checking native patterns, first match against the bytecode. This // won't automatically perform any rewrites so there is no need to worry about // conflicts. diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index c1a261eab8487d..cd49bd121a62e5 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++
[llvm-branch-commits] [mlir] [mlir][IR] Add listener notifications for pattern begin/end (PR #84131)
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/84131 >From 407c7f7c24a87f409a921328605cc93637386d19 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Fri, 8 Mar 2024 02:01:50 + Subject: [PATCH] [mlir][IR] Add listener notifications for pattern begin/end --- mlir/include/mlir/IR/PatternMatch.h | 30 +-- .../Transforms/Utils/DialectConversion.cpp| 29 +++--- .../Utils/GreedyPatternRewriteDriver.cpp | 53 +++ 3 files changed, 77 insertions(+), 35 deletions(-) diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index e3500b3f9446d8..49544c42790d4d 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -432,11 +432,22 @@ class RewriterBase : public OpBuilder { /// Note: This notification is not triggered when unlinking an operation. virtual void notifyOperationErased(Operation *op) {} -/// Notify the listener that the pattern failed to match the given -/// operation, and provide a callback to populate a diagnostic with the -/// reason why the failure occurred. This method allows for derived -/// listeners to optionally hook into the reason why a rewrite failed, and -/// display it to users. +/// Notify the listener that the specified pattern is about to be applied +/// at the specified root operation. +virtual void notifyPatternBegin(const Pattern , Operation *op) {} + +/// Notify the listener that a pattern application finished with the +/// specified status. "success" indicates that the pattern was applied +/// successfully. "failure" indicates that the pattern could not be +/// applied. The pattern may have communicated the reason for the failure +/// with `notifyMatchFailure`. +virtual void notifyPatternEnd(const Pattern , + LogicalResult status) {} + +/// Notify the listener that the pattern failed to match, and provide a +/// callback to populate a diagnostic with the reason why the failure +/// occurred. This method allows for derived listeners to optionally hook +/// into the reason why a rewrite failed, and display it to users. virtual void notifyMatchFailure(Location loc, function_ref reasonCallback) {} @@ -478,6 +489,15 @@ class RewriterBase : public OpBuilder { if (auto *rewriteListener = dyn_cast(listener)) rewriteListener->notifyOperationErased(op); } +void notifyPatternBegin(const Pattern , Operation *op) override { + if (auto *rewriteListener = dyn_cast(listener)) +rewriteListener->notifyPatternBegin(pattern, op); +} +void notifyPatternEnd(const Pattern , + LogicalResult status) override { + if (auto *rewriteListener = dyn_cast(listener)) +rewriteListener->notifyPatternEnd(pattern, status); +} void notifyMatchFailure( Location loc, function_ref reasonCallback) override { diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index c1a261eab8487d..cd49bd121a62e5 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -1856,7 +1856,8 @@ class OperationLegalizer { using LegalizationAction = ConversionTarget::LegalizationAction; OperationLegalizer(const ConversionTarget , - const FrozenRewritePatternSet ); + const FrozenRewritePatternSet , + const ConversionConfig ); /// Returns true if the given operation is known to be illegal on the target. bool isIllegal(Operation *op) const; @@ -1948,12 +1949,16 @@ class OperationLegalizer { /// The pattern applicator to use for conversions. PatternApplicator applicator; + + /// Dialect conversion configuration. + const ConversionConfig }; } // namespace OperationLegalizer::OperationLegalizer(const ConversionTarget , - const FrozenRewritePatternSet ) -: target(targetInfo), applicator(patterns) { + const FrozenRewritePatternSet , + const ConversionConfig ) +: target(targetInfo), applicator(patterns), config(config) { // The set of patterns that can be applied to illegal operations to transform // them into legal ones. DenseMap legalizerPatterns; @@ -2098,7 +2103,10 @@ OperationLegalizer::legalizeWithPattern(Operation *op, // Functor that returns if the given pattern may be applied. auto canApply = [&](const Pattern ) { -return canApplyPattern(op, pattern, rewriter); +bool canApply = canApplyPattern(op, pattern, rewriter); +if (canApply && config.listener) + config.listener->notifyPatternBegin(pattern, op); +return canApply; }; // Functor that cleans up the
[llvm-branch-commits] [mlir] [mlir][IR] Add listener notifications for pattern begin/end (PR #84131)
llvmbot wrote: @llvm/pr-subscribers-mlir-core Author: Matthias Springer (matthias-springer) Changes This commit adds two new notifications to `RewriterBase::Listener`: * `notifyPatternBegin`: Called when a pattern application begins during a greedy pattern rewrite or dialect conversion. * `notifyPatternEnd`: Called when a pattern application finishes during a greedy pattern rewrite or dialect conversion. The listener infrastructure already provides a `notifyMatchFailure` callback that notifies about the reason for a pattern match failure. The two new notifications provide additional information about pattern applications. This change is in preparation of improving the handle update mechanism in the `apply_conversion_patterns` transform op. --- Full diff: https://github.com/llvm/llvm-project/pull/84131.diff 3 Files Affected: - (modified) mlir/include/mlir/IR/PatternMatch.h (+25-5) - (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+21-8) - (modified) mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp (+31-22) ``diff diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index f8d22cfb22afd0..838b4947648f5e 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -432,11 +432,22 @@ class RewriterBase : public OpBuilder { /// Note: This notification is not triggered when unlinking an operation. virtual void notifyOperationErased(Operation *op) {} -/// Notify the listener that the pattern failed to match the given -/// operation, and provide a callback to populate a diagnostic with the -/// reason why the failure occurred. This method allows for derived -/// listeners to optionally hook into the reason why a rewrite failed, and -/// display it to users. +/// Notify the listener that the specified pattern is about to be applied +/// at the specified root operation. +virtual void notifyPatternBegin(const Pattern , Operation *op) {} + +/// Notify the listener that a pattern application finished with the +/// specified status. "success" indicates that the pattern was applied +/// successfully. "failure" indicates that the pattern could not be +/// applied. The pattern may have communicated the reason for the failure +/// with `notifyMatchFailure`. +virtual void notifyPatternEnd(const Pattern , + LogicalResult status) {} + +/// Notify the listener that the pattern failed to match, and provide a +/// callback to populate a diagnostic with the reason why the failure +/// occurred. This method allows for derived listeners to optionally hook +/// into the reason why a rewrite failed, and display it to users. virtual void notifyMatchFailure(Location loc, function_ref reasonCallback) {} @@ -478,6 +489,15 @@ class RewriterBase : public OpBuilder { if (auto *rewriteListener = dyn_cast(listener)) rewriteListener->notifyOperationErased(op); } +void notifyPatternBegin(const Pattern , Operation *op) override { + if (auto *rewriteListener = dyn_cast(listener)) +rewriteListener->notifyPatternBegin(pattern, op); +} +void notifyPatternEnd(const Pattern , + LogicalResult status) override { + if (auto *rewriteListener = dyn_cast(listener)) +rewriteListener->notifyPatternEnd(pattern, status); +} void notifyMatchFailure( Location loc, function_ref reasonCallback) override { diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index a5145246bc30c4..587fbe209b58af 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -1863,7 +1863,8 @@ class OperationLegalizer { using LegalizationAction = ConversionTarget::LegalizationAction; OperationLegalizer(const ConversionTarget , - const FrozenRewritePatternSet ); + const FrozenRewritePatternSet , + const ConversionConfig ); /// Returns true if the given operation is known to be illegal on the target. bool isIllegal(Operation *op) const; @@ -1955,12 +1956,16 @@ class OperationLegalizer { /// The pattern applicator to use for conversions. PatternApplicator applicator; + + /// Dialect conversion configuration. + const ConversionConfig }; } // namespace OperationLegalizer::OperationLegalizer(const ConversionTarget , - const FrozenRewritePatternSet ) -: target(targetInfo), applicator(patterns) { + const FrozenRewritePatternSet , + const ConversionConfig ) +: target(targetInfo), applicator(patterns), config(config) { // The set of patterns that can be applied to illegal operations to
[llvm-branch-commits] [mlir] [mlir][IR] Add listener notifications for pattern begin/end (PR #84131)
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/84131 This commit adds two new notifications to `RewriterBase::Listener`: * `notifyPatternBegin`: Called when a pattern application begins during a greedy pattern rewrite or dialect conversion. * `notifyPatternEnd`: Called when a pattern application finishes during a greedy pattern rewrite or dialect conversion. The listener infrastructure already provides a `notifyMatchFailure` callback that notifies about the reason for a pattern match failure. The two new notifications provide additional information about pattern applications. This change is in preparation of improving the handle update mechanism in the `apply_conversion_patterns` transform op. >From 3ecec094d43a4eb9405ad057c32e4579f1fbe680 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Wed, 6 Mar 2024 08:00:35 + Subject: [PATCH] [mlir][IR] Add listener notifications for pattern begin/end --- mlir/include/mlir/IR/PatternMatch.h | 30 +-- .../Transforms/Utils/DialectConversion.cpp| 29 +++--- .../Utils/GreedyPatternRewriteDriver.cpp | 53 +++ 3 files changed, 77 insertions(+), 35 deletions(-) diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index f8d22cfb22afd0..838b4947648f5e 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -432,11 +432,22 @@ class RewriterBase : public OpBuilder { /// Note: This notification is not triggered when unlinking an operation. virtual void notifyOperationErased(Operation *op) {} -/// Notify the listener that the pattern failed to match the given -/// operation, and provide a callback to populate a diagnostic with the -/// reason why the failure occurred. This method allows for derived -/// listeners to optionally hook into the reason why a rewrite failed, and -/// display it to users. +/// Notify the listener that the specified pattern is about to be applied +/// at the specified root operation. +virtual void notifyPatternBegin(const Pattern , Operation *op) {} + +/// Notify the listener that a pattern application finished with the +/// specified status. "success" indicates that the pattern was applied +/// successfully. "failure" indicates that the pattern could not be +/// applied. The pattern may have communicated the reason for the failure +/// with `notifyMatchFailure`. +virtual void notifyPatternEnd(const Pattern , + LogicalResult status) {} + +/// Notify the listener that the pattern failed to match, and provide a +/// callback to populate a diagnostic with the reason why the failure +/// occurred. This method allows for derived listeners to optionally hook +/// into the reason why a rewrite failed, and display it to users. virtual void notifyMatchFailure(Location loc, function_ref reasonCallback) {} @@ -478,6 +489,15 @@ class RewriterBase : public OpBuilder { if (auto *rewriteListener = dyn_cast(listener)) rewriteListener->notifyOperationErased(op); } +void notifyPatternBegin(const Pattern , Operation *op) override { + if (auto *rewriteListener = dyn_cast(listener)) +rewriteListener->notifyPatternBegin(pattern, op); +} +void notifyPatternEnd(const Pattern , + LogicalResult status) override { + if (auto *rewriteListener = dyn_cast(listener)) +rewriteListener->notifyPatternEnd(pattern, status); +} void notifyMatchFailure( Location loc, function_ref reasonCallback) override { diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index a5145246bc30c4..587fbe209b58af 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -1863,7 +1863,8 @@ class OperationLegalizer { using LegalizationAction = ConversionTarget::LegalizationAction; OperationLegalizer(const ConversionTarget , - const FrozenRewritePatternSet ); + const FrozenRewritePatternSet , + const ConversionConfig ); /// Returns true if the given operation is known to be illegal on the target. bool isIllegal(Operation *op) const; @@ -1955,12 +1956,16 @@ class OperationLegalizer { /// The pattern applicator to use for conversions. PatternApplicator applicator; + + /// Dialect conversion configuration. + const ConversionConfig }; } // namespace OperationLegalizer::OperationLegalizer(const ConversionTarget , - const FrozenRewritePatternSet ) -: target(targetInfo), applicator(patterns) { + const FrozenRewritePatternSet , + const ConversionConfig ) +