================
@@ -592,72 +633,276 @@ size_t 
YAMLProfileReader::matchWithCallGraph(BinaryContext &BC) {
   return MatchedWithCallGraph;
 }
 
-size_t YAMLProfileReader::InlineTreeNodeMapTy::matchInlineTrees(
-    const MCPseudoProbeDecoder &Decoder,
-    const std::vector<yaml::bolt::InlineTreeNode> &DecodedInlineTree,
-    const MCDecodedPseudoProbeInlineTree *Root) {
-  // Match inline tree nodes by GUID, checksum, parent, and call site.
-  for (const auto &[InlineTreeNodeId, InlineTreeNode] :
-       llvm::enumerate(DecodedInlineTree)) {
-    uint64_t GUID = InlineTreeNode.GUID;
-    uint64_t Hash = InlineTreeNode.Hash;
-    uint32_t ParentId = InlineTreeNode.ParentIndexDelta;
-    uint32_t CallSiteProbe = InlineTreeNode.CallSiteProbe;
-    const MCDecodedPseudoProbeInlineTree *Cur = nullptr;
-    if (!InlineTreeNodeId) {
-      Cur = Root;
-    } else if (const MCDecodedPseudoProbeInlineTree *Parent =
-                   getInlineTreeNode(ParentId)) {
-      for (const MCDecodedPseudoProbeInlineTree &Child :
-           Parent->getChildren()) {
-        if (Child.Guid == GUID) {
-          if (std::get<1>(Child.getInlineSite()) == CallSiteProbe)
-            Cur = &Child;
-          break;
-        }
+const MCDecodedPseudoProbeInlineTree *
+YAMLProfileReader::lookupTopLevelNode(const BinaryFunction &BF) {
+  const BinaryContext &BC = BF.getBinaryContext();
+  const MCPseudoProbeDecoder *Decoder = BC.getPseudoProbeDecoder();
+  assert(Decoder &&
+         "If pseudo probes are in use, pseudo probe decoder should exist");
+  uint64_t Addr = BF.getAddress();
+  uint64_t Size = BF.getSize();
+  auto Probes = Decoder->getAddress2ProbesMap().find(Addr, Addr + Size);
+  if (Probes.empty())
+    return nullptr;
+  const MCDecodedPseudoProbe &Probe = *Probes.begin();
+  const MCDecodedPseudoProbeInlineTree *Root = Probe.getInlineTreeNode();
+  while (Root->hasInlineSite())
+    Root = (const MCDecodedPseudoProbeInlineTree *)Root->Parent;
+  return Root;
+}
+
+size_t YAMLProfileReader::matchInlineTreesImpl(
+    BinaryFunction &BF, yaml::bolt::BinaryFunctionProfile &YamlBF,
+    const MCDecodedPseudoProbeInlineTree &Root, uint32_t RootIdx,
+    ArrayRef<yaml::bolt::InlineTreeNode> ProfileInlineTree,
+    MutableArrayRef<const MCDecodedPseudoProbeInlineTree *> Map, float Scale) {
+  using namespace yaml::bolt;
+  BinaryContext &BC = BF.getBinaryContext();
+  const MCPseudoProbeDecoder &Decoder = *BC.getPseudoProbeDecoder();
+  const InlineTreeNode &FuncNode = ProfileInlineTree[RootIdx];
+
+  using ChildMapTy =
+      std::unordered_map<InlineSite, const MCDecodedPseudoProbeInlineTree *,
+                         InlineSiteHash>;
+  using CallSiteInfoTy =
+      std::unordered_map<InlineSite, const CallSiteInfo *, InlineSiteHash>;
+  // Mapping from a parent node id to a map InlineSite -> Child node.
+  DenseMap<uint32_t, ChildMapTy> ParentToChildren;
+  // Collect calls in the profile: map from a parent node id to a map
+  // InlineSite -> CallSiteInfo ptr.
+  DenseMap<uint32_t, CallSiteInfoTy> ParentToCSI;
+  for (const BinaryBasicBlockProfile &YamlBB : YamlBF.Blocks) {
+    // Collect callees for inlined profile matching, indexed by InlineSite.
+    for (const CallSiteInfo &CSI : YamlBB.CallSites) {
+      ProbeMatchingStats.TotalCallCount += CSI.Count;
+      ++ProbeMatchingStats.TotalCallSites;
+      if (CSI.Probe == 0) {
+        LLVM_DEBUG(dbgs() << "no probe for " << CSI.DestId << " " << CSI.Count
+                          << '\n');
+        ++ProbeMatchingStats.MissingCallProbe;
+        ProbeMatchingStats.MissingCallCount += CSI.Count;
+        continue;
+      }
+      const BinaryFunctionProfile *Callee = IdToYamLBF.lookup(CSI.DestId);
+      if (!Callee) {
+        LLVM_DEBUG(dbgs() << "no callee for " << CSI.DestId << " " << CSI.Count
+                          << '\n');
+        ++ProbeMatchingStats.MissingCallee;
+        ProbeMatchingStats.MissingCallCount += CSI.Count;
+        continue;
+      }
+      // Get callee GUID
+      if (Callee->InlineTree.empty()) {
+        LLVM_DEBUG(dbgs() << "no inline tree for " << Callee->Name << '\n');
+        ++ProbeMatchingStats.MissingInlineTree;
+        ProbeMatchingStats.MissingCallCount += CSI.Count;
+        continue;
+      }
+      uint64_t CalleeGUID = Callee->InlineTree.front().GUID;
+      ParentToCSI[CSI.InlineTreeNode][InlineSite(CalleeGUID, CSI.Probe)] = 
&CSI;
+    }
+  }
+  LLVM_DEBUG({
+    for (auto &[ParentId, InlineSiteCSI] : ParentToCSI) {
+      for (auto &[InlineSite, CSI] : InlineSiteCSI) {
+        auto [CalleeGUID, CallSite] = InlineSite;
+        errs() << ParentId << "@" << CallSite << "->"
+               << Twine::utohexstr(CalleeGUID) << ": " << CSI->Count << ", "
+               << Twine::utohexstr(CSI->Offset) << '\n';
+      }
+    }
+  });
+
+  assert(!Root.isRoot());
+  LLVM_DEBUG(dbgs() << "matchInlineTreesImpl for " << BF << "@"
+                    << Twine::utohexstr(Root.Guid) << " and " << YamlBF.Name
+                    << "@" << Twine::utohexstr(FuncNode.GUID) << '\n');
+  ++ProbeMatchingStats.AttemptedNodes;
+  ++ProbeMatchingStats.AttemptedRoots;
+
+  // Match profile function with a lead node (top-level function or inlinee)
+  if (Root.Guid != FuncNode.GUID) {
+    LLVM_DEBUG(dbgs() << "Mismatching root GUID\n");
+    ++ProbeMatchingStats.MismatchingRootGUID;
+    return 0;
+  }
+  {
+    uint64_t BinaryHash = Decoder.getFuncDescForGUID(Root.Guid)->FuncHash;
+    uint64_t ProfileHash = FuncNode.Hash;
----------------
maksfb wrote:

nit: constify.

https://github.com/llvm/llvm-project/pull/100446
_______________________________________________
llvm-branch-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to