================
@@ -36,6 +47,146 @@ static std::string normalizeForBundler(const llvm::Triple 
&T,
                      : T.normalize();
 }
 
+// Collect undefined __hip_fatbin* and __hip_gpubin_handle* symbols from all
+// input object or archive files.
+class HIPUndefinedFatBinSymbols {
+public:
+  HIPUndefinedFatBinSymbols(const Compilation &C)
+      : C(C), DiagID(C.getDriver().getDiags().getCustomDiagID(
+                  DiagnosticsEngine::Error,
+                  "Error collecting HIP undefined fatbin symbols: %0")),
+        Quiet(C.getArgs().hasArg(options::OPT__HASH_HASH_HASH)),
+        Verbose(C.getArgs().hasArg(options::OPT_v)) {
+    populateSymbols();
+    if (Verbose) {
+      for (auto Name : FatBinSymbols)
+        llvm::errs() << "Found undefined HIP fatbin symbol: " << Name << "\n";
+      for (auto Name : GPUBinHandleSymbols)
+        llvm::errs() << "Found undefined HIP gpubin handle symbol: " << Name
+                     << "\n";
+    }
+  }
+
+  const std::set<std::string> &getFatBinSymbols() const {
+    return FatBinSymbols;
+  }
+
+  const std::set<std::string> &getGPUBinHandleSymbols() const {
+    return GPUBinHandleSymbols;
+  }
+
+private:
+  const Compilation &C;
+  unsigned DiagID;
+  bool Quiet;
+  bool Verbose;
+  std::set<std::string> FatBinSymbols;
+  std::set<std::string> GPUBinHandleSymbols;
+  const std::string FatBinPrefix = "__hip_fatbin";
+  const std::string GPUBinHandlePrefix = "__hip_gpubin_handle";
+
+  void populateSymbols() {
+    std::deque<const Action *> WorkList;
+    std::set<const Action *> Visited;
+
+    for (const auto &Action : C.getActions()) {
+      WorkList.push_back(Action);
+    }
+
+    while (!WorkList.empty()) {
+      const Action *CurrentAction = WorkList.front();
+      WorkList.pop_front();
+
+      if (!CurrentAction || !Visited.insert(CurrentAction).second)
+        continue;
+
+      if (const auto *IA = dyn_cast<InputAction>(CurrentAction)) {
+        std::string ID = IA->getId().str();
+        if (!ID.empty()) {
+          ID = llvm::utohexstr(llvm::MD5Hash(ID), /*LowerCase=*/true);
+          FatBinSymbols.insert(Twine(FatBinPrefix + "_" + ID).str());
+          GPUBinHandleSymbols.insert(
+              Twine(GPUBinHandlePrefix + "_" + ID).str());
+          continue;
+        }
+        const char *Filename = IA->getInputArg().getValue();
+        auto BufferOrErr = llvm::MemoryBuffer::getFile(Filename);
+        // Input action could be options to linker, therefore ignore it
+        // if cannot read it.
+        if (!BufferOrErr)
+          continue;
+
+        processInput(BufferOrErr.get()->getMemBufferRef());
+      } else
+        WorkList.insert(WorkList.end(), CurrentAction->getInputs().begin(),
+                        CurrentAction->getInputs().end());
+    }
+  }
+
+  void processInput(const llvm::MemoryBufferRef &Buffer) {
+    // Try processing as object file first.
+    auto ObjFileOrErr = llvm::object::ObjectFile::createObjectFile(Buffer);
+    if (ObjFileOrErr) {
+      processSymbols(**ObjFileOrErr);
+      return;
+    }
+
+    // Then try processing as archive files.
+    llvm::consumeError(ObjFileOrErr.takeError());
+    auto ArchiveOrErr = llvm::object::Archive::create(Buffer);
+    if (ArchiveOrErr) {
+      llvm::Error Err = llvm::Error::success();
+      llvm::object::Archive &Archive = *ArchiveOrErr.get();
+      for (auto &Child : Archive.children(Err)) {
+        auto ChildBufOrErr = Child.getMemoryBufferRef();
+        if (ChildBufOrErr)
+          processInput(*ChildBufOrErr);
+        else
+          errorHandler(ChildBufOrErr.takeError());
+      }
+
+      if (Err)
+        errorHandler(std::move(Err));
+      return;
+    }
+
+    // Ignore other files.
+    llvm::consumeError(ArchiveOrErr.takeError());
+  }
+  void processSymbols(const llvm::object::ObjectFile &Obj) {
+    for (const auto &Symbol : Obj.symbols()) {
+      auto FlagOrErr = Symbol.getFlags();
+      if (!FlagOrErr) {
+        errorHandler(FlagOrErr.takeError());
+        continue;
+      }
+
+      // Filter only undefined symbols
+      if (!(FlagOrErr.get() & llvm::object::SymbolRef::SF_Undefined)) {
----------------
Artem-B wrote:

style nit: remove `{}` around single-statement body. 

Applies here and in a handful of other places throughout the patch.

https://github.com/llvm/llvm-project/pull/81700
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to