Author: Mehdi Amini
Date: 2025-09-17T10:17:15+02:00
New Revision: b5c72affe48e74abdb0898a89aeefa8edfa81065

URL: 
https://github.com/llvm/llvm-project/commit/b5c72affe48e74abdb0898a89aeefa8edfa81065
DIFF: 
https://github.com/llvm/llvm-project/commit/b5c72affe48e74abdb0898a89aeefa8edfa81065.diff

LOG: Revert "[mlir] move if-condition propagation to a standalone pass 
(#150278)"

This reverts commit 9d11accf95db0ed08bd3181c25dd75fc793d089d.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
    mlir/lib/Dialect/SCF/IR/SCF.cpp
    mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
    mlir/test/Dialect/SCF/canonicalize.mlir

Removed: 
    mlir/lib/Dialect/SCF/Transforms/IfConditionPropagation.cpp
    mlir/test/Dialect/SCF/if-cond-prop.mlir


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td 
b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
index 8b891aa374b58..3ac651f53880c 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
@@ -41,12 +41,6 @@ def SCFForLoopSpecialization : 
Pass<"scf-for-loop-specialization"> {
   let constructor = "mlir::createForLoopSpecializationPass()";
 }
 
-def SCFIfConditionPropagation : Pass<"scf-if-condition-propagation"> {
-  let summary = "Replace usages of if condition with true/false constants in "
-                "the conditional regions";
-  let dependentDialects = ["arith::ArithDialect"];
-}
-
 def SCFParallelLoopFusion : Pass<"scf-parallel-loop-fusion"> {
   let summary = "Fuse adjacent parallel loops";
   let constructor = "mlir::createParallelLoopFusionPass()";

diff  --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index ae55eaded0554..a9da6c2c8320a 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -2453,6 +2453,65 @@ struct ConvertTrivialIfToSelect : public 
OpRewritePattern<IfOp> {
   }
 };
 
+/// Allow the true region of an if to assume the condition is true
+/// and vice versa. For example:
+///
+///   scf.if %cmp {
+///      print(%cmp)
+///   }
+///
+///  becomes
+///
+///   scf.if %cmp {
+///      print(true)
+///   }
+///
+struct ConditionPropagation : public OpRewritePattern<IfOp> {
+  using OpRewritePattern<IfOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(IfOp op,
+                                PatternRewriter &rewriter) const override {
+    // Early exit if the condition is constant since replacing a constant
+    // in the body with another constant isn't a simplification.
+    if (matchPattern(op.getCondition(), m_Constant()))
+      return failure();
+
+    bool changed = false;
+    mlir::Type i1Ty = rewriter.getI1Type();
+
+    // These variables serve to prevent creating duplicate constants
+    // and hold constant true or false values.
+    Value constantTrue = nullptr;
+    Value constantFalse = nullptr;
+
+    for (OpOperand &use :
+         llvm::make_early_inc_range(op.getCondition().getUses())) {
+      if (op.getThenRegion().isAncestor(use.getOwner()->getParentRegion())) {
+        changed = true;
+
+        if (!constantTrue)
+          constantTrue = rewriter.create<arith::ConstantOp>(
+              op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
+
+        rewriter.modifyOpInPlace(use.getOwner(),
+                                 [&]() { use.set(constantTrue); });
+      } else if (op.getElseRegion().isAncestor(
+                     use.getOwner()->getParentRegion())) {
+        changed = true;
+
+        if (!constantFalse)
+          constantFalse = rewriter.create<arith::ConstantOp>(
+              op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
+
+        rewriter.modifyOpInPlace(use.getOwner(),
+                                 [&]() { use.set(constantFalse); });
+      }
+    }
+
+    return success(changed);
+  }
+};
+
 /// Remove any statements from an if that are equivalent to the condition
 /// or its negation. For example:
 ///
@@ -2835,8 +2894,9 @@ struct CombineNestedIfs : public OpRewritePattern<IfOp> {
 
 void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                        MLIRContext *context) {
-  results.add<CombineIfs, CombineNestedIfs, ConvertTrivialIfToSelect,
-              RemoveEmptyElseBranch, RemoveStaticCondition, 
RemoveUnusedResults,
+  results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
+              ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
+              RemoveStaticCondition, RemoveUnusedResults,
               ReplaceIfYieldWithConditionOrValue>(context);
 }
 

diff  --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt 
b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index a07d9d4953d19..a9ffa9dc208a0 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -4,7 +4,6 @@ add_mlir_dialect_library(MLIRSCFTransforms
   ForallToFor.cpp
   ForallToParallel.cpp
   ForToWhile.cpp
-  IfConditionPropagation.cpp
   LoopCanonicalization.cpp
   LoopPipelining.cpp
   LoopRangeFolding.cpp

diff  --git a/mlir/lib/Dialect/SCF/Transforms/IfConditionPropagation.cpp 
b/mlir/lib/Dialect/SCF/Transforms/IfConditionPropagation.cpp
deleted file mode 100644
index bdc51296ef9f2..0000000000000
--- a/mlir/lib/Dialect/SCF/Transforms/IfConditionPropagation.cpp
+++ /dev/null
@@ -1,98 +0,0 @@
-//===- IfConditionPropagation.cpp 
-----------------------------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM 
Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file contains a pass for constant propagation of the condition of an
-// `scf.if` into its then and else regions as true and false respectively.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/SCF/Transforms/Passes.h"
-
-using namespace mlir;
-
-namespace mlir {
-#define GEN_PASS_DEF_SCFIFCONDITIONPROPAGATION
-#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
-} // namespace mlir
-
-/// Traverses the IR recursively (on region tree) and updates the uses of a
-/// value also as the condition of an `scf.if` to either `true` or `false`
-/// constants in the `then` and `else regions. This is done as a single
-/// post-order sweep over the IR (without `walk`) for efficiency reasons. While
-/// traversing, the function maintains the set of visited regions to quickly
-/// identify whether the value belong to a region that is known to be nested in
-/// the `then` or `else` branch of a specific loop.
-static void propagateIfConditionsImpl(Operation *root,
-                                      llvm::SmallPtrSet<Region *, 8> &visited) 
{
-  if (auto scfIf = dyn_cast<scf::IfOp>(root)) {
-    llvm::SmallPtrSet<Region *, 8> thenChildren, elseChildren;
-    // Visit the "then" region, collect children.
-    for (Block &block : scfIf.getThenRegion()) {
-      for (Operation &op : block) {
-        propagateIfConditionsImpl(&op, thenChildren);
-      }
-    }
-
-    // Visit the "else" region, collect children.
-    for (Block &block : scfIf.getElseRegion()) {
-      for (Operation &op : block) {
-        propagateIfConditionsImpl(&op, elseChildren);
-      }
-    }
-
-    // Update uses to point to constants instead.
-    OpBuilder builder(scfIf);
-    Value trueValue = arith::ConstantIntOp::create(builder, scfIf.getLoc(),
-                                                   /*value=*/true, 
/*width=*/1);
-    Value falseValue =
-        arith::ConstantIntOp::create(builder, scfIf.getLoc(),
-                                     /*value=*/false, /*width=*/1);
-
-    for (OpOperand &use : scfIf.getCondition().getUses()) {
-      if (thenChildren.contains(use.getOwner()->getParentRegion()))
-        use.set(trueValue);
-      else if (elseChildren.contains(use.getOwner()->getParentRegion()))
-        use.set(falseValue);
-    }
-    if (trueValue.getUses().empty())
-      trueValue.getDefiningOp()->erase();
-    if (falseValue.getUses().empty())
-      falseValue.getDefiningOp()->erase();
-
-    // Append the two lists of children and return them.
-    visited.insert_range(thenChildren);
-    visited.insert_range(elseChildren);
-    return;
-  }
-
-  for (Region &region : root->getRegions()) {
-    for (Block &block : region) {
-      for (Operation &op : block) {
-        propagateIfConditionsImpl(&op, visited);
-      }
-    }
-  }
-}
-
-/// Traverses the IR recursively (on region tree) and updates the uses of a
-/// value also as the condition of an `scf.if` to either `true` or `false`
-/// constants in the `then` and `else regions
-static void propagateIfConditions(Operation *root) {
-  llvm::SmallPtrSet<Region *, 8> visited;
-  propagateIfConditionsImpl(root, visited);
-}
-
-namespace {
-/// Pass entrypoint.
-struct SCFIfConditionPropagationPass
-    : impl::SCFIfConditionPropagationBase<SCFIfConditionPropagationPass> {
-  void runOnOperation() override { propagateIfConditions(getOperation()); }
-};
-} // namespace

diff  --git a/mlir/test/Dialect/SCF/canonicalize.mlir 
b/mlir/test/Dialect/SCF/canonicalize.mlir
index 5e89f74075252..2bec63672e783 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -867,6 +867,41 @@ func.func @matmul_on_tensors(%t0: tensor<32x1024xf32>) -> 
tensor<?x?xf32> {
 
 // -----
 
+// CHECK-LABEL: @cond_prop
+func.func @cond_prop(%arg0 : i1) -> index {
+  %res = scf.if %arg0 -> index {
+    %res1 = scf.if %arg0 -> index {
+      %v1 = "test.get_some_value1"() : () -> index
+      scf.yield %v1 : index
+    } else {
+      %v2 = "test.get_some_value2"() : () -> index
+      scf.yield %v2 : index
+    }
+    scf.yield %res1 : index
+  } else {
+    %res2 = scf.if %arg0 -> index {
+      %v3 = "test.get_some_value3"() : () -> index
+      scf.yield %v3 : index
+    } else {
+      %v4 = "test.get_some_value4"() : () -> index
+      scf.yield %v4 : index
+    }
+    scf.yield %res2 : index
+  }
+  return %res : index
+}
+// CHECK-NEXT:  %[[if:.+]] = scf.if %arg0 -> (index) {
+// CHECK-NEXT:    %[[c1:.+]] = "test.get_some_value1"() : () -> index
+// CHECK-NEXT:    scf.yield %[[c1]] : index
+// CHECK-NEXT:  } else {
+// CHECK-NEXT:    %[[c4:.+]] = "test.get_some_value4"() : () -> index
+// CHECK-NEXT:    scf.yield %[[c4]] : index
+// CHECK-NEXT:  }
+// CHECK-NEXT:  return %[[if]] : index
+// CHECK-NEXT:}
+
+// -----
+
 // CHECK-LABEL: @replace_if_with_cond1
 func.func @replace_if_with_cond1(%arg0 : i1) -> (i32, i1) {
   %true = arith.constant true

diff  --git a/mlir/test/Dialect/SCF/if-cond-prop.mlir 
b/mlir/test/Dialect/SCF/if-cond-prop.mlir
deleted file mode 100644
index 99d113f672014..0000000000000
--- a/mlir/test/Dialect/SCF/if-cond-prop.mlir
+++ /dev/null
@@ -1,34 +0,0 @@
-// RUN: mlir-opt %s --scf-if-condition-propagation 
--allow-unregistered-dialect | FileCheck %s
-
-// CHECK-LABEL: @cond_prop
-func.func @cond_prop(%arg0 : i1) -> index {
-  %res = scf.if %arg0 -> index {
-    %res1 = scf.if %arg0 -> index {
-      %v1 = "test.get_some_value1"() : () -> index
-      scf.yield %v1 : index
-    } else {
-      %v2 = "test.get_some_value2"() : () -> index
-      scf.yield %v2 : index
-    }
-    scf.yield %res1 : index
-  } else {
-    %res2 = scf.if %arg0 -> index {
-      %v3 = "test.get_some_value3"() : () -> index
-      scf.yield %v3 : index
-    } else {
-      %v4 = "test.get_some_value4"() : () -> index
-      scf.yield %v4 : index
-    }
-    scf.yield %res2 : index
-  }
-  return %res : index
-}
-// CHECK:  %[[if:.+]] = scf.if %arg0 -> (index) {
-// CHECK:    %[[c1:.+]] = "test.get_some_value1"() : () -> index
-// CHECK:    scf.yield %[[c1]] : index
-// CHECK:  } else {
-// CHECK:    %[[c4:.+]] = "test.get_some_value4"() : () -> index
-// CHECK:    scf.yield %[[c4]] : index
-// CHECK:  }
-// CHECK:  return %[[if]] : index
-// CHECK:}


        
_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to