https://github.com/mtrofin updated 
https://github.com/llvm/llvm-project/pull/167380

>From 1ca5b000f57eada55d9e5ae38806d1e7e574f0ec Mon Sep 17 00:00:00 2001
From: Mircea Trofin <[email protected]>
Date: Fri, 7 Nov 2025 15:21:38 -0800
Subject: [PATCH] [LTT][profcheck] Set branch weights for complex
 llvm.type.test lowering

---
 llvm/lib/Transforms/IPO/LowerTypeTests.cpp    | 49 ++++++++++++++-----
 llvm/test/Other/new-pm-O0-defaults.ll         |  1 +
 .../test/Transforms/LowerTypeTests/section.ll | 23 ++++++++-
 3 files changed, 60 insertions(+), 13 deletions(-)

diff --git a/llvm/lib/Transforms/IPO/LowerTypeTests.cpp 
b/llvm/lib/Transforms/IPO/LowerTypeTests.cpp
index 94663ff928a0b..31b5487ce6ec6 100644
--- a/llvm/lib/Transforms/IPO/LowerTypeTests.cpp
+++ b/llvm/lib/Transforms/IPO/LowerTypeTests.cpp
@@ -25,6 +25,7 @@
 #include "llvm/ADT/Statistic.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/ADT/TinyPtrVector.h"
+#include "llvm/Analysis/BlockFrequencyInfo.h"
 #include "llvm/Analysis/LoopInfo.h"
 #include "llvm/Analysis/PostDominators.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
@@ -54,6 +55,7 @@
 #include "llvm/IR/ModuleSummaryIndexYAML.h"
 #include "llvm/IR/Operator.h"
 #include "llvm/IR/PassManager.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/ReplaceConstant.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IR/Use.h"
