This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new cae1af62f9 [LLVM][RUNTIME] Add optional LLVM ORCJIT runtime executor 
(#15964)
cae1af62f9 is described below

commit cae1af62f98efc8ec8a54b986619552fea85154e
Author: Balint Cristian <[email protected]>
AuthorDate: Mon Mar 11 17:25:16 2024 +0200

    [LLVM][RUNTIME] Add optional LLVM ORCJIT runtime executor (#15964)
---
 src/target/llvm/llvm_instance.cc                   |  21 ++-
 src/target/llvm/llvm_instance.h                    |   6 +
 src/target/llvm/llvm_module.cc                     | 197 ++++++++++++++++++---
 src/target/target_kind.cc                          |   2 +
 .../runtime/test_runtime_module_based_interface.py |  21 ++-
 tests/python/runtime/test_runtime_module_load.py   |  13 +-
 tests/python/target/test_target_target.py          |   7 +
 7 files changed, 231 insertions(+), 36 deletions(-)

diff --git a/src/target/llvm/llvm_instance.cc b/src/target/llvm/llvm_instance.cc
index 08ba34cc73..a1359b7850 100644
--- a/src/target/llvm/llvm_instance.cc
+++ b/src/target/llvm/llvm_instance.cc
@@ -256,8 +256,23 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, 
const Target& target) {
     }
   }
 
-  // Target options
+  // LLVM JIT engine options
+  if (const Optional<String>& v = target->GetAttr<String>("jit")) {
+    String value = v.value();
+    if ((value == "mcjit") || (value == "orcjit")) {
+      jit_engine_ = value;
+    } else {
+      LOG(FATAL) << "invalid jit option " << value << " (can be `mcjit` or 
`orcjit`).";
+    }
+  }
 
+  // RISCV code model
+  auto arch = llvm::Triple(triple_).getArch();
+  if (arch == llvm::Triple::riscv32 || arch == llvm::Triple::riscv64) {
+    code_model_ = llvm::CodeModel::Medium;
+  }
+
+  // Target options
 #if TVM_LLVM_VERSION < 50
   target_options_.LessPreciseFPMADOption = true;
 #endif
@@ -525,6 +540,10 @@ std::string LLVMTargetInfo::str() const {
     os << quote << Join(",", opts) << quote;
   }
 
+  if (jit_engine_ != "mcjit") {
+    os << " -jit=" << jit_engine_;
+  }
+
   return os.str();
 }
 
diff --git a/src/target/llvm/llvm_instance.h b/src/target/llvm/llvm_instance.h
index 030a7db721..f3948b7a01 100644
--- a/src/target/llvm/llvm_instance.h
+++ b/src/target/llvm/llvm_instance.h
@@ -212,6 +212,11 @@ class LLVMTargetInfo {
    * \return `llvm::FastMathFlags` for this target
    */
   llvm::FastMathFlags GetFastMathFlags() const { return fast_math_flags_; }
+  /*!
+   * \brief Get the LLVM JIT engine type
+   * \return the type name of the JIT engine (default "mcjit" or "orcjit")
+   */
+  const std::string GetJITEngine() const { return jit_engine_; }
   /*!
    * \brief Get the LLVM optimization level
    * \return optimization level for this target
@@ -324,6 +329,7 @@ class LLVMTargetInfo {
   llvm::Reloc::Model reloc_model_ = llvm::Reloc::PIC_;
   llvm::CodeModel::Model code_model_ = llvm::CodeModel::Small;
   std::shared_ptr<llvm::TargetMachine> target_machine_;
+  std::string jit_engine_ = "mcjit";
 };
 
 /*!
diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc
index 59cd6a76b0..c332314a3e 100644
--- a/src/target/llvm/llvm_module.cc
+++ b/src/target/llvm/llvm_module.cc
@@ -30,7 +30,10 @@
 #include <llvm/ADT/StringRef.h>
 #include <llvm/Bitcode/BitcodeWriter.h>
 #include <llvm/ExecutionEngine/ExecutionEngine.h>
-#include <llvm/ExecutionEngine/MCJIT.h>  // Force linking of MCJIT
+#include <llvm/ExecutionEngine/MCJIT.h>
+#include <llvm/ExecutionEngine/Orc/LLJIT.h>
+#include <llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h>
+#include <llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h>
 #include <llvm/IR/DataLayout.h>
 #include <llvm/IR/Function.h>
 #include <llvm/IR/Intrinsics.h>
@@ -113,8 +116,11 @@ class LLVMModuleNode final : public runtime::ModuleNode {
 
   bool ImplementsFunction(const String& name, bool query_imports) final;
 
+  void SetJITEngine(const std::string& jit_engine) { jit_engine_ = jit_engine; 
}
+
  private:
-  void LazyInitJIT();
+  void InitMCJIT();
+  void InitORCJIT();
   bool IsCompatibleWithHost(const llvm::TargetMachine* tm) const;
   void* GetGlobalAddr(const std::string& name, const LLVMTarget& llvm_target) 
const;
   void* GetFunctionAddr(const std::string& name, const LLVMTarget& 
llvm_target) const;
@@ -123,8 +129,9 @@ class LLVMModuleNode final : public runtime::ModuleNode {
   std::unique_ptr<LLVMInstance> llvm_instance_;
   // JIT lock
   std::mutex mutex_;
-  // execution engine
-  llvm::ExecutionEngine* ee_{nullptr};
+  // jit execution engines
+  llvm::ExecutionEngine* mcjit_ee_{nullptr};
+  std::unique_ptr<llvm::orc::LLJIT> orcjit_ee_{nullptr};
   // The raw pointer to the module.
   llvm::Module* module_{nullptr};
   // The unique_ptr owning the module. This becomes empty once JIT has been 
initialized
@@ -132,12 +139,21 @@ class LLVMModuleNode final : public runtime::ModuleNode {
   std::unique_ptr<llvm::Module> module_owning_ptr_;
   /* \brief names of the external functions declared in this module */
   Array<String> function_names_;
+  std::string jit_engine_;
 };
 
 LLVMModuleNode::~LLVMModuleNode() {
-  if (ee_ != nullptr) {
-    ee_->runStaticConstructorsDestructors(true);
-    delete ee_;
+  if (mcjit_ee_ != nullptr) {
+    mcjit_ee_->runStaticConstructorsDestructors(true);
+    delete mcjit_ee_;
+  }
+  if (orcjit_ee_ != nullptr) {
+    auto dtors = llvm::orc::getDestructors(*module_);
+    auto dtorRunner = 
std::make_unique<llvm::orc::CtorDtorRunner>(orcjit_ee_->getMainJITDylib());
+    dtorRunner->add(dtors);
+    auto err = dtorRunner->run();
+    ICHECK(!err) << llvm::toString(std::move(err));
+    orcjit_ee_.reset();
   }
   module_owning_ptr_.reset();
 }
@@ -166,7 +182,9 @@ PackedFunc LLVMModuleNode::GetFunction(const String& name, 
const ObjectPtr<Objec
     std::string target_string = LLVMTarget::GetTargetMetadata(*module_);
     return PackedFunc([target_string](TVMArgs args, TVMRetValue* rv) { *rv = 
target_string; });
   }
-  if (ee_ == nullptr) LazyInitJIT();
+  ICHECK(jit_engine_.size()) << "JIT engine type is missing";
+  if ((jit_engine_ == "mcjit") && (mcjit_ee_ == nullptr)) InitMCJIT();
+  if ((jit_engine_ == "orcjit") && (orcjit_ee_ == nullptr)) InitORCJIT();
 
   std::lock_guard<std::mutex> lock(mutex_);
 
@@ -353,6 +371,7 @@ void LLVMModuleNode::Init(const IRModule& mod, const 
Target& target) {
 
   module_owning_ptr_ = cg->Finish();
   module_ = module_owning_ptr_.get();
+  jit_engine_ = llvm_target->GetJITEngine();
   llvm_target->SetTargetMetadata(module_);
   module_->addModuleFlag(llvm::Module::Override, "Debug Info Version",
                          llvm::DEBUG_METADATA_VERSION);
@@ -384,13 +403,16 @@ bool LLVMModuleNode::ImplementsFunction(const String& 
name, bool query_imports)
   return std::find(function_names_.begin(), function_names_.end(), name) != 
function_names_.end();
 }
 
-void LLVMModuleNode::LazyInitJIT() {
+void LLVMModuleNode::InitMCJIT() {
   std::lock_guard<std::mutex> lock(mutex_);
-  if (ee_) {
+  if (mcjit_ee_) {
     return;
   }
+  // MCJIT builder
   With<LLVMTarget> llvm_target(*llvm_instance_, 
LLVMTarget::GetTargetMetadata(*module_));
   llvm::EngineBuilder builder(std::move(module_owning_ptr_));
+
+  // set options
   builder.setEngineKind(llvm::EngineKind::JIT);
 #if TVM_LLVM_VERSION <= 170
   builder.setOptLevel(llvm::CodeGenOpt::Aggressive);
@@ -400,18 +422,31 @@ void LLVMModuleNode::LazyInitJIT() {
   builder.setMCPU(llvm_target->GetCPU());
   builder.setMAttrs(llvm_target->GetTargetFeatures());
   builder.setTargetOptions(llvm_target->GetTargetOptions());
+
+  // create the taget machine
   auto tm = std::unique_ptr<llvm::TargetMachine>(builder.selectTarget());
   if (!IsCompatibleWithHost(tm.get())) {
     LOG(FATAL) << "Cannot run module, architecture mismatch";
   }
+
+  // data layout
   llvm::DataLayout layout(tm->createDataLayout());
   ICHECK(layout == module_->getDataLayout())
       << "Data layout mismatch between module("
       << module_->getDataLayout().getStringRepresentation() << ")"
       << " and ExecutionEngine (" << layout.getStringRepresentation() << ")";
-  ee_ = builder.create(tm.release());
-  ICHECK(ee_ != nullptr) << "Failed to initialize jit engine for " << 
module_->getTargetTriple();
-  ee_->runStaticConstructorsDestructors(false);
+
+  // create MCJIT
+  mcjit_ee_ = builder.create(tm.release());
+  ICHECK(mcjit_ee_ != nullptr) << "Failed to initialize LLVM MCJIT engine for "
+                               << module_->getTargetTriple();
+
+  VLOG(2) << "LLVM MCJIT execute " << module_->getModuleIdentifier() << " for 
triple `"
+          << llvm_target->GetTargetTriple() << "`"
+          << " on cpu `" << llvm_target->GetCPU() << "`";
+
+  // run ctors
+  mcjit_ee_->runStaticConstructorsDestructors(false);
 
   if (void** ctx_addr =
           
reinterpret_cast<void**>(GetGlobalAddr(runtime::symbol::tvm_module_ctx, 
*llvm_target))) {
@@ -424,7 +459,104 @@ void LLVMModuleNode::LazyInitJIT() {
   // lead to a runtime crash.
   // Do name lookup on a symbol that doesn't exist. This will force MCJIT to 
finalize
   // all loaded objects, which will resolve symbols in JITed code.
-  
ee_->getFunctionAddress("__some_name_that_hopefully_doesnt_exist__b49f8aaade5877eaba7583b91");
+  mcjit_ee_->getFunctionAddress(
+      "__some_name_that_hopefully_doesnt_exist__b49f8aaade5877eaba7583b91");
+}
+
+void LLVMModuleNode::InitORCJIT() {
+  std::lock_guard<std::mutex> lock(mutex_);
+  if (orcjit_ee_) {
+    return;
+  }
+  // ORCJIT builder
+  With<LLVMTarget> llvm_target(*llvm_instance_, 
LLVMTarget::GetTargetMetadata(*module_));
+  llvm::orc::JITTargetMachineBuilder 
tm_builder(llvm::Triple(llvm_target->GetTargetTriple()));
+
+  // set options
+  tm_builder.setCPU(llvm_target->GetCPU());
+  tm_builder.setFeatures(llvm_target->GetTargetFeatureString());
+  tm_builder.setOptions(llvm_target->GetTargetOptions());
+#if TVM_LLVM_VERSION <= 170
+  tm_builder.setCodeGenOptLevel(llvm::CodeGenOpt::Aggressive);
+#else
+  tm_builder.setCodeGenOptLevel(llvm::CodeGenOptLevel::Aggressive);
+#endif
+
+  // create the taget machine
+  std::unique_ptr<llvm::TargetMachine> tm = 
llvm::cantFail(tm_builder.createTargetMachine());
+  if (!IsCompatibleWithHost(tm.get())) {
+    LOG(FATAL) << "Cannot run module, architecture mismatch";
+  }
+
+  // data layout
+  String module_name = module_->getModuleIdentifier();
+  llvm::DataLayout layout(tm->createDataLayout());
+  ICHECK(layout == module_->getDataLayout())
+      << "Data layout mismatch between module("
+      << module_->getDataLayout().getStringRepresentation() << ")"
+      << " and ExecutionEngine (" << layout.getStringRepresentation() << ")";
+
+  // compiler
+  const auto compilerBuilder = [&](const llvm::orc::JITTargetMachineBuilder&)
+      -> 
llvm::Expected<std::unique_ptr<llvm::orc::IRCompileLayer::IRCompiler>> {
+    return std::make_unique<llvm::orc::TMOwningSimpleCompiler>(std::move(tm));
+  };
+
+#if TVM_LLVM_VERSION >= 130
+  // linker
+  const auto linkerBuilder = [&](llvm::orc::ExecutionSession& session, const 
llvm::Triple&) {
+    return std::make_unique<llvm::orc::ObjectLinkingLayer>(session);
+  };
+#endif
+
+  // create LLJIT
+  orcjit_ee_ = llvm::cantFail(llvm::orc::LLJITBuilder()
+#if TVM_LLVM_VERSION >= 110
+                                  .setDataLayout(layout)
+#endif
+                                  .setCompileFunctionCreator(compilerBuilder)
+#if TVM_LLVM_VERSION >= 130
+                                  .setObjectLinkingLayerCreator(linkerBuilder)
+#endif
+                                  .create());
+
+  ICHECK(orcjit_ee_ != nullptr) << "Failed to initialize LLVM ORCJIT engine 
for "
+                                << module_->getTargetTriple();
+
+  // store ctors
+  auto ctors = llvm::orc::getConstructors(*module_);
+  llvm::orc::CtorDtorRunner ctorRunner(orcjit_ee_->getMainJITDylib());
+  ctorRunner.add(ctors);
+
+  // resolve system symbols (like pthread, dl, m, etc.)
+  auto gen =
+      
llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(layout.getGlobalPrefix());
+  ICHECK(gen) << llvm::toString(gen.takeError()) << "\n";
+  orcjit_ee_->getMainJITDylib().addGenerator(std::move(gen.get()));
+
+  // transfer module to a clone
+  auto uctx = std::make_unique<llvm::LLVMContext>();
+  auto umod = llvm::CloneModule(*(std::move(module_owning_ptr_)));
+
+  // add the llvm module to run
+  llvm::orc::ThreadSafeModule tsm(std::move(umod), std::move(uctx));
+  auto err = orcjit_ee_->addIRModule(std::move(tsm));
+  ICHECK(!err) << llvm::toString(std::move(err));
+
+  VLOG(2) << "LLVM ORCJIT execute " << module_->getModuleIdentifier() << " for 
triple `"
+          << llvm_target->GetTargetTriple() << "`"
+          << " on cpu `" << llvm_target->GetCPU() << "`";
+
+  // run ctors
+  err = ctorRunner.run();
+  ICHECK(!err) << llvm::toString(std::move(err));
+
+  if (void** ctx_addr =
+          
reinterpret_cast<void**>(GetGlobalAddr(runtime::symbol::tvm_module_ctx, 
*llvm_target))) {
+    *ctx_addr = this;
+  }
+  runtime::InitContextFunctions(
+      [this, &llvm_target](const char* name) { return GetGlobalAddr(name, 
*llvm_target); });
 }
 
 bool LLVMModuleNode::IsCompatibleWithHost(const llvm::TargetMachine* tm) const 
{
@@ -442,20 +574,40 @@ bool LLVMModuleNode::IsCompatibleWithHost(const 
llvm::TargetMachine* tm) const {
 void* LLVMModuleNode::GetGlobalAddr(const std::string& name, const LLVMTarget& 
llvm_target) const {
   // first verifies if GV exists.
   if (module_->getGlobalVariable(name) != nullptr) {
-    return reinterpret_cast<void*>(ee_->getGlobalValueAddress(name));
-  } else {
-    return nullptr;
+    if (jit_engine_ == "mcjit") {
+      return reinterpret_cast<void*>(mcjit_ee_->getGlobalValueAddress(name));
+    } else if (jit_engine_ == "orcjit") {
+#if TVM_LLVM_VERSION >= 150
+      auto addr = llvm::cantFail(orcjit_ee_->lookup(name)).getValue();
+#else
+      auto addr = llvm::cantFail(orcjit_ee_->lookup(name)).getAddress();
+#endif
+      return reinterpret_cast<void*>(addr);
+    } else {
+      LOG(FATAL) << "Either `mcjit` or `orcjit` are not initialized.";
+    }
   }
+  return nullptr;
 }
 
 void* LLVMModuleNode::GetFunctionAddr(const std::string& name,
                                       const LLVMTarget& llvm_target) const {
   // first verifies if GV exists.
   if (module_->getFunction(name) != nullptr) {
-    return reinterpret_cast<void*>(ee_->getFunctionAddress(name));
-  } else {
-    return nullptr;
+    if (jit_engine_ == "mcjit") {
+      return reinterpret_cast<void*>(mcjit_ee_->getFunctionAddress(name));
+    } else if (jit_engine_ == "orcjit") {
+#if TVM_LLVM_VERSION >= 150
+      auto addr = llvm::cantFail(orcjit_ee_->lookup(name)).getValue();
+#else
+      auto addr = llvm::cantFail(orcjit_ee_->lookup(name)).getAddress();
+#endif
+      return reinterpret_cast<void*>(addr);
+    } else {
+      LOG(FATAL) << "Either `mcjit` or `orcjit` are not initialized.";
+    }
   }
+  return nullptr;
 }
 
 TVM_REGISTER_GLOBAL("target.build.llvm")
@@ -476,6 +628,7 @@ TVM_REGISTER_GLOBAL("codegen.LLVMModuleCreate")
       module->setTargetTriple(llvm_target->GetTargetTriple());
       
module->setDataLayout(llvm_target->GetOrCreateTargetMachine()->createDataLayout());
       n->Init(std::move(module), std::move(llvm_instance));
+      n->SetJITEngine(llvm_target->GetJITEngine());
       return runtime::Module(n);
     });
 
@@ -595,6 +748,7 @@ 
TVM_REGISTER_GLOBAL("target.llvm_version_major").set_body_typed([]() -> int {
 TVM_REGISTER_GLOBAL("runtime.module.loadfile_ll")
     .set_body_typed([](std::string filename, std::string fmt) -> 
runtime::Module {
       auto n = make_object<LLVMModuleNode>();
+      n->SetJITEngine("mcjit");
       n->LoadIR(filename);
       return runtime::Module(n);
     });
@@ -616,6 +770,7 @@ TVM_REGISTER_GLOBAL("codegen.codegen_blob")
       std::unique_ptr<llvm::Module> blob =
           CodeGenBlob(data, system_lib, llvm_target.get(), c_symbol_prefix);
       n->Init(std::move(blob), std::move(llvm_instance));
+      n->SetJITEngine(llvm_target->GetJITEngine());
       return runtime::Module(n);
     });
 
@@ -645,6 +800,7 @@ runtime::Module 
CreateLLVMCppMetadataModule(runtime::metadata::Metadata metadata
 
   auto n = make_object<LLVMModuleNode>();
   n->Init(std::move(mod), std::move(llvm_instance));
+  n->SetJITEngine(llvm_target->GetJITEngine());
 
   auto meta_mod = MetadataModuleCreate(metadata);
   meta_mod->Import(runtime::Module(n));
@@ -691,6 +847,7 @@ runtime::Module CreateLLVMCrtMetadataModule(const 
Array<runtime::Module>& module
 
   auto n = make_object<LLVMModuleNode>();
   n->Init(std::move(mod), std::move(llvm_instance));
+  n->SetJITEngine(llvm_target->GetJITEngine());
   for (auto m : modules) {
     n->Import(m);
   }
diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc
index aa4499ec96..28c7e06629 100644
--- a/src/target/target_kind.cc
+++ b/src/target/target_kind.cc
@@ -291,6 +291,8 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU)
     .add_attr_option<Integer>("opt-level")
     // LLVM command line flags, see below
     .add_attr_option<Array<String>>("cl-opt")
+    // LLVM JIT engine mcjit/orcjit
+    .add_attr_option<String>("jit")
     .set_default_keys({"cpu"})
     // Force the external codegen kind attribute to be registered, even if no 
external
     // codegen targets are enabled by the TVM build.
diff --git a/tests/python/runtime/test_runtime_module_based_interface.py 
b/tests/python/runtime/test_runtime_module_based_interface.py
index 6e62e3f215..55edbdaccb 100644
--- a/tests/python/runtime/test_runtime_module_based_interface.py
+++ b/tests/python/runtime/test_runtime_module_based_interface.py
@@ -23,6 +23,7 @@ from tvm.contrib import graph_executor
 from tvm.contrib.debugger import debug_executor
 from tvm.contrib.cuda_graph import cuda_graph_executor
 import tvm.testing
+import pytest
 
 
 def input_shape(mod):
@@ -48,10 +49,11 @@ def verify(data):
 
 
 @tvm.testing.requires_llvm
-def test_legacy_compatibility():
[email protected]("target", ["llvm", "llvm -jit=orcjit"])
+def test_legacy_compatibility(target):
     mod, params = relay.testing.synthetic.get_workload()
     with relay.build_config(opt_level=3):
-        graph, lib, graph_params = relay.build_module.build(mod, "llvm", 
params=params)
+        graph, lib, graph_params = relay.build_module.build(mod, target, 
params=params)
     data = np.random.uniform(-1, 1, size=input_shape(mod)).astype("float32")
     dev = tvm.cpu()
     module = graph_executor.create(graph, lib, dev)
@@ -63,10 +65,11 @@ def test_legacy_compatibility():
 
 
 @tvm.testing.requires_llvm
-def test_cpu():
[email protected]("target", ["llvm", "llvm -jit=orcjit"])
+def test_cpu(target):
     mod, params = relay.testing.synthetic.get_workload()
     with relay.build_config(opt_level=3):
-        complied_graph_lib = relay.build_module.build(mod, "llvm", 
params=params)
+        complied_graph_lib = relay.build_module.build(mod, target, 
params=params)
     data = np.random.uniform(-1, 1, size=input_shape(mod)).astype("float32")
     # raw api
     dev = tvm.cpu()
@@ -105,10 +108,11 @@ def test_cpu_get_graph_json():
 
 
 @tvm.testing.requires_llvm
-def test_cpu_get_graph_params_run():
[email protected]("target", ["llvm", "llvm -jit=orcjit"])
+def test_cpu_get_graph_params_run(target):
     mod, params = relay.testing.synthetic.get_workload()
     with tvm.transform.PassContext(opt_level=3):
-        complied_graph_lib = relay.build_module.build(mod, "llvm", 
params=params)
+        complied_graph_lib = relay.build_module.build(mod, target, 
params=params)
     data = np.random.uniform(-1, 1, size=input_shape(mod)).astype("float32")
     dev = tvm.cpu()
     from tvm.contrib import utils
@@ -584,10 +588,11 @@ def test_remove_package_params():
 
 
 @tvm.testing.requires_llvm
-def test_debug_graph_executor():
[email protected]("target", ["llvm", "llvm -jit=orcjit"])
+def test_debug_graph_executor(target):
     mod, params = relay.testing.synthetic.get_workload()
     with relay.build_config(opt_level=3):
-        complied_graph_lib = relay.build_module.build(mod, "llvm", 
params=params)
+        complied_graph_lib = relay.build_module.build(mod, target, 
params=params)
     data = np.random.uniform(-1, 1, size=input_shape(mod)).astype("float32")
 
     # raw api
diff --git a/tests/python/runtime/test_runtime_module_load.py 
b/tests/python/runtime/test_runtime_module_load.py
index ecaa7067a5..3789a1d090 100644
--- a/tests/python/runtime/test_runtime_module_load.py
+++ b/tests/python/runtime/test_runtime_module_load.py
@@ -22,6 +22,7 @@ import numpy as np
 import subprocess
 import tvm.testing
 from tvm.relay.backend import Runtime
+import pytest
 
 runtime_py = """
 import os
@@ -42,9 +43,9 @@ print("Finish runtime checking...")
 """
 
 
-def test_dso_module_load():
-    if not tvm.testing.device_enabled("llvm"):
-        return
[email protected]_llvm
[email protected]("target", ["llvm", "llvm -jit=orcjit"])
+def test_dso_module_load(target):
     dtype = "int64"
     temp = utils.tempdir()
 
@@ -63,7 +64,7 @@ def test_dso_module_load():
         mod = tvm.IRModule.from_expr(
             tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "main")
         )
-        m = tvm.driver.build(mod, target="llvm")
+        m = tvm.driver.build(mod, target=target)
         for name in names:
             m.save(name)
 
@@ -167,6 +168,7 @@ def test_device_module_dump():
         check_stackvm(device)
 
 
[email protected]_llvm
 def test_combine_module_llvm():
     """Test combine multiple module into one shared lib."""
     # graph
@@ -178,9 +180,6 @@ def test_combine_module_llvm():
 
     def check_llvm():
         dev = tvm.cpu(0)
-        if not tvm.testing.device_enabled("llvm"):
-            print("Skip because llvm is not enabled")
-            return
         temp = utils.tempdir()
         fadd1 = tvm.build(s, [A, B], "llvm", name="myadd1")
         fadd2 = tvm.build(s, [A, B], "llvm", name="myadd2")
diff --git a/tests/python/target/test_target_target.py 
b/tests/python/target/test_target_target.py
index d5e8d06025..83bd864970 100644
--- a/tests/python/target/test_target_target.py
+++ b/tests/python/target/test_target_target.py
@@ -171,6 +171,13 @@ def test_target_llvm_options():
     )
 
 
+def test_target_llvm_jit_options():
+    target = tvm.target.Target("llvm -jit=mcjit")
+    assert target.attrs["jit"] == "mcjit"
+    target = tvm.target.Target("llvm -jit=orcjit")
+    assert target.attrs["jit"] == "orcjit"
+
+
 def test_target_create():
     targets = [cuda(), rocm(), mali(), intel_graphics(), arm_cpu("rk3399"), 
vta(), bifrost()]
     for tgt in targets:

Reply via email to