Author: Ming Yan Date: 2025-11-28T17:30:38+08:00 New Revision: e6304ea7e524609c33329a8bcb5c618832f0d381
URL: https://github.com/llvm/llvm-project/commit/e6304ea7e524609c33329a8bcb5c618832f0d381 DIFF: https://github.com/llvm/llvm-project/commit/e6304ea7e524609c33329a8bcb5c618832f0d381.diff LOG: Revert "[MLIR][SCF] Sink scf.if from scf.while before region into after regio…" This reverts commit 25d027b8ab3acd65b58fce278f4173b431326934. Added: Modified: mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp mlir/test/Dialect/SCF/uplift-while.mlir Removed: ################################################################################ diff --git a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp index 9f242f9e62b8e..ec1044aaa42ac 100644 --- a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp @@ -19,83 +19,6 @@ using namespace mlir; namespace { -/// Move an scf.if op that is directly before the scf.condition op in the while -/// before region, and whose condition matches the condition of the -/// scf.condition op, down into the while after region. -/// -/// scf.while (%init) : (...) -> ... { -/// %cond = ... -/// %res = scf.if %cond -> (...) { -/// use1(%init) -/// %then_val = ... -/// ... // then block -/// scf.yield %then_val -/// } else { -/// scf.yield %init -/// } -/// scf.condition(%cond) %res -/// } do { -/// ^bb0(%arg): -/// use2(%arg) -/// ... -/// -/// becomes -/// scf.while (%init) : (...) -> ... { -/// %cond = ... -/// scf.condition(%cond) %init -/// } do { -/// ^bb0(%arg): : -/// use1(%arg) -/// ... // if then block -/// %then_val = ... -/// use2(%then_val) -/// ... -struct WhileMoveIfDown : public OpRewritePattern<scf::WhileOp> { - using OpRewritePattern<scf::WhileOp>::OpRewritePattern; - - LogicalResult matchAndRewrite(scf::WhileOp op, - PatternRewriter &rewriter) const override { - // Check that the first opeation produces one result and that result must - // have exactly two uses (these two uses come from the `scf.if` and - // `scf.condition` operations). - Operation &condOp = op.getBeforeBody()->front(); - if (condOp.getNumResults() != 1 || !condOp.getResult(0).hasNUses(2)) - return failure(); - - Value condVal = condOp.getResult(0); - auto ifOp = dyn_cast<scf::IfOp>(condOp.getNextNode()); - if (!ifOp || ifOp.getCondition() != condVal) - return failure(); - - auto term = dyn_cast<scf::ConditionOp>(ifOp->getNextNode()); - if (!term || term.getCondition() != condVal) - return failure(); - - // Check that if results and else yield operands match the scf.condition op - // arguments and while before arguments respectively. - if (!llvm::equal(ifOp->getResults(), term.getArgs()) || - !llvm::equal(ifOp.elseYield()->getOperands(), op.getBeforeArguments())) - return failure(); - - // Update uses and move the if op into the after region. - rewriter.replaceAllUsesWith(op.getAfterArguments(), - ifOp.thenYield()->getOperands()); - rewriter.replaceUsesWithIf(op.getBeforeArguments(), op.getAfterArguments(), - [&](OpOperand &use) { - return ifOp.getThenRegion().isAncestor( - use.getOwner()->getParentRegion()); - }); - rewriter.modifyOpInPlace( - term, [&]() { term.getArgsMutable().assign(op.getBeforeArguments()); }); - - rewriter.eraseOp(ifOp.thenYield()); - rewriter.inlineBlockBefore(ifOp.thenBlock(), op.getAfterBody(), - op.getAfterBody()->begin()); - rewriter.eraseOp(ifOp); - return success(); - } -}; - struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> { using OpRewritePattern::OpRewritePattern; @@ -344,5 +267,5 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter, } void mlir::scf::populateUpliftWhileToForPatterns(RewritePatternSet &patterns) { - patterns.add<UpliftWhileOp, WhileMoveIfDown>(patterns.getContext()); + patterns.add<UpliftWhileOp>(patterns.getContext()); } diff --git a/mlir/test/Dialect/SCF/uplift-while.mlir b/mlir/test/Dialect/SCF/uplift-while.mlir index 736112824c515..cbe2ce5076ad2 100644 --- a/mlir/test/Dialect/SCF/uplift-while.mlir +++ b/mlir/test/Dialect/SCF/uplift-while.mlir @@ -185,34 +185,3 @@ func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> (i32, f32) // CHECK: %[[T2:.*]] = "test.test2"(%[[ARG2]]) : (f32) -> f32 // CHECK: scf.yield %[[T1]], %[[T2]] : i32, f32 // CHECK: return %[[RES]]#0, %[[RES]]#1 : i32, f32 - -// ----- - -func.func @uplift_while(%low: index, %upper: index, %val : i32) -> i32 { - %c1 = arith.constant 1 : index - %1:2 = scf.while (%iv = %low, %iter = %val) : (index, i32) -> (index, i32) { - %2 = arith.cmpi slt, %iv, %upper : index - %3:2 = scf.if %2 -> (index, i32) { - %4 = "test.test"(%iter) : (i32) -> i32 - %5 = arith.addi %iv, %c1 : index - scf.yield %5, %4 : index, i32 - } else { - scf.yield %iv, %iter : index, i32 - } - scf.condition(%2) %3#0, %3#1 : index, i32 - } do { - ^bb0(%arg0: index, %arg1: i32): - scf.yield %arg0, %arg1 : index, i32 - } - return %1#1 : i32 -} - -// CHECK-LABEL: func.func @uplift_while( -// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: i32) -> i32 { -// CHECK: %[[CONSTANT_0:.*]] = arith.constant 1 : index -// CHECK: %[[FOR_0:.*]] = scf.for %[[VAL_0:.*]] = %[[ARG0]] to %[[ARG1]] step %[[CONSTANT_0]] iter_args(%[[VAL_1:.*]] = %[[ARG2]]) -> (i32) { -// CHECK: %[[VAL_2:.*]] = "test.test"(%[[VAL_1]]) : (i32) -> i32 -// CHECK: scf.yield %[[VAL_2]] : i32 -// CHECK: } -// CHECK: return %[[FOR_0]] : i32 -// CHECK: } _______________________________________________ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
