https://github.com/mtrofin created https://github.com/llvm/llvm-project/pull/154841
None >From f4441cbe5e38f6abc76604a8049f6e36fb4881a7 Mon Sep 17 00:00:00 2001 From: Mircea Trofin <mtro...@google.com> Date: Thu, 21 Aug 2025 13:54:49 -0700 Subject: [PATCH] [SimplifyCFG] Set branch weights when merging conditional store to address --- llvm/include/llvm/IR/ProfDataUtils.h | 22 +++++++++++++ llvm/lib/Transforms/Utils/SimplifyCFG.cpp | 39 +++++++++++++++-------- 2 files changed, 48 insertions(+), 13 deletions(-) diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h index 404875285beae..c9284c1bc8dde 100644 --- a/llvm/include/llvm/IR/ProfDataUtils.h +++ b/llvm/include/llvm/IR/ProfDataUtils.h @@ -15,6 +15,7 @@ #ifndef LLVM_IR_PROFDATAUTILS_H #define LLVM_IR_PROFDATAUTILS_H +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Twine.h" #include "llvm/IR/Metadata.h" @@ -186,5 +187,26 @@ LLVM_ABI bool hasExplicitlyUnknownBranchWeights(const Instruction &I); /// Scaling the profile data attached to 'I' using the ratio of S/T. LLVM_ABI void scaleProfData(Instruction &I, uint64_t S, uint64_t T); +/// get the branch weights of a branch conditioned on b1 || b2, where b1 and b2 +/// are 2 booleans that are the condition of 2 branches for which we have the +/// branch weights B1 and B2, respectivelly. +inline SmallVector<uint64_t, 2> +getDisjunctionWeights(const SmallVector<uint32_t, 2> &B1, + const SmallVector<uint32_t, 2> &B2) { + // the probability of the new branch being taken is: + // P = p(b1) + p(b2) - p (b1 and b2) + // not P = p((not b1) and (not b2)) = + // = B1[1] / (B1[0]+B1[1]) * B2[1] / (B2[0]+B2[1]) = + // = B1[1] * B2[1] / (B1[0] + B1[1]) * (B2[0] + B2[1]) + // P = 1 - (not P) + // The numerator of P will be (B1[0] + B1[1]) * (B2[0] + B2[1]) - B1[1]*B2[1] + // ... which becomes what's shown below. + // We don't need the denominators, they are the same + assert(B1.size() == 2); + assert(B2.size() == 2); + auto FalseWeight = B1[1] * B2[1]; + auto TrueWeight = B1[0] * B2[0] + B1[0] * B2[1] + B1[1] * B2[0]; + return {TrueWeight, FalseWeight}; +} } // namespace llvm #endif diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp index 4847add386dc4..e26a189564d13 100644 --- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -1182,7 +1182,7 @@ static void cloneInstructionsIntoPredecessorBlockAndUpdateSSAUses( // only given the branch precondition. // Similarly strip attributes on call parameters that may cause UB in // location the call is moved to. - NewBonusInst->dropUBImplyingAttrsAndMetadata(); + NewBonusInst->dropUBImplyingAttrsAndMetadata({LLVMContext::MD_prof}); NewBonusInst->insertInto(PredBlock, PTI->getIterator()); auto Range = NewBonusInst->cloneDebugInfoFrom(&BonusInst); @@ -1808,7 +1808,8 @@ static void hoistConditionalLoadsStores( // !annotation: Not impact semantics. Keep it. if (const MDNode *Ranges = I->getMetadata(LLVMContext::MD_range)) MaskedLoadStore->addRangeRetAttr(getConstantRangeFromMetadata(*Ranges)); - I->dropUBImplyingAttrsAndUnknownMetadata({LLVMContext::MD_annotation}); + I->dropUBImplyingAttrsAndUnknownMetadata( + {LLVMContext::MD_annotation, LLVMContext::MD_prof}); // FIXME: DIAssignID is not supported for masked store yet. // (Verifier::visitDIAssignIDMetadata) at::deleteAssignmentMarkers(I); @@ -3366,7 +3367,7 @@ bool SimplifyCFGOpt::speculativelyExecuteBB(BranchInst *BI, if (!SpeculatedStoreValue || &I != SpeculatedStore) { I.setDebugLoc(DebugLoc::getDropped()); } - I.dropUBImplyingAttrsAndMetadata(); + I.dropUBImplyingAttrsAndMetadata({LLVMContext::MD_prof}); // Drop ephemeral values. if (EphTracker.contains(&I)) { @@ -4404,10 +4405,12 @@ static bool mergeConditionalStoreToAddress( // OK, we're going to sink the stores to PostBB. The store has to be // conditional though, so first create the predicate. - Value *PCond = cast<BranchInst>(PFB->getSinglePredecessor()->getTerminator()) - ->getCondition(); - Value *QCond = cast<BranchInst>(QFB->getSinglePredecessor()->getTerminator()) - ->getCondition(); + BranchInst *const PBranch = + cast<BranchInst>(PFB->getSinglePredecessor()->getTerminator()); + BranchInst *const QBranch = + cast<BranchInst>(QFB->getSinglePredecessor()->getTerminator()); + Value *const PCond = PBranch->getCondition(); + Value *const QCond = QBranch->getCondition(); Value *PPHI = ensureValueAvailableInSuccessor(PStore->getValueOperand(), PStore->getParent()); @@ -4418,19 +4421,29 @@ static bool mergeConditionalStoreToAddress( IRBuilder<> QB(PostBB, PostBBFirst); QB.SetCurrentDebugLocation(PostBBFirst->getStableDebugLoc()); - Value *PPred = PStore->getParent() == PTB ? PCond : QB.CreateNot(PCond); - Value *QPred = QStore->getParent() == QTB ? QCond : QB.CreateNot(QCond); + InvertPCond = (PStore->getParent() == PTB) ^ InvertPCond; + InvertQCond = (QStore->getParent() == QTB) ^ InvertQCond; + Value *const PPred = InvertPCond ? PCond : QB.CreateNot(PCond); + Value *const QPred = InvertQCond ? QCond : QB.CreateNot(QCond); - if (InvertPCond) - PPred = QB.CreateNot(PPred); - if (InvertQCond) - QPred = QB.CreateNot(QPred); Value *CombinedPred = QB.CreateOr(PPred, QPred); BasicBlock::iterator InsertPt = QB.GetInsertPoint(); auto *T = SplitBlockAndInsertIfThen(CombinedPred, InsertPt, /*Unreachable=*/false, /*BranchWeights=*/nullptr, DTU); + if (hasBranchWeightMD(*PBranch) && hasBranchWeightMD(*QBranch)) { + SmallVector<uint32_t, 2> PWeights, QWeights; + extractBranchWeights(*PBranch, PWeights); + extractBranchWeights(*QBranch, QWeights); + if (InvertPCond) + std::swap(PWeights[0], PWeights[1]); + if (InvertQCond) + std::swap(QWeights[0], QWeights[1]); + auto CombinedWeights = getDisjunctionWeights(PWeights, QWeights); + setBranchWeights(T, CombinedWeights[0], CombinedWeights[1], + /*IsExpected=*/false); + } QB.SetInsertPoint(T); StoreInst *SI = cast<StoreInst>(QB.CreateStore(QPHI, Address)); _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits