https://github.com/Andres-Salamanca updated https://github.com/llvm/llvm-project/pull/140649
>From c1403f148a58e259cc296310dc21b8c5611f2e82 Mon Sep 17 00:00:00 2001 From: Andres Salamanca <andrealebarbari...@gmail.com> Date: Mon, 19 May 2025 18:53:15 -0500 Subject: [PATCH 1/3] Implement CIR switch case simplify with appropriate tests --- clang/include/clang/CIR/MissingFeatures.h | 1 - clang/lib/CIR/CodeGen/CIRGenStmt.cpp | 6 - .../CIR/Dialect/Transforms/CIRSimplify.cpp | 106 +++++++++- clang/test/CIR/Transforms/switch-fold.cir | 196 ++++++++++++++++++ 4 files changed, 300 insertions(+), 9 deletions(-) create mode 100644 clang/test/CIR/Transforms/switch-fold.cir diff --git a/clang/include/clang/CIR/MissingFeatures.h b/clang/include/clang/CIR/MissingFeatures.h index 484822c351746..9f3e5d007d66c 100644 --- a/clang/include/clang/CIR/MissingFeatures.h +++ b/clang/include/clang/CIR/MissingFeatures.h @@ -114,7 +114,6 @@ struct MissingFeatures { static bool opUnaryPromotionType() { return false; } // SwitchOp handling - static bool foldCascadingCases() { return false; } static bool foldRangeCase() { return false; } // Clang early optimizations or things defered to LLVM lowering. diff --git a/clang/lib/CIR/CodeGen/CIRGenStmt.cpp b/clang/lib/CIR/CodeGen/CIRGenStmt.cpp index cc96e65e4ce1d..7f1ecbda414bd 100644 --- a/clang/lib/CIR/CodeGen/CIRGenStmt.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenStmt.cpp @@ -531,12 +531,6 @@ mlir::LogicalResult CIRGenFunction::emitCaseStmt(const CaseStmt &s, value = builder.getArrayAttr({cir::IntAttr::get(condType, intVal), cir::IntAttr::get(condType, endVal)}); kind = cir::CaseOpKind::Range; - - // We don't currently fold case range statements with other case statements. - // TODO(cir): Add this capability. Folding these cases is going to be - // implemented in CIRSimplify when it is upstreamed. - assert(!cir::MissingFeatures::foldRangeCase()); - assert(!cir::MissingFeatures::foldCascadingCases()); } else { value = builder.getArrayAttr({cir::IntAttr::get(condType, intVal)}); kind = cir::CaseOpKind::Equal; diff --git a/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp b/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp index b969569b0081c..58300cc219602 100644 --- a/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp +++ b/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp @@ -159,6 +159,107 @@ struct SimplifySelect : public OpRewritePattern<SelectOp> { } }; +/// Simplify `cir.switch` operations by folding cascading cases +/// into a single `cir.case` with the `anyof` kind. +/// +/// This pattern identifies cascading cases within a `cir.switch` operation. +/// Cascading cases are defined as consecutive `cir.case` operations of kind +/// `equal`, each containing a single `cir.yield` operation in their body. +/// +/// The pattern merges these cascading cases into a single `cir.case` operation +/// with kind `anyof`, aggregating all the case values. +/// +/// The merging process continues until a `cir.case` with a different body +/// (e.g., containing `cir.break` or compound stmt) is encountered, which +/// breaks the chain. +/// +/// Example: +/// +/// Before: +/// cir.case equal, [#cir.int<0> : !s32i] { +/// cir.yield +/// } +/// cir.case equal, [#cir.int<1> : !s32i] { +/// cir.yield +/// } +/// cir.case equal, [#cir.int<2> : !s32i] { +/// cir.break +/// } +/// +/// After applying SimplifySwitch: +/// cir.case anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> : +/// !s32i] { +/// cir.break +/// } +struct SimplifySwitch : public OpRewritePattern<SwitchOp> { + using OpRewritePattern<SwitchOp>::OpRewritePattern; + LogicalResult matchAndRewrite(SwitchOp op, + PatternRewriter &rewriter) const override { + + LogicalResult changed = mlir::failure(); + llvm::SmallVector<CaseOp, 8> cases; + SmallVector<CaseOp, 4> cascadingCases; + SmallVector<mlir::Attribute, 4> cascadingCaseValues; + + op.collectCases(cases); + if (cases.empty()) + return mlir::failure(); + + auto flushMergedOps = [&]() { + for (CaseOp &c : cascadingCases) { + rewriter.eraseOp(c); + } + cascadingCases.clear(); + cascadingCaseValues.clear(); + }; + + auto mergeCascadingInto = [&](CaseOp &target) { + rewriter.modifyOpInPlace(target, [&]() { + target.setValueAttr(rewriter.getArrayAttr(cascadingCaseValues)); + target.setKind(CaseOpKind::Anyof); + }); + changed = mlir::success(); + }; + + for (CaseOp c : cases) { + cir::CaseOpKind kind = c.getKind(); + if (kind == cir::CaseOpKind::Equal && + isa<YieldOp>(c.getCaseRegion().front().front())) { + // If the case contains only a YieldOp, collect it for cascading merge + cascadingCases.push_back(c); + cascadingCaseValues.push_back(c.getValue()[0]); + + } else if (kind == cir::CaseOpKind::Equal && !cascadingCases.empty()) { + // merge previously collected cascading cases + cascadingCaseValues.push_back(c.getValue()[0]); + mergeCascadingInto(c); + flushMergedOps(); + } else if (kind != cir::CaseOpKind::Equal && cascadingCases.size() > 1) { + // If a Default, Anyof or Range case is found and there are previous + // cascading cases, merge all of them into the last cascading case. + CaseOp lastCascadingCase = cascadingCases.back(); + mergeCascadingInto(lastCascadingCase); + cascadingCases.pop_back(); + flushMergedOps(); + } else { + cascadingCases.clear(); + cascadingCaseValues.clear(); + } + } + + // Edge case: all cases are simple cascading cases + if (cascadingCases.size() == cases.size()) { + CaseOp lastCascadingCase = cascadingCases.back(); + mergeCascadingInto(lastCascadingCase); + cascadingCases.pop_back(); + flushMergedOps(); + } + // We don't currently fold case range statements with other case statements. + assert(!cir::MissingFeatures::foldRangeCase()); + return changed; + } +}; + //===----------------------------------------------------------------------===// // CIRSimplifyPass //===----------------------------------------------------------------------===// @@ -173,7 +274,8 @@ void populateMergeCleanupPatterns(RewritePatternSet &patterns) { // clang-format off patterns.add< SimplifyTernary, - SimplifySelect + SimplifySelect, + SimplifySwitch >(patterns.getContext()); // clang-format on } @@ -186,7 +288,7 @@ void CIRSimplifyPass::runOnOperation() { // Collect operations to apply patterns. llvm::SmallVector<Operation *, 16> ops; getOperation()->walk([&](Operation *op) { - if (isa<TernaryOp, SelectOp>(op)) + if (isa<TernaryOp, SelectOp, SwitchOp>(op)) ops.push_back(op); }); diff --git a/clang/test/CIR/Transforms/switch-fold.cir b/clang/test/CIR/Transforms/switch-fold.cir new file mode 100644 index 0000000000000..3c2fe8a9cbf25 --- /dev/null +++ b/clang/test/CIR/Transforms/switch-fold.cir @@ -0,0 +1,196 @@ +// RUN: cir-opt -cir-canonicalize -cir-simplify -o %t.cir %s +// RUN: FileCheck --input-file=%t.cir %s + +!s32i = !cir.int<s, 32> + +module { + cir.func @foldCascade(%arg0: !s32i) { + %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64} + cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i> + cir.scope { + %1 = cir.load %0 : !cir.ptr<!s32i>, !s32i + cir.switch (%1 : !s32i) { + cir.case(equal, [#cir.int<1> : !s32i]) { + cir.yield + } + cir.case(equal, [#cir.int<2> : !s32i]) { + cir.yield + } + cir.case(equal, [#cir.int<3> : !s32i]) { + %2 = cir.const #cir.int<2> : !s32i + cir.store %2, %0 : !s32i, !cir.ptr<!s32i> + cir.break + } + cir.yield + } + } + cir.return + } + //CHECK: cir.func @foldCascade + //CHECK: cir.switch (%[[COND:.*]] : !s32i) { + //CHECK-NEXT: cir.case(anyof, [#cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i]) { + //CHECK-NEXT: %[[TWO:.*]] = cir.const #cir.int<2> : !s32i + //CHECK-NEXT: cir.store %[[TWO]], %[[ARG0:.*]] : !s32i, !cir.ptr<!s32i> + //CHECK-NEXT: cir.break + //CHECK-NEXT: } + //CHECK-NEXT: cir.yield + //CHECK-NEXT: } + + cir.func @foldCascade2(%arg0: !s32i) { + %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64} + cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i> + cir.scope { + %1 = cir.load %0 : !cir.ptr<!s32i>, !s32i + cir.switch (%1 : !s32i) { + cir.case(equal, [#cir.int<0> : !s32i]) { + cir.yield + } + cir.case(equal, [#cir.int<1> : !s32i]) { + cir.yield + } + cir.case(equal, [#cir.int<2> : !s32i]) { + cir.break + } + cir.case(equal, [#cir.int<3> : !s32i]) { + cir.yield + } + cir.case(equal, [#cir.int<4> : !s32i]) { + cir.yield + } + cir.case(equal, [#cir.int<5> : !s32i]) { + cir.break + } + cir.yield + } + } + cir.return + } + //CHECK: @foldCascade2 + //CHECK: cir.switch (%[[COND2:.*]] : !s32i) { + //CHECK: cir.case(anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> : !s32i]) { + //CHECK: cir.break + //cehck: } + //CHECK: cir.case(anyof, [#cir.int<3> : !s32i, #cir.int<4> : !s32i, #cir.int<5> : !s32i]) { + //CHECK: cir.break + //CHECK: } + //CHECK: cir.yield + //CHECK: } + cir.func @foldCascade3(%arg0: !s32i ) { + %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64} + cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i> + cir.scope { + %1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["x"] {alignment = 4 : i64} + %2 = cir.load %0 : !cir.ptr<!s32i>, !s32i + cir.switch (%2 : !s32i) { + cir.case(equal, [#cir.int<0> : !s32i]) { + cir.yield + } + cir.case(equal, [#cir.int<1> : !s32i]) { + cir.yield + } + cir.case(equal, [#cir.int<2> : !s32i]) { + cir.yield + } + cir.case(equal, [#cir.int<3> : !s32i]) { + cir.yield + } + cir.case(equal, [#cir.int<4> : !s32i]) { + cir.yield + } + cir.case(equal, [#cir.int<5> : !s32i]) { + cir.break + } + cir.yield + } + } + cir.return + } + //CHECK: cir.func @foldCascade3 + //CHECK: cir.switch (%[[COND3:.*]] : !s32i) { + //CHECK: cir.case(anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i, #cir.int<5> : !s32i]) { + //CHECK: cir.break + //CHECK: } + //CHECK: cir.yield + //CHECK: } + cir.func @foldCascadeWithDefault(%arg0: !s32i ) { + %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64} + cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i> + cir.scope { + %1 = cir.load %0 : !cir.ptr<!s32i>, !s32i + cir.switch (%1 : !s32i) { + cir.case(equal, [#cir.int<3> : !s32i]) { + cir.break + } + cir.case(equal, [#cir.int<4> : !s32i]) { + cir.yield + } + cir.case(equal, [#cir.int<5> : !s32i]) { + cir.yield + } + cir.case(default, []) { + cir.yield + } + cir.case(equal, [#cir.int<6> : !s32i]) { + cir.yield + } + cir.case(equal, [#cir.int<7> : !s32i]) { + cir.break + } + cir.yield + } + } + cir.return + } + //CHECK: cir.func @foldCascadeWithDefault + //CHECK: cir.switch (%[[COND:.*]] : !s32i) { + //CHECK: cir.case(equal, [#cir.int<3> : !s32i]) { + //CHECK: cir.break + //CHECK: } + //CHECK: cir.case(anyof, [#cir.int<4> : !s32i, #cir.int<5> : !s32i]) { + //CHECK: cir.yield + //CHECK: } + //CHECK: cir.case(default, []) { + //CHECK: cir.yield + //CHECK: } + //CHECK: cir.case(anyof, [#cir.int<6> : !s32i, #cir.int<7> : !s32i]) { + //CHECK: cir.break + //CHECK: } + //CHECK: cir.yield + //CHECK: } + cir.func @foldAllCascade(%arg0: !s32i ) { + %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64} + cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i> + cir.scope { + %1 = cir.load %0 : !cir.ptr<!s32i>, !s32i + cir.switch (%1 : !s32i) { + cir.case(equal, [#cir.int<0> : !s32i]) { + cir.yield + } + cir.case(equal, [#cir.int<1> : !s32i]) { + cir.yield + } + cir.case(equal, [#cir.int<2> : !s32i]) { + cir.yield + } + cir.case(equal, [#cir.int<3> : !s32i]) { + cir.yield + } + cir.case(equal, [#cir.int<4> : !s32i]) { + cir.yield + } + cir.case(equal, [#cir.int<5> : !s32i]) { + cir.yield + } + cir.yield + } + } + cir.return + } + //CHECK: cir.func @foldAllCascade + //CHECK: cir.switch (%[[COND:.*]] : !s32i) { + //CHECK: cir.case(anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i, #cir.int<5> : !s32i]) { + //CHECK: cir.yield + //CHECK: } + //CHECK: cir.yield + //CHECK: } +} >From 0afcfd6b0f5cf414f846d2012ac4508e968199d8 Mon Sep 17 00:00:00 2001 From: Andres Salamanca <andrealebarbari...@gmail.com> Date: Wed, 21 May 2025 11:27:36 -0500 Subject: [PATCH 2/3] Apply reviews --- clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp | 8 ++++---- clang/test/CIR/Transforms/switch-fold.cir | 12 ++++++------ 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp b/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp index 58300cc219602..af064c0800fbe 100644 --- a/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp +++ b/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp @@ -197,7 +197,7 @@ struct SimplifySwitch : public OpRewritePattern<SwitchOp> { PatternRewriter &rewriter) const override { LogicalResult changed = mlir::failure(); - llvm::SmallVector<CaseOp, 8> cases; + SmallVector<CaseOp, 8> cases; SmallVector<CaseOp, 4> cascadingCases; SmallVector<mlir::Attribute, 4> cascadingCaseValues; @@ -228,7 +228,6 @@ struct SimplifySwitch : public OpRewritePattern<SwitchOp> { // If the case contains only a YieldOp, collect it for cascading merge cascadingCases.push_back(c); cascadingCaseValues.push_back(c.getValue()[0]); - } else if (kind == cir::CaseOpKind::Equal && !cascadingCases.empty()) { // merge previously collected cascading cases cascadingCaseValues.push_back(c.getValue()[0]); @@ -237,6 +236,8 @@ struct SimplifySwitch : public OpRewritePattern<SwitchOp> { } else if (kind != cir::CaseOpKind::Equal && cascadingCases.size() > 1) { // If a Default, Anyof or Range case is found and there are previous // cascading cases, merge all of them into the last cascading case. + // We don't currently fold case range statements with other case statements. + assert(!cir::MissingFeatures::foldRangeCase()); CaseOp lastCascadingCase = cascadingCases.back(); mergeCascadingInto(lastCascadingCase); cascadingCases.pop_back(); @@ -254,8 +255,7 @@ struct SimplifySwitch : public OpRewritePattern<SwitchOp> { cascadingCases.pop_back(); flushMergedOps(); } - // We don't currently fold case range statements with other case statements. - assert(!cir::MissingFeatures::foldRangeCase()); + return changed; } }; diff --git a/clang/test/CIR/Transforms/switch-fold.cir b/clang/test/CIR/Transforms/switch-fold.cir index 3c2fe8a9cbf25..62a94f4fde2c3 100644 --- a/clang/test/CIR/Transforms/switch-fold.cir +++ b/clang/test/CIR/Transforms/switch-fold.cir @@ -45,16 +45,16 @@ module { cir.case(equal, [#cir.int<0> : !s32i]) { cir.yield } - cir.case(equal, [#cir.int<1> : !s32i]) { + cir.case(equal, [#cir.int<2> : !s32i]) { cir.yield } - cir.case(equal, [#cir.int<2> : !s32i]) { + cir.case(equal, [#cir.int<4> : !s32i]) { cir.break } - cir.case(equal, [#cir.int<3> : !s32i]) { + cir.case(equal, [#cir.int<1> : !s32i]) { cir.yield } - cir.case(equal, [#cir.int<4> : !s32i]) { + cir.case(equal, [#cir.int<3> : !s32i]) { cir.yield } cir.case(equal, [#cir.int<5> : !s32i]) { @@ -67,10 +67,10 @@ module { } //CHECK: @foldCascade2 //CHECK: cir.switch (%[[COND2:.*]] : !s32i) { - //CHECK: cir.case(anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> : !s32i]) { + //CHECK: cir.case(anyof, [#cir.int<0> : !s32i, #cir.int<2> : !s32i, #cir.int<4> : !s32i]) { //CHECK: cir.break //cehck: } - //CHECK: cir.case(anyof, [#cir.int<3> : !s32i, #cir.int<4> : !s32i, #cir.int<5> : !s32i]) { + //CHECK: cir.case(anyof, [#cir.int<1> : !s32i, #cir.int<3> : !s32i, #cir.int<5> : !s32i]) { //CHECK: cir.break //CHECK: } //CHECK: cir.yield >From e46e3abe21df1214b535fbe1ba7d1037b49b05e0 Mon Sep 17 00:00:00 2001 From: Andres Salamanca <andrealebarbari...@gmail.com> Date: Wed, 21 May 2025 11:31:09 -0500 Subject: [PATCH 3/3] Fix formatting --- clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp b/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp index af064c0800fbe..40716f2467563 100644 --- a/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp +++ b/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp @@ -236,7 +236,8 @@ struct SimplifySwitch : public OpRewritePattern<SwitchOp> { } else if (kind != cir::CaseOpKind::Equal && cascadingCases.size() > 1) { // If a Default, Anyof or Range case is found and there are previous // cascading cases, merge all of them into the last cascading case. - // We don't currently fold case range statements with other case statements. + // We don't currently fold case range statements with other case + // statements. assert(!cir::MissingFeatures::foldRangeCase()); CaseOp lastCascadingCase = cascadingCases.back(); mergeCascadingInto(lastCascadingCase); _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits