This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 7ae7ea8361 [LLVM][RUNTIME] Make ORCJIT LLVM executor the default one 
(#17538)
7ae7ea8361 is described below

commit 7ae7ea836169d3cf28b05c7d0dd2cb6a2045508e
Author: Balint Cristian <[email protected]>
AuthorDate: Tue Nov 26 15:46:17 2024 +0200

    [LLVM][RUNTIME] Make ORCJIT LLVM executor the default one (#17538)
---
 src/target/llvm/llvm_instance.cc                         |  4 ++--
 src/target/llvm/llvm_instance.h                          |  4 ++--
 src/target/llvm/llvm_module.cc                           | 16 +++++++++++++---
 .../runtime/test_runtime_module_based_interface.py       |  8 ++++----
 tests/python/runtime/test_runtime_module_load.py         |  2 +-
 5 files changed, 22 insertions(+), 12 deletions(-)

diff --git a/src/target/llvm/llvm_instance.cc b/src/target/llvm/llvm_instance.cc
index 0406dcf951..e2c5e28592 100644
--- a/src/target/llvm/llvm_instance.cc
+++ b/src/target/llvm/llvm_instance.cc
@@ -269,7 +269,7 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, 
const TargetJSON& target)
     if ((value == "mcjit") || (value == "orcjit")) {
       jit_engine_ = value;
     } else {
-      LOG(FATAL) << "invalid jit option " << value << " (can be `mcjit` or 
`orcjit`).";
+      LOG(FATAL) << "invalid jit option " << value << " (can be `orcjit` or 
`mcjit`).";
     }
   }
 
@@ -530,7 +530,7 @@ std::string LLVMTargetInfo::str() const {
     os << quote << Join(",", opts) << quote;
   }
 
-  if (jit_engine_ != "mcjit") {
+  if (jit_engine_ != "orcjit") {
     os << " -jit=" << jit_engine_;
   }
 
diff --git a/src/target/llvm/llvm_instance.h b/src/target/llvm/llvm_instance.h
index add2af6002..5cea99403a 100644
--- a/src/target/llvm/llvm_instance.h
+++ b/src/target/llvm/llvm_instance.h
@@ -232,7 +232,7 @@ class LLVMTargetInfo {
   llvm::FastMathFlags GetFastMathFlags() const { return fast_math_flags_; }
   /*!
    * \brief Get the LLVM JIT engine type
-   * \return the type name of the JIT engine (default "mcjit" or "orcjit")
+   * \return the type name of the JIT engine (default "orcjit" or "mcjit")
    */
   const std::string GetJITEngine() const { return jit_engine_; }
   /*!
@@ -348,7 +348,7 @@ class LLVMTargetInfo {
   llvm::Reloc::Model reloc_model_ = llvm::Reloc::PIC_;
   llvm::CodeModel::Model code_model_ = llvm::CodeModel::Small;
   std::shared_ptr<llvm::TargetMachine> target_machine_;
-  std::string jit_engine_ = "mcjit";
+  std::string jit_engine_ = "orcjit";
 };
 
 /*!
diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc
index 34bbb6a0c6..98dbe139f1 100644
--- a/src/target/llvm/llvm_module.cc
+++ b/src/target/llvm/llvm_module.cc
@@ -34,6 +34,7 @@
 #include <llvm/ExecutionEngine/Orc/LLJIT.h>
 #include <llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h>
 #include <llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h>
+#include <llvm/ExecutionEngine/SectionMemoryManager.h>
 #include <llvm/IR/DataLayout.h>
 #include <llvm/IR/Function.h>
 #include <llvm/IR/Intrinsics.h>
@@ -512,8 +513,17 @@ void LLVMModuleNode::InitORCJIT() {
 
 #if TVM_LLVM_VERSION >= 130
   // linker
-  const auto linkerBuilder = [&](llvm::orc::ExecutionSession& session, const 
llvm::Triple&) {
-    return std::make_unique<llvm::orc::ObjectLinkingLayer>(session);
+  const auto linkerBuilder =
+      [&](llvm::orc::ExecutionSession& session,
+          const llvm::Triple& triple) -> 
std::unique_ptr<llvm::orc::ObjectLayer> {
+    auto GetMemMgr = []() { return 
std::make_unique<llvm::SectionMemoryManager>(); };
+    auto ObjLinkingLayer =
+        std::make_unique<llvm::orc::RTDyldObjectLinkingLayer>(session, 
std::move(GetMemMgr));
+    if (triple.isOSBinFormatCOFF()) {
+      ObjLinkingLayer->setOverrideObjectFlagsWithResponsibilityFlags(true);
+      ObjLinkingLayer->setAutoClaimResponsibilityForObjectSymbols(true);
+    }
+    return ObjLinkingLayer;
   };
 #endif
 
@@ -755,7 +765,7 @@ 
TVM_REGISTER_GLOBAL("target.llvm_version_major").set_body_typed([]() -> int {
 TVM_REGISTER_GLOBAL("runtime.module.loadfile_ll")
     .set_body_typed([](std::string filename, std::string fmt) -> 
runtime::Module {
       auto n = make_object<LLVMModuleNode>();
-      n->SetJITEngine("mcjit");
+      n->SetJITEngine("orcjit");
       n->LoadIR(filename);
       return runtime::Module(n);
     });
diff --git a/tests/python/runtime/test_runtime_module_based_interface.py 
b/tests/python/runtime/test_runtime_module_based_interface.py
index 3f71258768..2c46838b94 100644
--- a/tests/python/runtime/test_runtime_module_based_interface.py
+++ b/tests/python/runtime/test_runtime_module_based_interface.py
@@ -54,7 +54,7 @@ def verify(data):
 
 
 @tvm.testing.requires_llvm
[email protected]("target", ["llvm", "llvm -jit=orcjit"])
[email protected]("target", ["llvm", "llvm -jit=mcjit"])
 def test_legacy_compatibility(target):
     mod, params = relay.testing.synthetic.get_workload()
     with relay.build_config(opt_level=3):
@@ -70,7 +70,7 @@ def test_legacy_compatibility(target):
 
 
 @tvm.testing.requires_llvm
[email protected]("target", ["llvm", "llvm -jit=orcjit"])
[email protected]("target", ["llvm", "llvm -jit=mcjit"])
 def test_cpu(target):
     mod, params = relay.testing.synthetic.get_workload()
     with relay.build_config(opt_level=3):
@@ -113,7 +113,7 @@ def test_cpu_get_graph_json():
 
 
 @tvm.testing.requires_llvm
[email protected]("target", ["llvm", "llvm -jit=orcjit"])
[email protected]("target", ["llvm", "llvm -jit=mcjit"])
 def test_cpu_get_graph_params_run(target):
     mod, params = relay.testing.synthetic.get_workload()
     with tvm.transform.PassContext(opt_level=3):
@@ -592,7 +592,7 @@ def test_remove_package_params():
 
 
 @tvm.testing.requires_llvm
[email protected]("target", ["llvm", "llvm -jit=orcjit"])
[email protected]("target", ["llvm", "llvm -jit=mcjit"])
 def test_debug_graph_executor(target):
     mod, params = relay.testing.synthetic.get_workload()
     with relay.build_config(opt_level=3):
diff --git a/tests/python/runtime/test_runtime_module_load.py 
b/tests/python/runtime/test_runtime_module_load.py
index 3789a1d090..87a8ef9f5e 100644
--- a/tests/python/runtime/test_runtime_module_load.py
+++ b/tests/python/runtime/test_runtime_module_load.py
@@ -44,7 +44,7 @@ print("Finish runtime checking...")
 
 
 @tvm.testing.requires_llvm
[email protected]("target", ["llvm", "llvm -jit=orcjit"])
[email protected]("target", ["llvm", "llvm -jit=mcjit"])
 def test_dso_module_load(target):
     dtype = "int64"
     temp = utils.tempdir()

Reply via email to