Author: Mehdi Amini Date: 2025-09-17T10:17:15+02:00 New Revision: b5c72affe48e74abdb0898a89aeefa8edfa81065
URL: https://github.com/llvm/llvm-project/commit/b5c72affe48e74abdb0898a89aeefa8edfa81065 DIFF: https://github.com/llvm/llvm-project/commit/b5c72affe48e74abdb0898a89aeefa8edfa81065.diff LOG: Revert "[mlir] move if-condition propagation to a standalone pass (#150278)" This reverts commit 9d11accf95db0ed08bd3181c25dd75fc793d089d. Added: Modified: mlir/include/mlir/Dialect/SCF/Transforms/Passes.td mlir/lib/Dialect/SCF/IR/SCF.cpp mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt mlir/test/Dialect/SCF/canonicalize.mlir Removed: mlir/lib/Dialect/SCF/Transforms/IfConditionPropagation.cpp mlir/test/Dialect/SCF/if-cond-prop.mlir ################################################################################ diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td index 8b891aa374b58..3ac651f53880c 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td @@ -41,12 +41,6 @@ def SCFForLoopSpecialization : Pass<"scf-for-loop-specialization"> { let constructor = "mlir::createForLoopSpecializationPass()"; } -def SCFIfConditionPropagation : Pass<"scf-if-condition-propagation"> { - let summary = "Replace usages of if condition with true/false constants in " - "the conditional regions"; - let dependentDialects = ["arith::ArithDialect"]; -} - def SCFParallelLoopFusion : Pass<"scf-parallel-loop-fusion"> { let summary = "Fuse adjacent parallel loops"; let constructor = "mlir::createParallelLoopFusionPass()"; diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index ae55eaded0554..a9da6c2c8320a 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -2453,6 +2453,65 @@ struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> { } }; +/// Allow the true region of an if to assume the condition is true +/// and vice versa. For example: +/// +/// scf.if %cmp { +/// print(%cmp) +/// } +/// +/// becomes +/// +/// scf.if %cmp { +/// print(true) +/// } +/// +struct ConditionPropagation : public OpRewritePattern<IfOp> { + using OpRewritePattern<IfOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(IfOp op, + PatternRewriter &rewriter) const override { + // Early exit if the condition is constant since replacing a constant + // in the body with another constant isn't a simplification. + if (matchPattern(op.getCondition(), m_Constant())) + return failure(); + + bool changed = false; + mlir::Type i1Ty = rewriter.getI1Type(); + + // These variables serve to prevent creating duplicate constants + // and hold constant true or false values. + Value constantTrue = nullptr; + Value constantFalse = nullptr; + + for (OpOperand &use : + llvm::make_early_inc_range(op.getCondition().getUses())) { + if (op.getThenRegion().isAncestor(use.getOwner()->getParentRegion())) { + changed = true; + + if (!constantTrue) + constantTrue = rewriter.create<arith::ConstantOp>( + op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1)); + + rewriter.modifyOpInPlace(use.getOwner(), + [&]() { use.set(constantTrue); }); + } else if (op.getElseRegion().isAncestor( + use.getOwner()->getParentRegion())) { + changed = true; + + if (!constantFalse) + constantFalse = rewriter.create<arith::ConstantOp>( + op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0)); + + rewriter.modifyOpInPlace(use.getOwner(), + [&]() { use.set(constantFalse); }); + } + } + + return success(changed); + } +}; + /// Remove any statements from an if that are equivalent to the condition /// or its negation. For example: /// @@ -2835,8 +2894,9 @@ struct CombineNestedIfs : public OpRewritePattern<IfOp> { void IfOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add<CombineIfs, CombineNestedIfs, ConvertTrivialIfToSelect, - RemoveEmptyElseBranch, RemoveStaticCondition, RemoveUnusedResults, + results.add<CombineIfs, CombineNestedIfs, ConditionPropagation, + ConvertTrivialIfToSelect, RemoveEmptyElseBranch, + RemoveStaticCondition, RemoveUnusedResults, ReplaceIfYieldWithConditionOrValue>(context); } diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt index a07d9d4953d19..a9ffa9dc208a0 100644 --- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt @@ -4,7 +4,6 @@ add_mlir_dialect_library(MLIRSCFTransforms ForallToFor.cpp ForallToParallel.cpp ForToWhile.cpp - IfConditionPropagation.cpp LoopCanonicalization.cpp LoopPipelining.cpp LoopRangeFolding.cpp diff --git a/mlir/lib/Dialect/SCF/Transforms/IfConditionPropagation.cpp b/mlir/lib/Dialect/SCF/Transforms/IfConditionPropagation.cpp deleted file mode 100644 index bdc51296ef9f2..0000000000000 --- a/mlir/lib/Dialect/SCF/Transforms/IfConditionPropagation.cpp +++ /dev/null @@ -1,98 +0,0 @@ -//===- IfConditionPropagation.cpp -----------------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file contains a pass for constant propagation of the condition of an -// `scf.if` into its then and else regions as true and false respectively. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/SCF/Transforms/Passes.h" - -using namespace mlir; - -namespace mlir { -#define GEN_PASS_DEF_SCFIFCONDITIONPROPAGATION -#include "mlir/Dialect/SCF/Transforms/Passes.h.inc" -} // namespace mlir - -/// Traverses the IR recursively (on region tree) and updates the uses of a -/// value also as the condition of an `scf.if` to either `true` or `false` -/// constants in the `then` and `else regions. This is done as a single -/// post-order sweep over the IR (without `walk`) for efficiency reasons. While -/// traversing, the function maintains the set of visited regions to quickly -/// identify whether the value belong to a region that is known to be nested in -/// the `then` or `else` branch of a specific loop. -static void propagateIfConditionsImpl(Operation *root, - llvm::SmallPtrSet<Region *, 8> &visited) { - if (auto scfIf = dyn_cast<scf::IfOp>(root)) { - llvm::SmallPtrSet<Region *, 8> thenChildren, elseChildren; - // Visit the "then" region, collect children. - for (Block &block : scfIf.getThenRegion()) { - for (Operation &op : block) { - propagateIfConditionsImpl(&op, thenChildren); - } - } - - // Visit the "else" region, collect children. - for (Block &block : scfIf.getElseRegion()) { - for (Operation &op : block) { - propagateIfConditionsImpl(&op, elseChildren); - } - } - - // Update uses to point to constants instead. - OpBuilder builder(scfIf); - Value trueValue = arith::ConstantIntOp::create(builder, scfIf.getLoc(), - /*value=*/true, /*width=*/1); - Value falseValue = - arith::ConstantIntOp::create(builder, scfIf.getLoc(), - /*value=*/false, /*width=*/1); - - for (OpOperand &use : scfIf.getCondition().getUses()) { - if (thenChildren.contains(use.getOwner()->getParentRegion())) - use.set(trueValue); - else if (elseChildren.contains(use.getOwner()->getParentRegion())) - use.set(falseValue); - } - if (trueValue.getUses().empty()) - trueValue.getDefiningOp()->erase(); - if (falseValue.getUses().empty()) - falseValue.getDefiningOp()->erase(); - - // Append the two lists of children and return them. - visited.insert_range(thenChildren); - visited.insert_range(elseChildren); - return; - } - - for (Region ®ion : root->getRegions()) { - for (Block &block : region) { - for (Operation &op : block) { - propagateIfConditionsImpl(&op, visited); - } - } - } -} - -/// Traverses the IR recursively (on region tree) and updates the uses of a -/// value also as the condition of an `scf.if` to either `true` or `false` -/// constants in the `then` and `else regions -static void propagateIfConditions(Operation *root) { - llvm::SmallPtrSet<Region *, 8> visited; - propagateIfConditionsImpl(root, visited); -} - -namespace { -/// Pass entrypoint. -struct SCFIfConditionPropagationPass - : impl::SCFIfConditionPropagationBase<SCFIfConditionPropagationPass> { - void runOnOperation() override { propagateIfConditions(getOperation()); } -}; -} // namespace diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir index 5e89f74075252..2bec63672e783 100644 --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -867,6 +867,41 @@ func.func @matmul_on_tensors(%t0: tensor<32x1024xf32>) -> tensor<?x?xf32> { // ----- +// CHECK-LABEL: @cond_prop +func.func @cond_prop(%arg0 : i1) -> index { + %res = scf.if %arg0 -> index { + %res1 = scf.if %arg0 -> index { + %v1 = "test.get_some_value1"() : () -> index + scf.yield %v1 : index + } else { + %v2 = "test.get_some_value2"() : () -> index + scf.yield %v2 : index + } + scf.yield %res1 : index + } else { + %res2 = scf.if %arg0 -> index { + %v3 = "test.get_some_value3"() : () -> index + scf.yield %v3 : index + } else { + %v4 = "test.get_some_value4"() : () -> index + scf.yield %v4 : index + } + scf.yield %res2 : index + } + return %res : index +} +// CHECK-NEXT: %[[if:.+]] = scf.if %arg0 -> (index) { +// CHECK-NEXT: %[[c1:.+]] = "test.get_some_value1"() : () -> index +// CHECK-NEXT: scf.yield %[[c1]] : index +// CHECK-NEXT: } else { +// CHECK-NEXT: %[[c4:.+]] = "test.get_some_value4"() : () -> index +// CHECK-NEXT: scf.yield %[[c4]] : index +// CHECK-NEXT: } +// CHECK-NEXT: return %[[if]] : index +// CHECK-NEXT:} + +// ----- + // CHECK-LABEL: @replace_if_with_cond1 func.func @replace_if_with_cond1(%arg0 : i1) -> (i32, i1) { %true = arith.constant true diff --git a/mlir/test/Dialect/SCF/if-cond-prop.mlir b/mlir/test/Dialect/SCF/if-cond-prop.mlir deleted file mode 100644 index 99d113f672014..0000000000000 --- a/mlir/test/Dialect/SCF/if-cond-prop.mlir +++ /dev/null @@ -1,34 +0,0 @@ -// RUN: mlir-opt %s --scf-if-condition-propagation --allow-unregistered-dialect | FileCheck %s - -// CHECK-LABEL: @cond_prop -func.func @cond_prop(%arg0 : i1) -> index { - %res = scf.if %arg0 -> index { - %res1 = scf.if %arg0 -> index { - %v1 = "test.get_some_value1"() : () -> index - scf.yield %v1 : index - } else { - %v2 = "test.get_some_value2"() : () -> index - scf.yield %v2 : index - } - scf.yield %res1 : index - } else { - %res2 = scf.if %arg0 -> index { - %v3 = "test.get_some_value3"() : () -> index - scf.yield %v3 : index - } else { - %v4 = "test.get_some_value4"() : () -> index - scf.yield %v4 : index - } - scf.yield %res2 : index - } - return %res : index -} -// CHECK: %[[if:.+]] = scf.if %arg0 -> (index) { -// CHECK: %[[c1:.+]] = "test.get_some_value1"() : () -> index -// CHECK: scf.yield %[[c1]] : index -// CHECK: } else { -// CHECK: %[[c4:.+]] = "test.get_some_value4"() : () -> index -// CHECK: scf.yield %[[c4]] : index -// CHECK: } -// CHECK: return %[[if]] : index -// CHECK:} _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits