https://github.com/mtrofin updated 
https://github.com/llvm/llvm-project/pull/154841

>From 24507f71fa438e7e799560aa1d6f370a3a8a7c43 Mon Sep 17 00:00:00 2001
From: Mircea Trofin <mtro...@google.com>
Date: Thu, 21 Aug 2025 13:54:49 -0700
Subject: [PATCH] [SimplifyCFG] Set branch weights when merging conditional
 store to address

---
 llvm/include/llvm/IR/ProfDataUtils.h          | 27 +++++++++++++++++++
 llvm/lib/Transforms/Utils/SimplifyCFG.cpp     | 16 +++++++++++
 .../SimplifyCFG/merge-cond-stores.ll          | 14 +++++++---
 3 files changed, 53 insertions(+), 4 deletions(-)

diff --git a/llvm/include/llvm/IR/ProfDataUtils.h 
b/llvm/include/llvm/IR/ProfDataUtils.h
index 404875285beae..ebf8559cd3d91 100644
--- a/llvm/include/llvm/IR/ProfDataUtils.h
+++ b/llvm/include/llvm/IR/ProfDataUtils.h
@@ -15,6 +15,7 @@
 #ifndef LLVM_IR_PROFDATAUTILS_H
 #define LLVM_IR_PROFDATAUTILS_H
 
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/Twine.h"
 #include "llvm/IR/Metadata.h"
@@ -186,5 +187,31 @@ LLVM_ABI bool hasExplicitlyUnknownBranchWeights(const 
Instruction &I);
 /// Scaling the profile data attached to 'I' using the ratio of S/T.
 LLVM_ABI void scaleProfData(Instruction &I, uint64_t S, uint64_t T);
 
+/// get the branch weights of a branch conditioned on b1 || b2, where b1 and b2
+/// are 2 booleans that are the condition of 2 branches for which we have the
+/// branch weights B1 and B2, respectivelly.
+inline SmallVector<uint64_t, 2>
+getDisjunctionWeights(const SmallVector<uint32_t, 2> &B1,
+                      const SmallVector<uint32_t, 2> &B2) {
+  // for the first conditional branch, the probability the "true" case is taken
+  // is p(b1) = B1[0] / (B1[0] + B2[0]). The "false" case's probability is
+  // p(not b1) = B1[1] / (B1[0] + B1[1]).
+  // Similarly for the second conditional branch and B2.
+  //
+  // the probability of the new branch NOT being taken is:
+  // not P = p((not b1) and (not b2)) =
+  //       = B1[1] / (B1[0]+B1[1]) * B2[1] / (B2[0]+B2[1]) =
+  //       = B1[1] * B2[1] / (B1[0] + B1[1]) * (B2[0] + B2[1])
+  // then the probability of it being taken is: P = 1 - (not P).
+  // The denominator will be the same as above, and the numerator of P will be
+  // (B1[0] + B1[1]) * (B2[0] + B2[1]) - B1[1]*B2[1]
+  // Which then reduces to what's shown below (out of the 4 terms coming out of
+  // the product of sums, the subtracted one cancels out)
+  assert(B1.size() == 2);
+  assert(B2.size() == 2);
+  auto FalseWeight = B1[1] * B2[1];
+  auto TrueWeight = B1[0] * B2[0] + B1[0] * B2[1] + B1[1] * B2[0];
+  return {TrueWeight, FalseWeight};
+}
 } // namespace llvm
 #endif
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp 
b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index 270598e2b674b..370b282d1b14d 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -203,6 +203,8 @@ static cl::opt<unsigned> MaxJumpThreadingLiveBlocks(
     cl::desc("Limit number of blocks a define in a threaded block is allowed "
              "to be live in"));
 
+extern cl::opt<bool> ProfcheckDisableMetadataFixes;
+
 STATISTIC(NumBitMaps, "Number of switch instructions turned into bitmaps");
 STATISTIC(NumLinearMaps,
           "Number of switch instructions turned into linear mapping");
@@ -4431,6 +4433,20 @@ static bool mergeConditionalStoreToAddress(
   auto *T = SplitBlockAndInsertIfThen(CombinedPred, InsertPt,
                                       /*Unreachable=*/false,
                                       /*BranchWeights=*/nullptr, DTU);
+  if (hasBranchWeightMD(*PBranch) && hasBranchWeightMD(*QBranch) &&
+      !ProfcheckDisableMetadataFixes) {
+    SmallVector<uint32_t, 2> PWeights, QWeights;
+    extractBranchWeights(*PBranch, PWeights);
+    extractBranchWeights(*QBranch, QWeights);
+    if (InvertPCond)
+      std::swap(PWeights[0], PWeights[1]);
+    if (InvertQCond)
+      std::swap(QWeights[0], QWeights[1]);
+    auto CombinedWeights = getDisjunctionWeights(PWeights, QWeights);
+    setBranchWeights(PostBB->getTerminator(), CombinedWeights[0],
+                     CombinedWeights[1],
+                     /*IsExpected=*/false);
+  }
 
   QB.SetInsertPoint(T);
   StoreInst *SI = cast<StoreInst>(QB.CreateStore(QPHI, Address));
diff --git a/llvm/test/Transforms/SimplifyCFG/merge-cond-stores.ll 
b/llvm/test/Transforms/SimplifyCFG/merge-cond-stores.ll
index b5c4b8aa51db4..ee723463d4b06 100644
--- a/llvm/test/Transforms/SimplifyCFG/merge-cond-stores.ll
+++ b/llvm/test/Transforms/SimplifyCFG/merge-cond-stores.ll
@@ -1,4 +1,4 @@
-; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py 
UTC_ARGS: --check-globals
 ; RUN: opt -passes=simplifycfg,instcombine 
-simplifycfg-require-and-preserve-domtree=1 < %s 
-simplifycfg-merge-cond-stores=true 
-simplifycfg-merge-cond-stores-aggressively=false -phi-node-folding-threshold=2 
-S | FileCheck %s
 
 ; This test should succeed and end up if-converted.
@@ -43,7 +43,7 @@ define void @test_simple_commuted(ptr %p, i32 %a, i32 %b) {
 ; CHECK-NEXT:    [[X1_NOT:%.*]] = icmp eq i32 [[A:%.*]], 0
 ; CHECK-NEXT:    [[X2:%.*]] = icmp eq i32 [[B:%.*]], 0
 ; CHECK-NEXT:    [[TMP0:%.*]] = or i1 [[X1_NOT]], [[X2]]
-; CHECK-NEXT:    br i1 [[TMP0]], label [[TMP1:%.*]], label [[TMP2:%.*]]
+; CHECK-NEXT:    br i1 [[TMP0]], label [[TMP1:%.*]], label [[TMP2:%.*]], !prof 
[[PROF0:![0-9]+]]
 ; CHECK:       1:
 ; CHECK-NEXT:    [[SPEC_SELECT:%.*]] = zext i1 [[X2]] to i32
 ; CHECK-NEXT:    store i32 [[SPEC_SELECT]], ptr [[P:%.*]], align 4
@@ -53,7 +53,7 @@ define void @test_simple_commuted(ptr %p, i32 %a, i32 %b) {
 ;
 entry:
   %x1 = icmp eq i32 %a, 0
-  br i1 %x1, label %yes1, label %fallthrough
+  br i1 %x1, label %yes1, label %fallthrough, !prof !0
 
 yes1:
   store i32 0, ptr %p
@@ -61,7 +61,7 @@ yes1:
 
 fallthrough:
   %x2 = icmp eq i32 %b, 0
-  br i1 %x2, label %yes2, label %end
+  br i1 %x2, label %yes2, label %end, !prof !1
 
 yes2:
   store i32 1, ptr %p
@@ -406,3 +406,9 @@ yes2:
 end:
   ret void
 }
+
+!0 = !{!"branch_weights", i32 7, i32 13}
+!1 = !{!"branch_weights", i32 3, i32 11}
+;.
+; CHECK: [[PROF0]] = !{!"branch_weights", i32 259, i32 21}
+;.

_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to