llvmbot wrote:

<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-clangir

Author: None (Andres-Salamanca)

<details>
<summary>Changes</summary>

This PR introduces a new **CIR simplify for `switch` cases**, which folds 
multiple **cascading `Equal` cases** (that contain only a `YieldOp`) into a 
single `CaseOp` of kind `AnyOf`.

This logic is based on the suggestion from this discussion:
https://github.com/llvm/llvm-project/pull/138003#discussion_r2070564458

---
Full diff: https://github.com/llvm/llvm-project/pull/140649.diff


4 Files Affected:

- (modified) clang/include/clang/CIR/MissingFeatures.h (-1) 
- (modified) clang/lib/CIR/CodeGen/CIRGenStmt.cpp (-6) 
- (modified) clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp (+104-2) 
- (added) clang/test/CIR/Transforms/switch-fold.cir (+196) 


``````````diff
diff --git a/clang/include/clang/CIR/MissingFeatures.h 
b/clang/include/clang/CIR/MissingFeatures.h
index 484822c351746..9f3e5d007d66c 100644
--- a/clang/include/clang/CIR/MissingFeatures.h
+++ b/clang/include/clang/CIR/MissingFeatures.h
@@ -114,7 +114,6 @@ struct MissingFeatures {
   static bool opUnaryPromotionType() { return false; }
 
   // SwitchOp handling
-  static bool foldCascadingCases() { return false; }
   static bool foldRangeCase() { return false; }
 
   // Clang early optimizations or things defered to LLVM lowering.
diff --git a/clang/lib/CIR/CodeGen/CIRGenStmt.cpp 
b/clang/lib/CIR/CodeGen/CIRGenStmt.cpp
index cc96e65e4ce1d..7f1ecbda414bd 100644
--- a/clang/lib/CIR/CodeGen/CIRGenStmt.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenStmt.cpp
@@ -531,12 +531,6 @@ mlir::LogicalResult CIRGenFunction::emitCaseStmt(const 
CaseStmt &s,
     value = builder.getArrayAttr({cir::IntAttr::get(condType, intVal),
                                   cir::IntAttr::get(condType, endVal)});
     kind = cir::CaseOpKind::Range;
-
-    // We don't currently fold case range statements with other case 
statements.
-    // TODO(cir): Add this capability. Folding these cases is going to be
-    // implemented in CIRSimplify when it is upstreamed.
-    assert(!cir::MissingFeatures::foldRangeCase());
-    assert(!cir::MissingFeatures::foldCascadingCases());
   } else {
     value = builder.getArrayAttr({cir::IntAttr::get(condType, intVal)});
     kind = cir::CaseOpKind::Equal;
diff --git a/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp 
b/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp
index b969569b0081c..58300cc219602 100644
--- a/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp
@@ -159,6 +159,107 @@ struct SimplifySelect : public OpRewritePattern<SelectOp> 
{
   }
 };
 
+/// Simplify `cir.switch` operations by folding cascading cases
+/// into a single `cir.case` with the `anyof` kind.
+///
+/// This pattern identifies cascading cases within a `cir.switch` operation.
+/// Cascading cases are defined as consecutive `cir.case` operations of kind
+/// `equal`, each containing a single `cir.yield` operation in their body.
+///
+/// The pattern merges these cascading cases into a single `cir.case` operation
+/// with kind `anyof`, aggregating all the case values.
+///
+/// The merging process continues until a `cir.case` with a different body
+/// (e.g., containing `cir.break` or compound stmt) is encountered, which
+/// breaks the chain.
+///
+/// Example:
+///
+/// Before:
+///   cir.case equal, [#cir.int<0> : !s32i] {
+///     cir.yield
+///   }
+///   cir.case equal, [#cir.int<1> : !s32i] {
+///     cir.yield
+///   }
+///   cir.case equal, [#cir.int<2> : !s32i] {
+///     cir.break
+///   }
+///
+/// After applying SimplifySwitch:
+///   cir.case anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> :
+///   !s32i] {
+///     cir.break
+///   }
+struct SimplifySwitch : public OpRewritePattern<SwitchOp> {
+  using OpRewritePattern<SwitchOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(SwitchOp op,
+                                PatternRewriter &rewriter) const override {
+
+    LogicalResult changed = mlir::failure();
+    llvm::SmallVector<CaseOp, 8> cases;
+    SmallVector<CaseOp, 4> cascadingCases;
+    SmallVector<mlir::Attribute, 4> cascadingCaseValues;
+
+    op.collectCases(cases);
+    if (cases.empty())
+      return mlir::failure();
+
+    auto flushMergedOps = [&]() {
+      for (CaseOp &c : cascadingCases) {
+        rewriter.eraseOp(c);
+      }
+      cascadingCases.clear();
+      cascadingCaseValues.clear();
+    };
+
+    auto mergeCascadingInto = [&](CaseOp &target) {
+      rewriter.modifyOpInPlace(target, [&]() {
+        target.setValueAttr(rewriter.getArrayAttr(cascadingCaseValues));
+        target.setKind(CaseOpKind::Anyof);
+      });
+      changed = mlir::success();
+    };
+
+    for (CaseOp c : cases) {
+      cir::CaseOpKind kind = c.getKind();
+      if (kind == cir::CaseOpKind::Equal &&
+          isa<YieldOp>(c.getCaseRegion().front().front())) {
+        // If the case contains only a YieldOp, collect it for cascading merge
+        cascadingCases.push_back(c);
+        cascadingCaseValues.push_back(c.getValue()[0]);
+
+      } else if (kind == cir::CaseOpKind::Equal && !cascadingCases.empty()) {
+        // merge previously collected cascading cases
+        cascadingCaseValues.push_back(c.getValue()[0]);
+        mergeCascadingInto(c);
+        flushMergedOps();
+      } else if (kind != cir::CaseOpKind::Equal && cascadingCases.size() > 1) {
+        // If a Default, Anyof or Range case is found and there are previous
+        // cascading cases, merge all of them into the last cascading case.
+        CaseOp lastCascadingCase = cascadingCases.back();
+        mergeCascadingInto(lastCascadingCase);
+        cascadingCases.pop_back();
+        flushMergedOps();
+      } else {
+        cascadingCases.clear();
+        cascadingCaseValues.clear();
+      }
+    }
+
+    // Edge case: all cases are simple cascading cases
+    if (cascadingCases.size() == cases.size()) {
+      CaseOp lastCascadingCase = cascadingCases.back();
+      mergeCascadingInto(lastCascadingCase);
+      cascadingCases.pop_back();
+      flushMergedOps();
+    }
+    // We don't currently fold case range statements with other case 
statements.
+    assert(!cir::MissingFeatures::foldRangeCase());
+    return changed;
+  }
+};
+
 
//===----------------------------------------------------------------------===//
 // CIRSimplifyPass
 
//===----------------------------------------------------------------------===//
@@ -173,7 +274,8 @@ void populateMergeCleanupPatterns(RewritePatternSet 
&patterns) {
   // clang-format off
   patterns.add<
     SimplifyTernary,
-    SimplifySelect
+    SimplifySelect,
+    SimplifySwitch
   >(patterns.getContext());
   // clang-format on
 }
@@ -186,7 +288,7 @@ void CIRSimplifyPass::runOnOperation() {
   // Collect operations to apply patterns.
   llvm::SmallVector<Operation *, 16> ops;
   getOperation()->walk([&](Operation *op) {
-    if (isa<TernaryOp, SelectOp>(op))
+    if (isa<TernaryOp, SelectOp, SwitchOp>(op))
       ops.push_back(op);
   });
 
diff --git a/clang/test/CIR/Transforms/switch-fold.cir 
b/clang/test/CIR/Transforms/switch-fold.cir
new file mode 100644
index 0000000000000..3c2fe8a9cbf25
--- /dev/null
+++ b/clang/test/CIR/Transforms/switch-fold.cir
@@ -0,0 +1,196 @@
+// RUN: cir-opt -cir-canonicalize -cir-simplify -o %t.cir %s
+// RUN: FileCheck --input-file=%t.cir %s
+
+!s32i = !cir.int<s, 32>
+
+module {
+    cir.func @foldCascade(%arg0: !s32i) {
+    %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64}
+    cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
+    cir.scope {
+      %1 = cir.load %0 : !cir.ptr<!s32i>, !s32i
+      cir.switch (%1 : !s32i) {
+        cir.case(equal, [#cir.int<1> : !s32i]) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<2> : !s32i]) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<3> : !s32i]) {
+          %2 = cir.const #cir.int<2> : !s32i
+          cir.store %2, %0 : !s32i, !cir.ptr<!s32i>
+          cir.break
+        }
+        cir.yield
+      }
+    }
+    cir.return
+  }
+  //CHECK: cir.func @foldCascade
+  //CHECK:   cir.switch (%[[COND:.*]] : !s32i) {
+  //CHECK-NEXT:     cir.case(anyof, [#cir.int<1> : !s32i, #cir.int<2> : !s32i, 
#cir.int<3> : !s32i]) {
+  //CHECK-NEXT:       %[[TWO:.*]] = cir.const #cir.int<2> : !s32i
+  //CHECK-NEXT:       cir.store %[[TWO]], %[[ARG0:.*]] : !s32i, !cir.ptr<!s32i>
+  //CHECK-NEXT:       cir.break
+  //CHECK-NEXT:     }
+  //CHECK-NEXT:     cir.yield
+  //CHECK-NEXT:   }
+
+    cir.func @foldCascade2(%arg0: !s32i) {
+    %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64}
+    cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
+    cir.scope {
+      %1 = cir.load %0 : !cir.ptr<!s32i>, !s32i
+      cir.switch (%1 : !s32i) {
+        cir.case(equal, [#cir.int<0> : !s32i]) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<1> : !s32i]) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<2> : !s32i]) {
+          cir.break
+        }
+        cir.case(equal, [#cir.int<3> : !s32i]) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<4> : !s32i]) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<5> : !s32i]) {
+          cir.break
+        }
+        cir.yield
+      }
+    }
+    cir.return
+  }
+  //CHECK: @foldCascade2
+  //CHECK:   cir.switch (%[[COND2:.*]] : !s32i) {
+  //CHECK:     cir.case(anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, 
#cir.int<2> : !s32i]) {
+  //CHECK:       cir.break
+  //cehck:     }
+  //CHECK:     cir.case(anyof, [#cir.int<3> : !s32i, #cir.int<4> : !s32i, 
#cir.int<5> : !s32i]) {
+  //CHECK:       cir.break
+  //CHECK:     }
+  //CHECK:     cir.yield
+  //CHECK:   }
+  cir.func @foldCascade3(%arg0: !s32i ) {
+    %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64}
+    cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
+    cir.scope {
+      %1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["x"] {alignment = 4 : i64}
+      %2 = cir.load %0 : !cir.ptr<!s32i>, !s32i
+      cir.switch (%2 : !s32i) {
+        cir.case(equal, [#cir.int<0> : !s32i]) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<1> : !s32i]) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<2> : !s32i]) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<3> : !s32i]) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<4> : !s32i]) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<5> : !s32i]) {
+          cir.break
+        }
+        cir.yield
+      }
+    }
+    cir.return
+  }
+  //CHECK: cir.func @foldCascade3
+  //CHECK:   cir.switch (%[[COND3:.*]] : !s32i) {
+  //CHECK:     cir.case(anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, 
#cir.int<2> : !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i, #cir.int<5> : 
!s32i]) {
+  //CHECK:       cir.break
+  //CHECK:    }
+  //CHECK:    cir.yield
+  //CHECK:   }
+  cir.func @foldCascadeWithDefault(%arg0: !s32i ) {
+    %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64}
+    cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
+    cir.scope {
+      %1 = cir.load %0 : !cir.ptr<!s32i>, !s32i
+      cir.switch (%1 : !s32i) {
+        cir.case(equal, [#cir.int<3> : !s32i]) {
+          cir.break
+        }
+        cir.case(equal, [#cir.int<4> : !s32i]) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<5> : !s32i]) {
+          cir.yield
+        }
+        cir.case(default, []) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<6> : !s32i]) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<7> : !s32i]) {
+          cir.break
+        }
+        cir.yield
+      }
+    }
+    cir.return
+  }
+  //CHECK: cir.func @foldCascadeWithDefault
+  //CHECK:   cir.switch (%[[COND:.*]] : !s32i) {
+  //CHECK:      cir.case(equal, [#cir.int<3> : !s32i]) {
+  //CHECK:        cir.break
+  //CHECK:      }
+  //CHECK:      cir.case(anyof, [#cir.int<4> : !s32i, #cir.int<5> : !s32i]) {
+  //CHECK:        cir.yield
+  //CHECK:      }
+  //CHECK:      cir.case(default, []) {
+  //CHECK:        cir.yield
+  //CHECK:      }
+  //CHECK:      cir.case(anyof, [#cir.int<6> : !s32i, #cir.int<7> : !s32i]) {
+  //CHECK:        cir.break
+  //CHECK:      }
+  //CHECK:      cir.yield
+  //CHECK:   }
+  cir.func @foldAllCascade(%arg0: !s32i ) {
+    %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64}
+    cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
+    cir.scope {
+      %1 = cir.load %0 : !cir.ptr<!s32i>, !s32i
+      cir.switch (%1 : !s32i) {
+        cir.case(equal, [#cir.int<0> : !s32i]) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<1> : !s32i]) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<2> : !s32i]) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<3> : !s32i]) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<4> : !s32i]) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<5> : !s32i]) {
+          cir.yield
+        }
+        cir.yield
+      }
+    }
+    cir.return
+  }
+  //CHECK: cir.func @foldAllCascade
+  //CHECK:   cir.switch (%[[COND:.*]] : !s32i) {
+  //CHECK:     cir.case(anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, 
#cir.int<2> : !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i, #cir.int<5> : 
!s32i]) {
+  //CHECK:       cir.yield
+  //CHECK:     }
+  //CHECK:     cir.yield
+  //CHECK:   }
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/140649
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to