This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b8fcad8  [LLVM/CG] Sort PrimFuncs when creating LLVM module (#8958)
b8fcad8 is described below

commit b8fcad88b08a78dadcfeaf06e94e48f27eea4187
Author: Krzysztof Parzyszek <[email protected]>
AuthorDate: Thu Sep 9 03:32:58 2021 -0500

    [LLVM/CG] Sort PrimFuncs when creating LLVM module (#8958)
    
    * [LLVM/CG] Sort PrimFuncs when creating LLVM module
    
    PrimFuncs are stored in a map where the order of iteration is not
    deterministic. This can cause a different llvm::Module to be created
    each time, which can defeat debugging tools like -opt-bisect-limit.
    
    Add function CodeGenLLVM::AddFunctionsOrdered that takes a range of
    PrimFuncs or objects convertible to PrimFuncs, and adds them to the
    LLVM module in a deterministic order.
    
    * Empty commit to restart build
    
    * Add testcase
---
 src/target/llvm/codegen_amdgpu.cc                 | 10 +++----
 src/target/llvm/codegen_hexagon.cc                |  5 ++--
 src/target/llvm/codegen_llvm.h                    | 36 +++++++++++++++++++++++
 src/target/llvm/codegen_nvptx.cc                  | 10 +++----
 src/target/llvm/llvm_module.cc                    |  4 +--
 tests/python/unittest/test_target_codegen_llvm.py | 27 +++++++++++++++++
 6 files changed, 76 insertions(+), 16 deletions(-)

diff --git a/src/target/llvm/codegen_amdgpu.cc 
b/src/target/llvm/codegen_amdgpu.cc
index 7770e42..33a09b1 100644
--- a/src/target/llvm/codegen_amdgpu.cc
+++ b/src/target/llvm/codegen_amdgpu.cc
@@ -230,11 +230,11 @@ runtime::Module BuildAMDGPU(IRModule mod, Target target) {
 
   cg->Init("TVMAMDGPUModule", tm.get(), ctx.get(), false, false, false);
 
-  for (auto kv : mod->functions) {
-    ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "Can only lower IR Module 
with PrimFuncs";
-    auto f = Downcast<PrimFunc>(kv.second);
-    cg->AddFunction(f);
-  }
+  cg->AddFunctionsOrdered(mod->functions.begin(), mod->functions.end(), 
[](auto& kv) {
+    ICHECK(kv.second->template IsInstance<PrimFuncNode>())
+        << "Can only lower IR Module with PrimFuncs";
+    return Downcast<PrimFunc>(kv.second);
+  });
 
   const auto* find_rocm_bitcodes = 
tvm::runtime::Registry::Get("tvm_callback_rocm_bitcode_path");
   Array<runtime::String> bitcode_files = (*find_rocm_bitcodes)();
diff --git a/src/target/llvm/codegen_hexagon.cc 
b/src/target/llvm/codegen_hexagon.cc
index e9eacc2..2f91807 100644
--- a/src/target/llvm/codegen_hexagon.cc
+++ b/src/target/llvm/codegen_hexagon.cc
@@ -731,9 +731,8 @@ runtime::Module BuildHexagon(IRModule mod, Target target) {
   }
 
   cg->Init("TVMHexagonModule", tm.get(), ctx.get(), false, false, false);
-  for (const PrimFunc& f : funcs) {
-    cg->AddFunction(f);
-  }
+  cg->AddFunctionsOrdered(funcs.begin(), funcs.end());
+
   if (!linked_params.empty()) {
     cg->LinkParameters(linked_params);
   }
diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h
index 52c5b98a..a4f007a 100644
--- a/src/target/llvm/codegen_llvm.h
+++ b/src/target/llvm/codegen_llvm.h
@@ -36,6 +36,7 @@
 #include <tvm/tir/stmt.h>
 #include <tvm/tir/stmt_functor.h>
 
+#include <algorithm>
 #include <memory>
 #include <string>
 #include <unordered_map>
@@ -93,6 +94,25 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const 
PrimExpr&)>,
    */
   virtual std::unique_ptr<llvm::Module> Finish();
   /*!
+   * \brief Add functions from the (unordered) range to the current module in 
a deterministic order.
+   *        The range consists of objects convertible to PrimFunc.
+   * \param begin The beginning of the range.
+   * \param end The end of the range.
+   * \param pfunc Converter function from the range element type to PrimFunc.
+   */
+  template <typename IterType, typename ConvType>
+  void AddFunctionsOrdered(IterType begin, IterType end, ConvType pfunc);
+  /*!
+   * \brief Add functions from the (unordered) range of elements of type 
PrimFunc to the current
+   *        module in a deterministic order.
+   * \param begin The beginning of the range.
+   * \param end The end of the range.
+   */
+  template <typename IterType>
+  void AddFunctionsOrdered(IterType begin, IterType end) {
+    this->AddFunctionsOrdered(begin, end, [](auto f) { return f; });
+  }
+  /*!
    * \brief Add mod to be linked with the generated module
    * \param mod The module to be linked.
    */
@@ -377,6 +397,22 @@ inline int CodeGenLLVM::GetVectorNumElements(llvm::Value* 
vec) {
 #endif
 }
 
+template <typename IterType, typename ConvType>
+void CodeGenLLVM::AddFunctionsOrdered(IterType begin, IterType end, ConvType 
pfunc) {
+  std::vector<PrimFunc> funcs;
+  for (auto it = begin; it != end; ++it) {
+    funcs.push_back(pfunc(*it));
+  }
+  std::sort(funcs.begin(), funcs.end(), [](PrimFunc func_a, PrimFunc func_b) {
+    std::string name_a = 
func_a->GetAttr<String>(tvm::attr::kGlobalSymbol).value();
+    std::string name_b = 
func_b->GetAttr<String>(tvm::attr::kGlobalSymbol).value();
+    return name_a < name_b;
+  });
+  for (auto& f : funcs) {
+    AddFunction(f);
+  }
+}
+
 }  // namespace codegen
 }  // namespace tvm
 #endif  // LLVM_VERSION
diff --git a/src/target/llvm/codegen_nvptx.cc b/src/target/llvm/codegen_nvptx.cc
index 15543ed..ebe6d6d 100644
--- a/src/target/llvm/codegen_nvptx.cc
+++ b/src/target/llvm/codegen_nvptx.cc
@@ -274,11 +274,11 @@ runtime::Module BuildNVPTX(IRModule mod, Target target) {
 
   cg->Init("TVMPTXModule", tm.get(), ctx.get(), false, false, false);
 
-  for (auto kv : mod->functions) {
-    ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "Can only lower IR Module 
with PrimFuncs";
-    auto f = Downcast<PrimFunc>(kv.second);
-    cg->AddFunction(f);
-  }
+  cg->AddFunctionsOrdered(mod->functions.begin(), mod->functions.end(), 
[](auto& kv) {
+    ICHECK(kv.second->template IsInstance<PrimFuncNode>())
+        << "Can only lower IR Module with PrimFuncs";
+    return Downcast<PrimFunc>(kv.second);
+  });
 
   const auto* flibdevice_path = 
tvm::runtime::Registry::Get("tvm_callback_libdevice_path");
   if (flibdevice_path != nullptr) {
diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc
index 8bdf6d1..0e4bca4 100644
--- a/src/target/llvm/llvm_module.cc
+++ b/src/target/llvm/llvm_module.cc
@@ -258,9 +258,7 @@ class LLVMModuleNode final : public runtime::ModuleNode {
     // makes sense when we start to use multiple modules.
     cg->Init("TVMMod", tm_.get(), ctx_.get(), system_lib, system_lib, 
target_c_runtime);
 
-    for (const auto& f : funcs) {
-      cg->AddFunction(f);
-    }
+    cg->AddFunctionsOrdered(funcs.begin(), funcs.end());
 
     if (entry_func.length() != 0) {
       cg->AddMainFunction(entry_func);
diff --git a/tests/python/unittest/test_target_codegen_llvm.py 
b/tests/python/unittest/test_target_codegen_llvm.py
index 10cbcd6..e5e93ed 100644
--- a/tests/python/unittest/test_target_codegen_llvm.py
+++ b/tests/python/unittest/test_target_codegen_llvm.py
@@ -818,5 +818,32 @@ def test_llvm_gpu_lower_atomic():
         tvm.testing.assert_allclose(a.numpy(), ref, rtol=1e-5)
 
 
[email protected]_llvm
+def test_llvm_order_functions():
+    """Check that functions in the LLVM module are ordered alphabetically."""
+
+    # Note: the order is alphabetical because that's a predictable ordering. 
Any predictable
+    # ordering will work fine, but if the ordering changes, this test will 
need to be updated.
+    def make_call_extern(caller, callee):
+        # Create a function:
+        #   float32 caller(float32 v) { return callee(v); }
+        ib = tvm.tir.ir_builder.create()
+        v = tvm.te.var("v", dtype="float32")
+        t = tvm.tir.call_extern("float32", callee, v)
+        ib.emit(t)
+        return tvm.tir.PrimFunc([v], ib.get()).with_attr("global_symbol", 
caller)
+
+    # Create some functions in a random order.
+    functions = {
+        "Danny": make_call_extern("Danny", "Dave"),
+        "Sammy": make_call_extern("Sammy", "Eve"),
+        "Kirby": make_call_extern("Kirby", "Fred"),
+    }
+    mod = tvm.IRModule(functions=functions)
+    ir_text = tvm.build(mod, None, target="llvm").get_source("ll")
+    matches = re.findall(r"^define[^@]*@([a-zA-Z_][a-zA-Z0-9_]*)", ir_text, 
re.MULTILINE)
+    assert matches == sorted(matches)
+
+
 if __name__ == "__main__":
     sys.exit(pytest.main([__file__] + sys.argv[1:]))

Reply via email to