llvmbot wrote:

<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-clangir

Author: None (Andres-Salamanca)

<details>
<summary>Changes</summary>

This PR upstreams the GotoSolver pass.  
It works by walking the function and matching each label to a goto.  If a label 
is not matched to a goto, it is removed and not lowered.  


---
Full diff: https://github.com/llvm/llvm-project/pull/154596.diff


8 Files Affected:

- (modified) clang/include/clang/CIR/Dialect/Passes.h (+1) 
- (modified) clang/include/clang/CIR/Dialect/Passes.td (+10) 
- (modified) clang/lib/CIR/Dialect/Transforms/CMakeLists.txt (+1) 
- (added) clang/lib/CIR/Dialect/Transforms/GotoSolver.cpp (+52) 
- (modified) clang/lib/CIR/Lowering/CIRPasses.cpp (+1) 
- (modified) clang/test/CIR/CodeGen/goto.cpp (+95) 
- (modified) clang/test/CIR/CodeGen/label.c (+49-2) 
- (added) clang/test/CIR/Lowering/goto.cir (+52) 


``````````diff
diff --git a/clang/include/clang/CIR/Dialect/Passes.h 
b/clang/include/clang/CIR/Dialect/Passes.h
index 7a202b1e04ef9..32c3e27d07dfb 100644
--- a/clang/include/clang/CIR/Dialect/Passes.h
+++ b/clang/include/clang/CIR/Dialect/Passes.h
@@ -26,6 +26,7 @@ std::unique_ptr<Pass> createCIRSimplifyPass();
 std::unique_ptr<Pass> createHoistAllocasPass();
 std::unique_ptr<Pass> createLoweringPreparePass();
 std::unique_ptr<Pass> createLoweringPreparePass(clang::ASTContext *astCtx);
+std::unique_ptr<Pass> createGotoSolverPass();
 
 void populateCIRPreLoweringPasses(mlir::OpPassManager &pm);
 
diff --git a/clang/include/clang/CIR/Dialect/Passes.td 
b/clang/include/clang/CIR/Dialect/Passes.td
index 7d5ec2ffed39d..0f5783945f8ae 100644
--- a/clang/include/clang/CIR/Dialect/Passes.td
+++ b/clang/include/clang/CIR/Dialect/Passes.td
@@ -72,6 +72,16 @@ def CIRFlattenCFG : Pass<"cir-flatten-cfg"> {
   let dependentDialects = ["cir::CIRDialect"];
 }
 
+def GotoSolver : Pass<"cir-goto-solver"> {
+  let summary = "Replaces goto operations with branches";
+  let description = [{
+    This pass transforms CIR and replaces goto-s with branch
+    operations to the proper blocks.
+  }];
+  let constructor = "mlir::createGotoSolverPass()";
+  let dependentDialects = ["cir::CIRDialect"];
+}
+
 def LoweringPrepare : Pass<"cir-lowering-prepare"> {
   let summary = "Lower to more fine-grained CIR operations before lowering to "
     "other dialects";
diff --git a/clang/lib/CIR/Dialect/Transforms/CMakeLists.txt 
b/clang/lib/CIR/Dialect/Transforms/CMakeLists.txt
index 18beca7b9a680..df7a1a3e0acb5 100644
--- a/clang/lib/CIR/Dialect/Transforms/CMakeLists.txt
+++ b/clang/lib/CIR/Dialect/Transforms/CMakeLists.txt
@@ -4,6 +4,7 @@ add_clang_library(MLIRCIRTransforms
   FlattenCFG.cpp
   HoistAllocas.cpp
   LoweringPrepare.cpp
+  GotoSolver.cpp
 
   DEPENDS
   MLIRCIRPassIncGen
diff --git a/clang/lib/CIR/Dialect/Transforms/GotoSolver.cpp 
b/clang/lib/CIR/Dialect/Transforms/GotoSolver.cpp
new file mode 100644
index 0000000000000..e1c47a1ce16f1
--- /dev/null
+++ b/clang/lib/CIR/Dialect/Transforms/GotoSolver.cpp
@@ -0,0 +1,52 @@
+#include "PassDetail.h"
+#include "clang/CIR/Dialect/IR/CIRDialect.h"
+#include "clang/CIR/Dialect/Passes.h"
+#include "llvm/Support/TimeProfiler.h"
+#include <memory>
+
+using namespace mlir;
+using namespace cir;
+
+namespace {
+
+struct GotoSolverPass : public GotoSolverBase<GotoSolverPass> {
+
+  GotoSolverPass() = default;
+  void runOnOperation() override;
+};
+
+static void process(cir::FuncOp func) {
+
+  mlir::OpBuilder rewriter(func.getContext());
+  llvm::StringMap<Block *> labels;
+  llvm::SmallVector<cir::GotoOp, 4> gotos;
+
+  func.getBody().walk([&](mlir::Operation *op) {
+    if (auto lab = dyn_cast<cir::LabelOp>(op)) {
+      // Will construct a string copy inplace. Safely erase the label
+      labels.try_emplace(lab.getLabel(), lab->getBlock());
+      lab.erase();
+    } else if (auto goTo = dyn_cast<cir::GotoOp>(op)) {
+      gotos.push_back(goTo);
+    }
+  });
+
+  for (auto goTo : gotos) {
+    mlir::OpBuilder::InsertionGuard guard(rewriter);
+    rewriter.setInsertionPoint(goTo);
+    Block *dest = labels[goTo.getLabel()];
+    rewriter.create<cir::BrOp>(goTo.getLoc(), dest);
+    goTo.erase();
+  }
+}
+
+void GotoSolverPass::runOnOperation() {
+  llvm::TimeTraceScope scope("Goto Solver");
+  getOperation()->walk([&](cir::FuncOp op) { process(op); });
+}
+
+} // namespace
+
+std::unique_ptr<Pass> mlir::createGotoSolverPass() {
+  return std::make_unique<GotoSolverPass>();
+}
diff --git a/clang/lib/CIR/Lowering/CIRPasses.cpp 
b/clang/lib/CIR/Lowering/CIRPasses.cpp
index bb9781be897eb..ccc838717e421 100644
--- a/clang/lib/CIR/Lowering/CIRPasses.cpp
+++ b/clang/lib/CIR/Lowering/CIRPasses.cpp
@@ -45,6 +45,7 @@ namespace mlir {
 void populateCIRPreLoweringPasses(OpPassManager &pm) {
   pm.addPass(createHoistAllocasPass());
   pm.addPass(createCIRFlattenCFGPass());
+  pm.addPass(createGotoSolverPass());
 }
 
 } // namespace mlir
diff --git a/clang/test/CIR/CodeGen/goto.cpp b/clang/test/CIR/CodeGen/goto.cpp
index 13ca65344a150..48cb44ed0f478 100644
--- a/clang/test/CIR/CodeGen/goto.cpp
+++ b/clang/test/CIR/CodeGen/goto.cpp
@@ -1,5 +1,7 @@
 // RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir %s -o 
%t.cir
 // RUN: FileCheck --input-file=%t.cir %s -check-prefix=CIR
+// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-llvm %s -o 
%t-cir.ll
+// RUN: FileCheck --input-file=%t-cir.ll %s --check-prefix=LLVM
 // RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -emit-llvm %s -o %t.ll
 // RUN: FileCheck --input-file=%t.ll %s --check-prefix=OGCG
 
@@ -27,6 +29,24 @@ int shouldNotGenBranchRet(int x) {
 // CIR:    cir.store [[MINUS]], [[RETVAL]] : !s32i, !cir.ptr<!s32i>
 // CIR:    cir.br ^bb1
 
+// LLVM: define dso_local i32 @_Z21shouldNotGenBranchReti
+// LLVM:   [[COND:%.*]] = load i32, ptr {{.*}}, align 4
+// LLVM:   [[CMP:%.*]] = icmp sgt i32 [[COND]], 5
+// LLVM:   br i1 [[CMP]], label %[[IFTHEN:.*]], label %[[IFEND:.*]]
+// LLVM: [[IFTHEN]]:
+// LLVM:   br label %[[ERR:.*]]
+// LLVM: [[IFEND]]:
+// LLVM:   br label %[[BB9:.*]]
+// LLVM: [[BB9]]:
+// LLVM:   store i32 0, ptr %[[RETVAL:.*]], align 4
+// LLVM:   br label %[[BBRET:.*]]
+// LLVM: [[BBRET]]:
+// LLVM:   [[RET:%.*]] = load i32, ptr %[[RETVAL]], align 4
+// LLVM:   ret i32 [[RET]]
+// LLVM: [[ERR]]:
+// LLVM:   store i32 -1, ptr %[[RETVAL]], align 4
+// LLVM:   br label %10
+
 // OGCG: define dso_local noundef i32 @_Z21shouldNotGenBranchReti
 // OGCG: if.then:
 // OGCG:   br label %err
@@ -51,6 +71,17 @@ int shouldGenBranch(int x) {
 // CIR:  ^bb1:
 // CIR:    cir.label "err"
 
+// LLVM: define dso_local i32 @_Z15shouldGenBranchi
+// LLVM:   br i1 [[CMP:%.*]], label %[[IFTHEN:.*]], label %[[IFEND:.*]]
+// LLVM: [[IFTHEN]]:
+// LLVM:   br label %[[ERR:.*]]
+// LLVM: [[IFEND]]:
+// LLVM:   br label %[[BB9:.*]]
+// LLVM: [[BB9]]:
+// LLVM:   br label %[[ERR]]
+// LLVM: [[ERR]]:
+// LLVM:   ret i32 [[RET:%.*]]
+
 // OGCG: define dso_local noundef i32 @_Z15shouldGenBranchi
 // OGCG: if.then:
 // OGCG:   br label %err
@@ -78,6 +109,15 @@ void severalLabelsInARow(int a) {
 // CIR:  ^bb[[#BLK3]]:
 // CIR:    cir.label "end2"
 
+// LLVM: define dso_local void @_Z19severalLabelsInARowi
+// LLVM:   br label %[[END1:.*]]
+// LLVM: [[UNRE:.*]]:                                                ; No 
predecessors!
+// LLVM:   br label %[[END2:.*]]
+// LLVM: [[END1]]:
+// LLVM:   br label %[[END2]]
+// LLVM: [[END2]]:
+// LLVM:   ret
+
 // OGCG: define dso_local void @_Z19severalLabelsInARowi
 // OGCG:   br label %end1
 // OGCG: end1:
@@ -99,6 +139,13 @@ void severalGotosInARow(int a) {
 // CIR:  ^bb[[#BLK2:]]:
 // CIR:    cir.label "end"
 
+// LLVM: define dso_local void @_Z18severalGotosInARowi
+// LLVM:   br label %[[END:.*]]
+// LLVM: [[UNRE:.*]]:                                                ; No 
predecessors!
+// LLVM:   br label %[[END]]
+// LLVM: [[END]]:
+// LLVM:   ret void
+
 // OGCG: define dso_local void @_Z18severalGotosInARowi(i32 noundef %a) #0 {
 // OGCG:   br label %end
 // OGCG: end:
@@ -126,6 +173,14 @@ extern "C" void multiple_non_case(int v) {
 // CIR: cir.call @action2()
 // CIR: cir.break
 
+// LLVM: define dso_local void @multiple_non_case
+// LLVM: [[SWDEFAULT:.*]]:
+// LLVM:   call void @action1()
+// LLVM:   br label %[[L2:.*]]
+// LLVM: [[L2]]:
+// LLVM:   call void @action2()
+// LLVM:   br label %[[BREAK:.*]]
+
 // OGCG: define dso_local void @multiple_non_case
 // OGCG: sw.default:
 // OGCG:   call void @action1()
@@ -158,6 +213,26 @@ extern "C" void case_follow_label(int v) {
 // CIR:   cir.call @action2()
 // CIR:   cir.goto "label"
 
+// LLVM: define dso_local void @case_follow_label
+// LLVM:  switch i32 {{.*}}, label %[[SWDEFAULT:.*]] [
+// LLVM:    i32 1, label %[[LABEL:.*]]
+// LLVM:    i32 2, label %[[CASE2:.*]]
+// LLVM:  ]
+// LLVM: [[LABEL]]:
+// LLVM:   br label %[[CASE2]]
+// LLVM: [[CASE2]]:
+// LLVM:   call void @action1()
+// LLVM:   br label %[[BREAK:.*]]
+// LLVM: [[BREAK]]:
+// LLVM:   br label %[[END:.*]]
+// LLVM: [[SWDEFAULT]]:
+// LLVM:   call void @action2()
+// LLVM:   br label %[[LABEL]]
+// LLVM: [[END]]:
+// LLVM:   br label %[[RET:.*]]
+// LLVM: [[RET]]:
+// LLVM:   ret void
+
 // OGCG: define dso_local void @case_follow_label
 // OGCG: sw.bb:
 // OGCG:   br label %label
@@ -197,6 +272,26 @@ extern "C" void default_follow_label(int v) {
 // CIR:   cir.call @action2()
 // CIR:   cir.goto "label"
 
+// LLVM: define dso_local void @default_follow_label
+// LLVM: [[CASE1:.*]]:
+// LLVM:   br label %[[BB8:.*]]
+// LLVM: [[BB8]]:
+// LLVM:   br label %[[CASE2:.*]]
+// LLVM: [[CASE2]]:
+// LLVM:   call void @action1()
+// LLVM:   br label %[[BREAK:.*]]
+// LLVM: [[LABEL:.*]]:
+// LLVM:   br label %[[SWDEFAULT:.*]]
+// LLVM: [[SWDEFAULT]]:
+// LLVM:   call void @action2()
+// LLVM:   br label %[[BB9:.*]]
+// LLVM: [[BB9]]:
+// LLVM:   br label %[[LABEL]]
+// LLVM: [[BREAK]]:
+// LLVM:   br label %[[RET:.*]]
+// LLVM: [[RET]]:
+// LLVM:   ret void
+
 // OGCG: define dso_local void @default_follow_label
 // OGCG: sw.bb:
 // OGCG:   call void @action1()
diff --git a/clang/test/CIR/CodeGen/label.c b/clang/test/CIR/CodeGen/label.c
index 797c44475a621..a050094de678b 100644
--- a/clang/test/CIR/CodeGen/label.c
+++ b/clang/test/CIR/CodeGen/label.c
@@ -1,5 +1,7 @@
 // RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir %s -o 
%t.cir
 // RUN: FileCheck --input-file=%t.cir %s --check-prefix=CIR
+// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-llvm %s -o 
%t-cir.ll
+// RUN: FileCheck --input-file=%t-cir.ll %s --check-prefix=LLVM
 // RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -emit-llvm %s -o %t.ll
 // RUN: FileCheck --input-file=%t.ll %s --check-prefix=OGCG
 
@@ -12,8 +14,8 @@ void label() {
 // CIR:    cir.label "labelA"
 // CIR:    cir.return
 
-// Note: We are not lowering to LLVM IR via CIR at this stage because that
-// process depends on the GotoSolver.
+// LLVM:define dso_local void @label
+// LLVM:  ret void
 
 // OGCG: define dso_local void @label
 // OGCG:   br label %labelA
@@ -33,6 +35,11 @@ void multiple_labels() {
 // CIR:    cir.label "labelC"
 // CIR:    cir.return
 
+// LLVM: define dso_local void @multiple_labels()
+// LLVM:   br label %1
+// LLVM: 1:
+// LLVM:   ret void
+
 // OGCG: define dso_local void @multiple_labels
 // OGCG:   br label %labelB
 // OGCG: labelB:
@@ -56,6 +63,22 @@ void label_in_if(int cond) {
 // CIR:      }
 // CIR:    cir.return
 
+// LLVM: define dso_local void @label_in_if
+// LLVM:   br label %3
+// LLVM: 3:
+// LLVM:   [[LOAD:%.*]] = load i32, ptr [[COND:%.*]], align 4
+// LLVM:   [[CMP:%.*]] = icmp ne i32 [[LOAD]], 0
+// LLVM:   br i1 [[CMP]], label %6, label %9
+// LLVM: 6:
+// LLVM:   [[LOAD2:%.*]] = load i32, ptr [[COND]], align 4
+// LLVM:   [[ADD1:%.*]] = add nsw i32 [[LOAD2]], 1
+// LLVM:   store i32 [[ADD1]], ptr [[COND]], align 4
+// LLVM:   br label %9
+// LLVM: 9:
+// LLVM:   br label %10
+// LLVM: 10:
+// LLVM:  ret void
+
 // OGCG: define dso_local void @label_in_if
 // OGCG: if.then:
 // OGCG:   br label %labelD
@@ -80,6 +103,13 @@ void after_return() {
 // CIR:    cir.label "label"
 // CIR:    cir.br ^bb1
 
+// LLVM: define dso_local void @after_return
+// LLVM:   br label %1
+// LLVM: 1:
+// LLVM:   ret void
+// LLVM: 2:
+// LLVM:   br label %1
+
 // OGCG: define dso_local void @after_return
 // OGCG:   br label %label
 // OGCG: label:
@@ -97,6 +127,11 @@ void after_unreachable() {
 // CIR:    cir.label "label"
 // CIR:    cir.return
 
+// LLVM: define dso_local void @after_unreachable
+// LLVM:   unreachable
+// LLVM: 1:
+// LLVM:   ret void
+
 // OGCG: define dso_local void @after_unreachable
 // OGCG:   unreachable
 // OGCG: label:
@@ -111,6 +146,9 @@ void labelWithoutMatch() {
 // CIR:    cir.return
 // CIR:  }
 
+// LLVM: define dso_local void @labelWithoutMatch
+// LLVM:   ret void
+
 // OGCG: define dso_local void @labelWithoutMatch
 // OGCG:   br label %end
 // OGCG: end:
@@ -132,6 +170,15 @@ void foo() {
 // CIR:     cir.label "label"
 // CIR:     %0 = cir.alloca !rec_S, !cir.ptr<!rec_S>, ["agg.tmp0"]
 
+// LLVM:define dso_local void @foo() {
+// LLVM:  [[ALLOC:%.*]] = alloca %struct.S, i64 1, align 1
+// LLVM:  br label %2
+// LLVM:2:
+// LLVM:  [[CALL:%.*]] = call %struct.S @get()
+// LLVM:  store %struct.S [[CALL]], ptr [[ALLOC]], align 1
+// LLVM:  [[LOAD:%.*]] = load %struct.S, ptr [[ALLOC]], align 1
+// LLVM:  call void @bar(%struct.S [[LOAD]])
+
 // OGCG: define dso_local void @foo()
 // OGCG:   %agg.tmp = alloca %struct.S, align 1
 // OGCG:   %undef.agg.tmp = alloca %struct.S, align 1
diff --git a/clang/test/CIR/Lowering/goto.cir b/clang/test/CIR/Lowering/goto.cir
new file mode 100644
index 0000000000000..cd3a57d2e7138
--- /dev/null
+++ b/clang/test/CIR/Lowering/goto.cir
@@ -0,0 +1,52 @@
+// RUN: cir-opt %s 
--pass-pipeline='builtin.module(cir-to-llvm,canonicalize{region-simplify=disabled})'
 -o - | FileCheck %s -check-prefix=MLIR
+
+!s32i = !cir.int<s, 32>
+
+module {
+
+  cir.func @gotoFromIf(%arg0: !s32i) -> !s32i {
+    %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["x", init] {alignment = 4 : i64}
+    %1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["__retval"] {alignment = 4 : i64}
+    cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
+    cir.scope {
+      %6 = cir.load %0 : !cir.ptr<!s32i>, !s32i
+      %7 = cir.const #cir.int<5> : !s32i
+      %8 = cir.cmp(gt, %6, %7) : !s32i, !cir.bool
+      cir.if %8 {
+        cir.goto "err"
+      }
+    }
+    %2 = cir.const #cir.int<0> : !s32i
+    cir.store %2, %1 : !s32i, !cir.ptr<!s32i>
+    cir.br ^bb1
+  ^bb1:
+    %3 = cir.load %1 : !cir.ptr<!s32i>, !s32i
+    cir.return %3 : !s32i
+  ^bb2:
+    cir.label "err"
+    %4 = cir.const #cir.int<1> : !s32i
+    %5 = cir.unary(minus, %4) : !s32i, !s32i
+    cir.store %5, %1 : !s32i, !cir.ptr<!s32i>
+    cir.br ^bb1
+  }
+
+// MLIR:  llvm.func @gotoFromIf
+// MLIR:    %[[#One:]] = llvm.mlir.constant(1 : i32) : i32
+// MLIR:    %[[#Zero:]] = llvm.mlir.constant(0 : i32) : i32
+// MLIR:    llvm.cond_br {{.*}}, ^bb[[#COND_YES:]], ^bb[[#COND_NO:]]
+// MLIR:  ^bb[[#COND_YES]]:
+// MLIR:    llvm.br ^bb[[#GOTO_BLK:]]
+// MLIR:   ^bb[[#COND_NO]]:
+// MLIR:    llvm.br ^bb[[#BLK:]]
+// MLIR:  ^bb[[#BLK]]:
+// MLIR:    llvm.store %[[#Zero]], %[[#Ret_val_addr:]] {{.*}}: i32, !llvm.ptr
+// MLIR:    llvm.br ^bb[[#RETURN:]]
+// MLIR:  ^bb[[#RETURN]]:
+// MLIR:    %[[#Ret_val:]] = llvm.load %[[#Ret_val_addr]] {alignment = 4 : 
i64} : !llvm.ptr -> i32
+// MLIR:    llvm.return %[[#Ret_val]] : i32
+// MLIR:  ^bb[[#GOTO_BLK]]:
+// MLIR:    %[[#Neg_one:]] = llvm.sub %[[#Zero]], %[[#One]]  : i32
+// MLIR:    llvm.store %[[#Neg_one]], %[[#Ret_val_addr]] {{.*}}: i32, !llvm.ptr
+// MLIR:    llvm.br ^bb[[#RETURN]]
+// MLIR: }
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/154596
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to