@@ -95,6 +97,7 @@ STATISTIC(NumByteArraysCreated, "Number of byte arrays 
created");
 STATISTIC(NumTypeTestCallsLowered, "Number of type test calls lowered");
 STATISTIC(NumTypeIdDisjointSets, "Number of disjoint sets of type 
identifiers");
 
+namespace llvm {
 static cl::opt<bool> AvoidReuse(
     "lowertypetests-avoid-reuse",
     cl::desc("Try to avoid reuse of byte array addresses using aliases"),
@@ -131,6 +134,9 @@ static cl::opt<DropTestKind>
                                           "Drop all type test sequences")),
                     cl::Hidden, cl::init(DropTestKind::None));
 
+extern cl::opt<bool> ProfcheckDisableMetadataFixes;
+} // namespace llvm
+
 bool BitSetInfo::containsGlobalOffset(uint64_t Offset) const {
   if (Offset < ByteOffset)
     return false;
@@ -423,8 +429,10 @@ struct ScopedSaveAliaseesAndUsed {
 class LowerTypeTestsModule {
   Module &M;
 
-  ModuleSummaryIndex *ExportSummary;
-  const ModuleSummaryIndex *ImportSummary;
+  FunctionAnalysisManager &FAM;
+
+  ModuleSummaryIndex *const ExportSummary;
+  const ModuleSummaryIndex *const ImportSummary;
   // Set when the client has invoked this to simply drop all type test assume
   // sequences.
   DropTestKind DropTypeTests;
@@ -507,9 +515,10 @@ class LowerTypeTestsModule {
   void allocateByteArrays();
   Value *createBitSetTest(IRBuilder<> &B, const TypeIdLowering &TIL,
                           Value *BitOffset);
-  void lowerTypeTestCalls(
-      ArrayRef<Metadata *> TypeIds, Constant *CombinedGlobalAddr,
-      const DenseMap<GlobalTypeMember *, uint64_t> &GlobalLayout);
+  void
+  lowerTypeTestCalls(ArrayRef<Metadata *> TypeIds, Constant 
*CombinedGlobalAddr,
+                     const DenseMap<GlobalTypeMember *, uint64_t> 
&GlobalLayout,
+                     uint64_t *TotalCallCount = nullptr);
   Value *lowerTypeTestCall(Metadata *TypeId, CallInst *CI,
                            const TypeIdLowering &TIL);
 
@@ -803,6 +812,8 @@ Value *LowerTypeTestsModule::lowerTypeTestCall(Metadata 
*TypeId, CallInst *CI,
       }
 
   IRBuilder<> ThenB(SplitBlockAndInsertIfThen(OffsetInRange, CI, false));
+  setExplicitlyUnknownBranchWeightsIfProfiled(*InitialBB->getTerminator(),
+                                              DEBUG_TYPE);
 
   // Now that we know that the offset is in range and aligned, load the
   // appropriate bit from the bitset.
@@ -1181,7 +1192,8 @@ buildBitSets(ArrayRef<Metadata *> TypeIds,
 
 void LowerTypeTestsModule::lowerTypeTestCalls(
     ArrayRef<Metadata *> TypeIds, Constant *CombinedGlobalAddr,
-    const DenseMap<GlobalTypeMember *, uint64_t> &GlobalLayout) {
+    const DenseMap<GlobalTypeMember *, uint64_t> &GlobalLayout,
+    uint64_t *TotalCallCount) {
   // For each type identifier in this disjoint set...
   for (const auto &[TypeId, BSI] : buildBitSets(TypeIds, GlobalLayout)) {
     ByteArrayInfo *BAI = nullptr;
@@ -1227,6 +1239,18 @@ void LowerTypeTestsModule::lowerTypeTestCalls(
       ++NumTypeTestCallsLowered;
       Value *Lowered = lowerTypeTestCall(TypeId, CI, TIL);
       if (Lowered) {
+        if (TotalCallCount) {
+          auto *CIF = CI->getFunction();
+          if (auto EC = CIF->getEntryCount())
+            if (EC->getCount()) {
+              auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(*CIF);
+              *TotalCallCount +=
+                  EC->getCount() *
+                  static_cast<double>(
+                      BFI.getBlockFreq(CI->getParent()).getFrequency()) /
+                  BFI.getEntryFreq().getFrequency();
+            }
+        }
         CI->replaceAllUsesWith(Lowered);
         CI->eraseFromParent();
       }
@@ -1702,10 +1726,13 @@ void 
LowerTypeTestsModule::buildBitSetsFromFunctionsNative(
   ArrayType *JumpTableEntryType = ArrayType::get(Int8Ty, EntrySize);
   ArrayType *JumpTableType =
       ArrayType::get(JumpTableEntryType, Functions.size());
-  auto JumpTable = ConstantExpr::getPointerCast(
+  auto *JumpTable = ConstantExpr::getPointerCast(
       JumpTableFn, PointerType::getUnqual(M.getContext()));
 
-  lowerTypeTestCalls(TypeIds, JumpTable, GlobalLayout);
+  uint64_t Count = 0;
+  lowerTypeTestCalls(TypeIds, JumpTable, GlobalLayout, &Count);
+  if (!ProfcheckDisableMetadataFixes && Count)
+    JumpTableFn->setEntryCount(Count);
 
   // Build aliases pointing to offsets into the jump table, and replace
   // references to the original functions with references to the aliases.
@@ -1870,7 +1897,9 @@ void LowerTypeTestsModule::buildBitSetsFromDisjointSet(
 LowerTypeTestsModule::LowerTypeTestsModule(
     Module &M, ModuleAnalysisManager &AM, ModuleSummaryIndex *ExportSummary,
     const ModuleSummaryIndex *ImportSummary, DropTestKind DropTypeTests)
-    : M(M), ExportSummary(ExportSummary), ImportSummary(ImportSummary),
+    : M(M),
+      FAM(AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager()),
+      ExportSummary(ExportSummary), ImportSummary(ImportSummary),
       DropTypeTests(ClDropTypeTests > DropTypeTests ? ClDropTypeTests
                                                     : DropTypeTests) {
   assert(!(ExportSummary && ImportSummary));
@@ -1879,8 +1908,6 @@ LowerTypeTestsModule::LowerTypeTestsModule(
   if (Arch == Triple::arm)
     CanUseArmJumpTable = true;
   if (Arch == Triple::arm || Arch == Triple::thumb) {
-    auto &FAM =
-        AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
     for (Function &F : M) {
       // Skip declarations since we should not query the TTI for them.
       if (F.isDeclaration())
diff --git a/llvm/test/Other/new-pm-O0-defaults.ll 
b/llvm/test/Other/new-pm-O0-defaults.ll
index 278a89261691a..4e2b5d4d62a78 100644
--- a/llvm/test/Other/new-pm-O0-defaults.ll
+++ b/llvm/test/Other/new-pm-O0-defaults.ll
@@ -44,6 +44,7 @@
 ; CHECK-PRE-LINK: Running pass: CanonicalizeAliasesPass
 ; CHECK-PRE-LINK-NEXT: Running pass: NameAnonGlobalPass
 ; CHECK-THINLTO: Running pass: LowerTypeTestsPass
+; CHECK-THINLTO: Running analysis: 
InnerAnalysisManagerProxy<FunctionAnalysisManager, Module> on [module]
 ; CHECK-THINLTO-NEXT: Running pass: CoroConditionalWrapper
 ; CHECK-THINLTO-NEXT: Running pass: EliminateAvailableExternallyPass
 ; CHECK-THINLTO-NEXT: Running pass: GlobalDCEPass
diff --git a/llvm/test/Transforms/LowerTypeTests/section.ll 
b/llvm/test/Transforms/LowerTypeTests/section.ll
index bd91389c60ef0..1b0efd5bdd01d 100644
--- a/llvm/test/Transforms/LowerTypeTests/section.ll
+++ b/llvm/test/Transforms/LowerTypeTests/section.ll
@@ -13,14 +13,33 @@ entry:
   ret void
 }
 
-define i1 @g() {
+define i1 @g() !prof !1 {
 entry:
   %0 = call i1 @llvm.type.test(ptr @f, metadata !"_ZTSFvE")
   ret i1 %0
 }
 
-; CHECK: define private void @[[JT]]() #{{.*}} align {{.*}} {
+define i1 @h(i1 %c) !prof !2 {
+entry:
+  br i1 %c, label %yes, label %common, !prof !3
+
+yes:
+  %0 = call i1 @llvm.type.test(ptr @f, metadata !"_ZTSFvE")
+  ret i1 %0
+
+common:
+  ret i1 0
+}
+
+; CHECK: define private void @[[JT]]() #{{.*}} align {{.*}} !prof !4 {
 
 declare i1 @llvm.type.test(ptr, metadata) nounwind readnone
 
 !0 = !{i64 0, !"_ZTSFvE"}
+!1 = !{!"function_entry_count", i32 20} 
+!2 = !{!"function_entry_count", i32 40}
+!3 = !{!"branch_weights", i32 3, i32 5}
+; the entry count for the jumptable function is: 20 + 40 * (3/8) = 20 + 15
+; where: 20 is the entry count of g, 40 of h, and 3/8 is the frequency of the
+; llvm.type.test in h, relative to h's entry basic block.                      
         
+; CHECK !4 = !{!"function_entry_count", i64 35}
\ No newline at end of file

_______________________________________________
llvm-branch-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to