This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new cae1af62f9 [LLVM][RUNTIME] Add optional LLVM ORCJIT runtime executor
(#15964)
cae1af62f9 is described below
commit cae1af62f98efc8ec8a54b986619552fea85154e
Author: Balint Cristian <[email protected]>
AuthorDate: Mon Mar 11 17:25:16 2024 +0200
[LLVM][RUNTIME] Add optional LLVM ORCJIT runtime executor (#15964)
---
src/target/llvm/llvm_instance.cc | 21 ++-
src/target/llvm/llvm_instance.h | 6 +
src/target/llvm/llvm_module.cc | 197 ++++++++++++++++++---
src/target/target_kind.cc | 2 +
.../runtime/test_runtime_module_based_interface.py | 21 ++-
tests/python/runtime/test_runtime_module_load.py | 13 +-
tests/python/target/test_target_target.py | 7 +
7 files changed, 231 insertions(+), 36 deletions(-)
diff --git a/src/target/llvm/llvm_instance.cc b/src/target/llvm/llvm_instance.cc
index 08ba34cc73..a1359b7850 100644
--- a/src/target/llvm/llvm_instance.cc
+++ b/src/target/llvm/llvm_instance.cc
@@ -256,8 +256,23 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance,
const Target& target) {
}
}
- // Target options
+ // LLVM JIT engine options
+ if (const Optional<String>& v = target->GetAttr<String>("jit")) {
+ String value = v.value();
+ if ((value == "mcjit") || (value == "orcjit")) {
+ jit_engine_ = value;
+ } else {
+ LOG(FATAL) << "invalid jit option " << value << " (can be `mcjit` or
`orcjit`).";
+ }
+ }
+ // RISCV code model
+ auto arch = llvm::Triple(triple_).getArch();
+ if (arch == llvm::Triple::riscv32 || arch == llvm::Triple::riscv64) {
+ code_model_ = llvm::CodeModel::Medium;
+ }
+
+ // Target options
#if TVM_LLVM_VERSION < 50
target_options_.LessPreciseFPMADOption = true;
#endif
@@ -525,6 +540,10 @@ std::string LLVMTargetInfo::str() const {
os << quote << Join(",", opts) << quote;
}
+ if (jit_engine_ != "mcjit") {
+ os << " -jit=" << jit_engine_;
+ }
+
return os.str();
}
diff --git a/src/target/llvm/llvm_instance.h b/src/target/llvm/llvm_instance.h
index 030a7db721..f3948b7a01 100644
--- a/src/target/llvm/llvm_instance.h
+++ b/src/target/llvm/llvm_instance.h
@@ -212,6 +212,11 @@ class LLVMTargetInfo {
* \return `llvm::FastMathFlags` for this target
*/
llvm::FastMathFlags GetFastMathFlags() const { return fast_math_flags_; }
+ /*!
+ * \brief Get the LLVM JIT engine type
+ * \return the type name of the JIT engine (default "mcjit" or "orcjit")
+ */
+ const std::string GetJITEngine() const { return jit_engine_; }
/*!
* \brief Get the LLVM optimization level
* \return optimization level for this target
@@ -324,6 +329,7 @@ class LLVMTargetInfo {
llvm::Reloc::Model reloc_model_ = llvm::Reloc::PIC_;
llvm::CodeModel::Model code_model_ = llvm::CodeModel::Small;
std::shared_ptr<llvm::TargetMachine> target_machine_;
+ std::string jit_engine_ = "mcjit";
};
/*!
diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc
index 59cd6a76b0..c332314a3e 100644
--- a/src/target/llvm/llvm_module.cc
+++ b/src/target/llvm/llvm_module.cc
@@ -30,7 +30,10 @@
#include <llvm/ADT/StringRef.h>
#include <llvm/Bitcode/BitcodeWriter.h>
#include <llvm/ExecutionEngine/ExecutionEngine.h>
-#include <llvm/ExecutionEngine/MCJIT.h> // Force linking of MCJIT
+#include <llvm/ExecutionEngine/MCJIT.h>
+#include <llvm/ExecutionEngine/Orc/LLJIT.h>
+#include <llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h>
+#include <llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h>
#include <llvm/IR/DataLayout.h>
#include <llvm/IR/Function.h>
#include <llvm/IR/Intrinsics.h>
@@ -113,8 +116,11 @@ class LLVMModuleNode final : public runtime::ModuleNode {
bool ImplementsFunction(const String& name, bool query_imports) final;
+ void SetJITEngine(const std::string& jit_engine) { jit_engine_ = jit_engine;
}
+
private:
- void LazyInitJIT();
+ void InitMCJIT();
+ void InitORCJIT();
bool IsCompatibleWithHost(const llvm::TargetMachine* tm) const;
void* GetGlobalAddr(const std::string& name, const LLVMTarget& llvm_target)
const;
void* GetFunctionAddr(const std::string& name, const LLVMTarget&
llvm_target) const;
@@ -123,8 +129,9 @@ class LLVMModuleNode final : public runtime::ModuleNode {
std::unique_ptr<LLVMInstance> llvm_instance_;
// JIT lock
std::mutex mutex_;
- // execution engine
- llvm::ExecutionEngine* ee_{nullptr};
+ // jit execution engines
+ llvm::ExecutionEngine* mcjit_ee_{nullptr};
+ std::unique_ptr<llvm::orc::LLJIT> orcjit_ee_{nullptr};
// The raw pointer to the module.
llvm::Module* module_{nullptr};
// The unique_ptr owning the module. This becomes empty once JIT has been
initialized
@@ -132,12 +139,21 @@ class LLVMModuleNode final : public runtime::ModuleNode {
std::unique_ptr<llvm::Module> module_owning_ptr_;
/* \brief names of the external functions declared in this module */
Array<String> function_names_;
+ std::string jit_engine_;
};
LLVMModuleNode::~LLVMModuleNode() {
- if (ee_ != nullptr) {
- ee_->runStaticConstructorsDestructors(true);
- delete ee_;
+ if (mcjit_ee_ != nullptr) {
+ mcjit_ee_->runStaticConstructorsDestructors(true);
+ delete mcjit_ee_;
+ }
+ if (orcjit_ee_ != nullptr) {
+ auto dtors = llvm::orc::getDestructors(*module_);
+ auto dtorRunner =
std::make_unique<llvm::orc::CtorDtorRunner>(orcjit_ee_->getMainJITDylib());
+ dtorRunner->add(dtors);
+ auto err = dtorRunner->run();
+ ICHECK(!err) << llvm::toString(std::move(err));
+ orcjit_ee_.reset();
}
module_owning_ptr_.reset();
}
@@ -166,7 +182,9 @@ PackedFunc LLVMModuleNode::GetFunction(const String& name,
const ObjectPtr<Objec
std::string target_string = LLVMTarget::GetTargetMetadata(*module_);
return PackedFunc([target_string](TVMArgs args, TVMRetValue* rv) { *rv =
target_string; });
}
- if (ee_ == nullptr) LazyInitJIT();
+ ICHECK(jit_engine_.size()) << "JIT engine type is missing";
+ if ((jit_engine_ == "mcjit") && (mcjit_ee_ == nullptr)) InitMCJIT();
+ if ((jit_engine_ == "orcjit") && (orcjit_ee_ == nullptr)) InitORCJIT();
std::lock_guard<std::mutex> lock(mutex_);
@@ -353,6 +371,7 @@ void LLVMModuleNode::Init(const IRModule& mod, const
Target& target) {
module_owning_ptr_ = cg->Finish();
module_ = module_owning_ptr_.get();
+ jit_engine_ = llvm_target->GetJITEngine();
llvm_target->SetTargetMetadata(module_);
module_->addModuleFlag(llvm::Module::Override, "Debug Info Version",
llvm::DEBUG_METADATA_VERSION);
@@ -384,13 +403,16 @@ bool LLVMModuleNode::ImplementsFunction(const String&
name, bool query_imports)
return std::find(function_names_.begin(), function_names_.end(), name) !=
function_names_.end();
}
-void LLVMModuleNode::LazyInitJIT() {
+void LLVMModuleNode::InitMCJIT() {
std::lock_guard<std::mutex> lock(mutex_);
- if (ee_) {
+ if (mcjit_ee_) {
return;
}
+ // MCJIT builder
With<LLVMTarget> llvm_target(*llvm_instance_,
LLVMTarget::GetTargetMetadata(*module_));
llvm::EngineBuilder builder(std::move(module_owning_ptr_));
+
+ // set options
builder.setEngineKind(llvm::EngineKind::JIT);
#if TVM_LLVM_VERSION <= 170
builder.setOptLevel(llvm::CodeGenOpt::Aggressive);
@@ -400,18 +422,31 @@ void LLVMModuleNode::LazyInitJIT() {
builder.setMCPU(llvm_target->GetCPU());
builder.setMAttrs(llvm_target->GetTargetFeatures());
builder.setTargetOptions(llvm_target->GetTargetOptions());
+
+ // create the taget machine
auto tm = std::unique_ptr<llvm::TargetMachine>(builder.selectTarget());
if (!IsCompatibleWithHost(tm.get())) {
LOG(FATAL) << "Cannot run module, architecture mismatch";
}
+
+ // data layout
llvm::DataLayout layout(tm->createDataLayout());
ICHECK(layout == module_->getDataLayout())
<< "Data layout mismatch between module("
<< module_->getDataLayout().getStringRepresentation() << ")"
<< " and ExecutionEngine (" << layout.getStringRepresentation() << ")";
- ee_ = builder.create(tm.release());
- ICHECK(ee_ != nullptr) << "Failed to initialize jit engine for " <<
module_->getTargetTriple();
- ee_->runStaticConstructorsDestructors(false);
+
+ // create MCJIT
+ mcjit_ee_ = builder.create(tm.release());
+ ICHECK(mcjit_ee_ != nullptr) << "Failed to initialize LLVM MCJIT engine for "
+ << module_->getTargetTriple();
+
+ VLOG(2) << "LLVM MCJIT execute " << module_->getModuleIdentifier() << " for
triple `"
+ << llvm_target->GetTargetTriple() << "`"
+ << " on cpu `" << llvm_target->GetCPU() << "`";
+
+ // run ctors
+ mcjit_ee_->runStaticConstructorsDestructors(false);
if (void** ctx_addr =
reinterpret_cast<void**>(GetGlobalAddr(runtime::symbol::tvm_module_ctx,
*llvm_target))) {
@@ -424,7 +459,104 @@ void LLVMModuleNode::LazyInitJIT() {
// lead to a runtime crash.
// Do name lookup on a symbol that doesn't exist. This will force MCJIT to
finalize
// all loaded objects, which will resolve symbols in JITed code.
-
ee_->getFunctionAddress("__some_name_that_hopefully_doesnt_exist__b49f8aaade5877eaba7583b91");
+ mcjit_ee_->getFunctionAddress(
+ "__some_name_that_hopefully_doesnt_exist__b49f8aaade5877eaba7583b91");
+}
+
+void LLVMModuleNode::InitORCJIT() {
+ std::lock_guard<std::mutex> lock(mutex_);
+ if (orcjit_ee_) {
+ return;
+ }
+ // ORCJIT builder
+ With<LLVMTarget> llvm_target(*llvm_instance_,
LLVMTarget::GetTargetMetadata(*module_));
+ llvm::orc::JITTargetMachineBuilder
tm_builder(llvm::Triple(llvm_target->GetTargetTriple()));
+
+ // set options
+ tm_builder.setCPU(llvm_target->GetCPU());
+ tm_builder.setFeatures(llvm_target->GetTargetFeatureString());
+ tm_builder.setOptions(llvm_target->GetTargetOptions());
+#if TVM_LLVM_VERSION <= 170
+ tm_builder.setCodeGenOptLevel(llvm::CodeGenOpt::Aggressive);
+#else
+ tm_builder.setCodeGenOptLevel(llvm::CodeGenOptLevel::Aggressive);
+#endif
+
+ // create the taget machine
+ std::unique_ptr<llvm::TargetMachine> tm =
llvm::cantFail(tm_builder.createTargetMachine());
+ if (!IsCompatibleWithHost(tm.get())) {
+ LOG(FATAL) << "Cannot run module, architecture mismatch";
+ }
+
+ // data layout
+ String module_name = module_->getModuleIdentifier();
+ llvm::DataLayout layout(tm->createDataLayout());
+ ICHECK(layout == module_->getDataLayout())
+ << "Data layout mismatch between module("
+ << module_->getDataLayout().getStringRepresentation() << ")"
+ << " and ExecutionEngine (" << layout.getStringRepresentation() << ")";
+
+ // compiler
+ const auto compilerBuilder = [&](const llvm::orc::JITTargetMachineBuilder&)
+ ->
llvm::Expected<std::unique_ptr<llvm::orc::IRCompileLayer::IRCompiler>> {
+ return std::make_unique<llvm::orc::TMOwningSimpleCompiler>(std::move(tm));
+ };
+
+#if TVM_LLVM_VERSION >= 130
+ // linker
+ const auto linkerBuilder = [&](llvm::orc::ExecutionSession& session, const
llvm::Triple&) {
+ return std::make_unique<llvm::orc::ObjectLinkingLayer>(session);
+ };
+#endif
+
+ // create LLJIT
+ orcjit_ee_ = llvm::cantFail(llvm::orc::LLJITBuilder()
+#if TVM_LLVM_VERSION >= 110
+ .setDataLayout(layout)
+#endif
+ .setCompileFunctionCreator(compilerBuilder)
+#if TVM_LLVM_VERSION >= 130
+ .setObjectLinkingLayerCreator(linkerBuilder)
+#endif
+ .create());
+
+ ICHECK(orcjit_ee_ != nullptr) << "Failed to initialize LLVM ORCJIT engine
for "
+ << module_->getTargetTriple();
+
+ // store ctors
+ auto ctors = llvm::orc::getConstructors(*module_);
+ llvm::orc::CtorDtorRunner ctorRunner(orcjit_ee_->getMainJITDylib());
+ ctorRunner.add(ctors);
+
+ // resolve system symbols (like pthread, dl, m, etc.)
+ auto gen =
+
llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(layout.getGlobalPrefix());
+ ICHECK(gen) << llvm::toString(gen.takeError()) << "\n";
+ orcjit_ee_->getMainJITDylib().addGenerator(std::move(gen.get()));
+
+ // transfer module to a clone
+ auto uctx = std::make_unique<llvm::LLVMContext>();
+ auto umod = llvm::CloneModule(*(std::move(module_owning_ptr_)));
+
+ // add the llvm module to run
+ llvm::orc::ThreadSafeModule tsm(std::move(umod), std::move(uctx));
+ auto err = orcjit_ee_->addIRModule(std::move(tsm));
+ ICHECK(!err) << llvm::toString(std::move(err));
+
+ VLOG(2) << "LLVM ORCJIT execute " << module_->getModuleIdentifier() << " for
triple `"
+ << llvm_target->GetTargetTriple() << "`"
+ << " on cpu `" << llvm_target->GetCPU() << "`";
+
+ // run ctors
+ err = ctorRunner.run();
+ ICHECK(!err) << llvm::toString(std::move(err));
+
+ if (void** ctx_addr =
+
reinterpret_cast<void**>(GetGlobalAddr(runtime::symbol::tvm_module_ctx,
*llvm_target))) {
+ *ctx_addr = this;
+ }
+ runtime::InitContextFunctions(
+ [this, &llvm_target](const char* name) { return GetGlobalAddr(name,
*llvm_target); });
}
bool LLVMModuleNode::IsCompatibleWithHost(const llvm::TargetMachine* tm) const
{
@@ -442,20 +574,40 @@ bool LLVMModuleNode::IsCompatibleWithHost(const
llvm::TargetMachine* tm) const {
void* LLVMModuleNode::GetGlobalAddr(const std::string& name, const LLVMTarget&
llvm_target) const {
// first verifies if GV exists.
if (module_->getGlobalVariable(name) != nullptr) {
- return reinterpret_cast<void*>(ee_->getGlobalValueAddress(name));
- } else {
- return nullptr;
+ if (jit_engine_ == "mcjit") {
+ return reinterpret_cast<void*>(mcjit_ee_->getGlobalValueAddress(name));
+ } else if (jit_engine_ == "orcjit") {
+#if TVM_LLVM_VERSION >= 150
+ auto addr = llvm::cantFail(orcjit_ee_->lookup(name)).getValue();
+#else
+ auto addr = llvm::cantFail(orcjit_ee_->lookup(name)).getAddress();
+#endif
+ return reinterpret_cast<void*>(addr);
+ } else {
+ LOG(FATAL) << "Either `mcjit` or `orcjit` are not initialized.";
+ }
}
+ return nullptr;
}
void* LLVMModuleNode::GetFunctionAddr(const std::string& name,
const LLVMTarget& llvm_target) const {
// first verifies if GV exists.
if (module_->getFunction(name) != nullptr) {
- return reinterpret_cast<void*>(ee_->getFunctionAddress(name));
- } else {
- return nullptr;
+ if (jit_engine_ == "mcjit") {
+ return reinterpret_cast<void*>(mcjit_ee_->getFunctionAddress(name));
+ } else if (jit_engine_ == "orcjit") {
+#if TVM_LLVM_VERSION >= 150
+ auto addr = llvm::cantFail(orcjit_ee_->lookup(name)).getValue();
+#else
+ auto addr = llvm::cantFail(orcjit_ee_->lookup(name)).getAddress();
+#endif
+ return reinterpret_cast<void*>(addr);
+ } else {
+ LOG(FATAL) << "Either `mcjit` or `orcjit` are not initialized.";
+ }
}
+ return nullptr;
}
TVM_REGISTER_GLOBAL("target.build.llvm")
@@ -476,6 +628,7 @@ TVM_REGISTER_GLOBAL("codegen.LLVMModuleCreate")
module->setTargetTriple(llvm_target->GetTargetTriple());
module->setDataLayout(llvm_target->GetOrCreateTargetMachine()->createDataLayout());
n->Init(std::move(module), std::move(llvm_instance));
+ n->SetJITEngine(llvm_target->GetJITEngine());
return runtime::Module(n);
});
@@ -595,6 +748,7 @@
TVM_REGISTER_GLOBAL("target.llvm_version_major").set_body_typed([]() -> int {
TVM_REGISTER_GLOBAL("runtime.module.loadfile_ll")
.set_body_typed([](std::string filename, std::string fmt) ->
runtime::Module {
auto n = make_object<LLVMModuleNode>();
+ n->SetJITEngine("mcjit");
n->LoadIR(filename);
return runtime::Module(n);
});
@@ -616,6 +770,7 @@ TVM_REGISTER_GLOBAL("codegen.codegen_blob")
std::unique_ptr<llvm::Module> blob =
CodeGenBlob(data, system_lib, llvm_target.get(), c_symbol_prefix);
n->Init(std::move(blob), std::move(llvm_instance));
+ n->SetJITEngine(llvm_target->GetJITEngine());
return runtime::Module(n);
});
@@ -645,6 +800,7 @@ runtime::Module
CreateLLVMCppMetadataModule(runtime::metadata::Metadata metadata
auto n = make_object<LLVMModuleNode>();
n->Init(std::move(mod), std::move(llvm_instance));
+ n->SetJITEngine(llvm_target->GetJITEngine());
auto meta_mod = MetadataModuleCreate(metadata);
meta_mod->Import(runtime::Module(n));
@@ -691,6 +847,7 @@ runtime::Module CreateLLVMCrtMetadataModule(const
Array<runtime::Module>& module
auto n = make_object<LLVMModuleNode>();
n->Init(std::move(mod), std::move(llvm_instance));
+ n->SetJITEngine(llvm_target->GetJITEngine());
for (auto m : modules) {
n->Import(m);
}
diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc
index aa4499ec96..28c7e06629 100644
--- a/src/target/target_kind.cc
+++ b/src/target/target_kind.cc
@@ -291,6 +291,8 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU)
.add_attr_option<Integer>("opt-level")
// LLVM command line flags, see below
.add_attr_option<Array<String>>("cl-opt")
+ // LLVM JIT engine mcjit/orcjit
+ .add_attr_option<String>("jit")
.set_default_keys({"cpu"})
// Force the external codegen kind attribute to be registered, even if no
external
// codegen targets are enabled by the TVM build.
diff --git a/tests/python/runtime/test_runtime_module_based_interface.py
b/tests/python/runtime/test_runtime_module_based_interface.py
index 6e62e3f215..55edbdaccb 100644
--- a/tests/python/runtime/test_runtime_module_based_interface.py
+++ b/tests/python/runtime/test_runtime_module_based_interface.py
@@ -23,6 +23,7 @@ from tvm.contrib import graph_executor
from tvm.contrib.debugger import debug_executor
from tvm.contrib.cuda_graph import cuda_graph_executor
import tvm.testing
+import pytest
def input_shape(mod):
@@ -48,10 +49,11 @@ def verify(data):
@tvm.testing.requires_llvm
-def test_legacy_compatibility():
[email protected]("target", ["llvm", "llvm -jit=orcjit"])
+def test_legacy_compatibility(target):
mod, params = relay.testing.synthetic.get_workload()
with relay.build_config(opt_level=3):
- graph, lib, graph_params = relay.build_module.build(mod, "llvm",
params=params)
+ graph, lib, graph_params = relay.build_module.build(mod, target,
params=params)
data = np.random.uniform(-1, 1, size=input_shape(mod)).astype("float32")
dev = tvm.cpu()
module = graph_executor.create(graph, lib, dev)
@@ -63,10 +65,11 @@ def test_legacy_compatibility():
@tvm.testing.requires_llvm
-def test_cpu():
[email protected]("target", ["llvm", "llvm -jit=orcjit"])
+def test_cpu(target):
mod, params = relay.testing.synthetic.get_workload()
with relay.build_config(opt_level=3):
- complied_graph_lib = relay.build_module.build(mod, "llvm",
params=params)
+ complied_graph_lib = relay.build_module.build(mod, target,
params=params)
data = np.random.uniform(-1, 1, size=input_shape(mod)).astype("float32")
# raw api
dev = tvm.cpu()
@@ -105,10 +108,11 @@ def test_cpu_get_graph_json():
@tvm.testing.requires_llvm
-def test_cpu_get_graph_params_run():
[email protected]("target", ["llvm", "llvm -jit=orcjit"])
+def test_cpu_get_graph_params_run(target):
mod, params = relay.testing.synthetic.get_workload()
with tvm.transform.PassContext(opt_level=3):
- complied_graph_lib = relay.build_module.build(mod, "llvm",
params=params)
+ complied_graph_lib = relay.build_module.build(mod, target,
params=params)
data = np.random.uniform(-1, 1, size=input_shape(mod)).astype("float32")
dev = tvm.cpu()
from tvm.contrib import utils
@@ -584,10 +588,11 @@ def test_remove_package_params():
@tvm.testing.requires_llvm
-def test_debug_graph_executor():
[email protected]("target", ["llvm", "llvm -jit=orcjit"])
+def test_debug_graph_executor(target):
mod, params = relay.testing.synthetic.get_workload()
with relay.build_config(opt_level=3):
- complied_graph_lib = relay.build_module.build(mod, "llvm",
params=params)
+ complied_graph_lib = relay.build_module.build(mod, target,
params=params)
data = np.random.uniform(-1, 1, size=input_shape(mod)).astype("float32")
# raw api
diff --git a/tests/python/runtime/test_runtime_module_load.py
b/tests/python/runtime/test_runtime_module_load.py
index ecaa7067a5..3789a1d090 100644
--- a/tests/python/runtime/test_runtime_module_load.py
+++ b/tests/python/runtime/test_runtime_module_load.py
@@ -22,6 +22,7 @@ import numpy as np
import subprocess
import tvm.testing
from tvm.relay.backend import Runtime
+import pytest
runtime_py = """
import os
@@ -42,9 +43,9 @@ print("Finish runtime checking...")
"""
-def test_dso_module_load():
- if not tvm.testing.device_enabled("llvm"):
- return
[email protected]_llvm
[email protected]("target", ["llvm", "llvm -jit=orcjit"])
+def test_dso_module_load(target):
dtype = "int64"
temp = utils.tempdir()
@@ -63,7 +64,7 @@ def test_dso_module_load():
mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "main")
)
- m = tvm.driver.build(mod, target="llvm")
+ m = tvm.driver.build(mod, target=target)
for name in names:
m.save(name)
@@ -167,6 +168,7 @@ def test_device_module_dump():
check_stackvm(device)
[email protected]_llvm
def test_combine_module_llvm():
"""Test combine multiple module into one shared lib."""
# graph
@@ -178,9 +180,6 @@ def test_combine_module_llvm():
def check_llvm():
dev = tvm.cpu(0)
- if not tvm.testing.device_enabled("llvm"):
- print("Skip because llvm is not enabled")
- return
temp = utils.tempdir()
fadd1 = tvm.build(s, [A, B], "llvm", name="myadd1")
fadd2 = tvm.build(s, [A, B], "llvm", name="myadd2")
diff --git a/tests/python/target/test_target_target.py
b/tests/python/target/test_target_target.py
index d5e8d06025..83bd864970 100644
--- a/tests/python/target/test_target_target.py
+++ b/tests/python/target/test_target_target.py
@@ -171,6 +171,13 @@ def test_target_llvm_options():
)
+def test_target_llvm_jit_options():
+ target = tvm.target.Target("llvm -jit=mcjit")
+ assert target.attrs["jit"] == "mcjit"
+ target = tvm.target.Target("llvm -jit=orcjit")
+ assert target.attrs["jit"] == "orcjit"
+
+
def test_target_create():
targets = [cuda(), rocm(), mali(), intel_graphics(), arm_cpu("rk3399"),
vta(), bifrost()]
for tgt in targets: