llvmbot wrote:

<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

Replace a workaround in the implementation of `replaceAllUsesWith` in the 
no-rollback dialect conversion. This workaround was necessary for 
`restoreByValRefArgumentType` in the `func-to-llvm` lowering because there was 
no support for `replaceAllUsesExcept`. Support for this API has been added to 
the no-rollback driver, so the workaround can be dropped from that driver. The 
workaround is still in place for the rollback driver.

Depends on #<!-- -->169606.


---
Full diff: https://github.com/llvm/llvm-project/pull/169609.diff


4 Files Affected:

- (modified) mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp (+10-2) 
- (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+12-21) 
- (modified) mlir/test/Transforms/test-convert-func-op.mlir (+2-1) 
- (modified) mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp (+10-1) 


``````````diff
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp 
b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 2220f61ed8a07..ddd94f5d03042 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -283,8 +283,16 @@ static void restoreByValRefArgumentType(
     Type resTy = typeConverter.convertType(
         cast<TypeAttr>(byValRefAttr->getValue()).getValue());
 
-    Value valueArg = LLVM::LoadOp::create(rewriter, arg.getLoc(), resTy, arg);
-    rewriter.replaceAllUsesWith(arg, valueArg);
+    auto loadOp = LLVM::LoadOp::create(rewriter, arg.getLoc(), resTy, arg);
+    if (!rewriter.getConfig().allowPatternRollback) {
+      rewriter.replaceAllUsesExcept(arg, loadOp, loadOp);
+    } else {
+      // replaceAllUsesExcept is not supported in rollback mode. The rollback
+      // mode implementation has a workaround: certain replacements that would
+      // cause a dominance violation are skipped.
+      // TODO: Remove workaround.
+      rewriter.replaceAllUsesWith(arg, loadOp);
+    }
   }
 }
 
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp 
b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index c9f1596c07cbe..ccc5b7cb6f229 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1205,17 +1205,14 @@ void BlockTypeConversionRewrite::rollback() {
   getNewBlock()->replaceAllUsesWith(getOrigBlock());
 }
 
-/// Replace all uses of `from` with `repl`.
-static void
-performReplaceValue(RewriterBase &rewriter, Value from, Value repl,
-                    function_ref<bool(OpOperand &)> functor = nullptr) {
+void ReplaceValueRewrite::commit(RewriterBase &rewriter) {
+  Value repl = rewriterImpl.findOrBuildReplacementValue(value, converter);
+  if (!repl)
+    return;
+
   if (isa<BlockArgument>(repl)) {
     // `repl` is a block argument. Directly replace all uses.
-    if (functor) {
-      rewriter.replaceUsesWithIf(from, repl, functor);
-    } else {
-      rewriter.replaceAllUsesWith(from, repl);
-    }
+    rewriter.replaceAllUsesWith(value, repl);
     return;
   }
 
@@ -1244,23 +1241,14 @@ performReplaceValue(RewriterBase &rewriter, Value from, 
Value repl,
   // `ConversionPatternRewriter` API with the normal `RewriterBase` API.
   Operation *replOp = repl.getDefiningOp();
   Block *replBlock = replOp->getBlock();
-  rewriter.replaceUsesWithIf(from, repl, [&](OpOperand &operand) {
+  rewriter.replaceUsesWithIf(value, repl, [&](OpOperand &operand) {
     Operation *user = operand.getOwner();
     bool result =
         user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
-    if (functor)
-      result &= functor(operand);
     return result;
   });
 }
 
-void ReplaceValueRewrite::commit(RewriterBase &rewriter) {
-  Value repl = rewriterImpl.findOrBuildReplacementValue(value, converter);
-  if (!repl)
-    return;
-  performReplaceValue(rewriter, value, repl);
-}
-
 void ReplaceValueRewrite::rollback() {
   rewriterImpl.mapping.erase({value});
 #ifndef NDEBUG
@@ -2000,8 +1988,11 @@ void ConversionPatternRewriterImpl::replaceValueUses(
     Value repl = repls.front();
     if (!repl)
       return;
-
-    performReplaceValue(r, from, repl, functor);
+    if (functor) {
+      r.replaceUsesWithIf(from, repl, functor);
+    } else {
+      r.replaceAllUsesWith(from, repl);
+    }
     return;
   }
 
diff --git a/mlir/test/Transforms/test-convert-func-op.mlir 
b/mlir/test/Transforms/test-convert-func-op.mlir
index 180f16a32991b..14c15ecbe77f0 100644
--- a/mlir/test/Transforms/test-convert-func-op.mlir
+++ b/mlir/test/Transforms/test-convert-func-op.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt %s -test-convert-func-op --split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-convert-func-op="allow-pattern-rollback=1" 
--split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-convert-func-op="allow-pattern-rollback=0" 
--split-input-file | FileCheck %s
 
 // CHECK-LABEL: llvm.func @add
 func.func @add(%arg0: i32, %arg1: i32) -> i32 attributes { 
llvm.emit_c_interface } {
diff --git a/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp 
b/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp
index 75168dde93130..897b11b65b6f2 100644
--- a/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp
+++ b/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp
@@ -68,6 +68,9 @@ struct TestConvertFuncOp
     : public PassWrapper<TestConvertFuncOp, OperationPass<ModuleOp>> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestConvertFuncOp)
 
+  TestConvertFuncOp() = default;
+  TestConvertFuncOp(const TestConvertFuncOp &other) : PassWrapper(other) {}
+
   void getDependentDialects(DialectRegistry &registry) const final {
     registry.insert<LLVM::LLVMDialect>();
   }
@@ -92,10 +95,16 @@ struct TestConvertFuncOp
     patterns.add<ReturnOpConversion>(typeConverter);
 
     LLVMConversionTarget target(getContext());
+    ConversionConfig config;
+    config.allowPatternRollback = allowPatternRollback;
     if (failed(applyPartialConversion(getOperation(), target,
-                                      std::move(patterns))))
+                                      std::move(patterns), config)))
       signalPassFailure();
   }
+
+  Option<bool> allowPatternRollback{*this, "allow-pattern-rollback",
+                                    llvm::cl::desc("Allow pattern rollback"),
+                                    llvm::cl::init(true)};
 };
 
 } // namespace

``````````

</details>


https://github.com/llvm/llvm-project/pull/169609
_______________________________________________
llvm-branch-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to