This is an automated email from the ASF dual-hosted git repository.
jroesch pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 594f23d [Core][Build] Move build module transformations and utilities
to C++ (#9103)
594f23d is described below
commit 594f23d976f09d4aa300f84de0c9a7906b71eeee
Author: Michalis Papadimitriou <[email protected]>
AuthorDate: Thu Oct 14 11:09:02 2021 +0300
[Core][Build] Move build module transformations and utilities to C++ (#9103)
* Initial investigation
* More progress!
* More progress / notes
* rewrite build_for_device mostly in c++
* More progress
* Initial split of transformations applied to device and host as post split
action from mixed module
* Combine duplicate passes after spliting mod on aot and vm flows
* Minor cleanup
* Move target mangling to driver_api.cc
* Move more build utlities to cpp driver api
* [Build][WIP] Moving build utilities to C++ from Python
* [Build] Remove comments
* [lint] Pass black
* More formating
* Move more build functionality into cpp
* Remove comments
* Remove unused defs and imports
* Address PR comments
* More PR comments
* More comments
* More comments
* Add comments on the new split function
* Fix PR comments on clarity
* Test CI
* Fix format
* Refactor build
* Expose splitted composite passes to python
* Format files
* Test fix
* Fix for annotating entry funcs on code targeting CPU
* Prevent entry funcs to be annotated when compiling for CPU with C runtime
enabled
* Guard for aot executor entry
* Sphix format
* Sanity fix
* Sphinx fix
Co-authored-by: electriclilies <[email protected]>
---
include/tvm/driver/driver_api.h | 30 ++++
python/tvm/driver/build_module.py | 125 ++---------------
python/tvm/relay/build_module.py | 6 +-
src/driver/driver_api.cc | 280 ++++++++++++++++++++++++++------------
src/relay/backend/vm/compiler.cc | 2 +-
5 files changed, 238 insertions(+), 205 deletions(-)
diff --git a/include/tvm/driver/driver_api.h b/include/tvm/driver/driver_api.h
index 418d532..45a9382 100644
--- a/include/tvm/driver/driver_api.h
+++ b/include/tvm/driver/driver_api.h
@@ -30,6 +30,7 @@
#define TVM_DRIVER_DRIVER_API_H_
#include <tvm/ir/module.h>
+#include <tvm/ir/transform.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/support/with.h>
#include <tvm/target/target.h>
@@ -43,6 +44,34 @@
#include <vector>
namespace tvm {
+using tvm::transform::Pass;
+
+/*!
+ * \brief Configures and returns the composite Pass for the fused module (pre
split) that contains
+ * device and host code.
+ * \param mixed_mod The original mixed module.
+ * \param target The device Target.
+ * \return The composite Pass for the fused module.
+// */
+TVM_DLL transform::Sequential MixedModulePassManager(IRModule mixed_mod,
Target target);
+
+/*!
+ * \brief Configures and returns the composite Pass for the device Target
after device/host from
+ * mixed module.
+ * \param mixed_mod The optimized mixed module.
+ * \param target The device Target.
+ * \return The composite Pass for the device module.
+ */
+TVM_DLL transform::Sequential DeviceModulePassManager(IRModule mixed_mod,
Target target);
+
+/*!
+ * \brief Configures and returns the composite Pass for the host Target after
device/host from mixed
+ * module.
+ * \param mixed_mod The optimized mixed module.
+ * \param target_host The host Target.
+ * \return The composite Pass for the host module.
+ */
+TVM_DLL transform::Sequential HostModulePassManager(IRModule mixed_mod, Target
target_host);
/*!
* \brief Lower an IRModule (optimize with it with the pass list defined in
CreatePassList)
@@ -136,6 +165,7 @@ TVM_DLL runtime::Module build(const Map<Target, IRModule>&
input, const Target&
* \return The built module that contains code for different processors.
*/
TVM_DLL runtime::Module build(const Map<String, IRModule>& input, const
Target& target_host);
+
} // namespace tvm
#endif // TVM_DRIVER_DRIVER_API_H_
diff --git a/python/tvm/driver/build_module.py
b/python/tvm/driver/build_module.py
index a7ebc00..429b3e1 100644
--- a/python/tvm/driver/build_module.py
+++ b/python/tvm/driver/build_module.py
@@ -16,27 +16,23 @@
# under the License.
# pylint: disable=invalid-name
-"""The build utils in python.
-"""
+"""The build utils in python."""
from typing import Union, Optional, List, Mapping
-import warnings
import tvm.tir
from tvm.runtime import Module
from tvm.runtime import ndarray
from tvm.ir import container
-from tvm.ir import CallingConv
from tvm.tir import PrimFunc
from tvm.ir.module import IRModule
-from tvm.ir.transform import PassContext
-from tvm.target import codegen
from tvm.te import tensor
from tvm.te import schedule
from tvm.target import Target
from tvm.tir.buffer import Buffer
from tvm.tir.expr import Var
+from tvm.driver import _ffi_api as _driver_ffi
from . import _ffi_api as ffi
@@ -104,8 +100,8 @@ def lower(
args : Optional[List[Union[tvm.tir.Buffer, tensor.Tensor, Var]]]
The argument lists to the function for TE schedule.
- It should be None if we want to lower TensorIR.
+ It should be None if we want to lower TensorIR.
name : str
The name of the result function.
@@ -132,98 +128,6 @@ def lower(
raise ValueError("Expected input to be an IRModule, PrimFunc or Schedule,
but got, ", type(inp))
-def _build_for_device(input_mod, target, target_host):
- """Build the lowered functions for a device with the given compilation
- target.
-
- Parameters
- ----------
- input_mod : IRModule
- The schedule to be built.
-
- target : str or :any:`tvm.target.Target`
- The target and option of the compilation.
-
- target_host : str or :any:`tvm.target.Target`
- The host compilation target.
-
- Returns
- -------
- fhost : IRModule
- The host IRModule.
-
- mdev : tvm.module
- A module that contains device code.
- """
- target, target_host = Target.check_and_update_host_consist(target,
target_host)
- device_type = ndarray.device(target.kind.name, 0).device_type
-
- mod_mixed = input_mod
- mod_mixed = tvm.tir.transform.Apply(lambda f: f.with_attr("target",
target))(mod_mixed)
-
- opt_mixed = [
- tvm.tir.transform.VerifyMemory(),
- tvm.tir.transform.MergeDynamicSharedMemoryAllocations(),
- ]
- if len(mod_mixed.functions) == 1:
- opt_mixed += [tvm.tir.transform.Apply(lambda f:
f.with_attr("tir.is_entry_func", True))]
-
- if PassContext.current().config.get("tir.detect_global_barrier", False):
- opt_mixed += [tvm.tir.transform.ThreadSync("global")]
- opt_mixed += [
- tvm.tir.transform.ThreadSync("shared"),
- tvm.tir.transform.ThreadSync("warp"),
- tvm.tir.transform.InferFragment(),
- tvm.tir.transform.LowerThreadAllreduce(),
- tvm.tir.transform.MakePackedAPI(),
- tvm.tir.transform.SplitHostDevice(),
- ]
- mod_mixed = tvm.transform.Sequential(opt_mixed)(mod_mixed)
-
- # device optimizations
- opt_device = tvm.transform.Sequential(
- [
- tvm.tir.transform.Filter(
- lambda f: "calling_conv" in f.attrs
- and f.attrs["calling_conv"].value ==
CallingConv.DEVICE_KERNEL_LAUNCH
- ),
- tvm.tir.transform.LowerWarpMemory(),
- tvm.tir.transform.Simplify(),
- tvm.tir.transform.LowerDeviceStorageAccessInfo(),
- tvm.tir.transform.LowerCustomDatatypes(),
- tvm.tir.transform.LowerIntrin(),
- ]
- )
- mod_dev = opt_device(mod_mixed)
-
- # host optimizations
- opt_host = tvm.transform.Sequential(
- [
- tvm.tir.transform.Filter(
- lambda f: "calling_conv" not in f.attrs
- or f.attrs["calling_conv"].value !=
CallingConv.DEVICE_KERNEL_LAUNCH
- ),
- tvm.tir.transform.Apply(lambda f: f.with_attr("target",
target_host)),
- tvm.tir.transform.LowerTVMBuiltin(),
- tvm.tir.transform.LowerDeviceStorageAccessInfo(),
- tvm.tir.transform.LowerCustomDatatypes(),
- tvm.tir.transform.LowerIntrin(),
- tvm.tir.transform.CombineContextCall(),
- ]
- )
- mod_host = opt_host(mod_mixed)
-
- if device_type == ndarray.cpu(0).device_type and target_host == target:
- assert len(mod_dev.functions) == 0
- if "gpu" in target.keys and len(mod_dev.functions) == 0:
- warnings.warn(
- "Specified target %s, but cannot find device code, did you do "
"bind?" % target
- )
-
- rt_mod_dev = codegen.build_module(mod_dev, target) if
len(mod_dev.functions) != 0 else None
- return mod_host, rt_mod_dev
-
-
def build(
inputs: Union[schedule.Schedule, PrimFunc, IRModule, Mapping[str,
IRModule]],
args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None,
@@ -237,7 +141,8 @@ def build(
Parameters
----------
- inputs : Union[tvm.te.schedule.Schedule, tvm.tir.PrimFunc, IRModule,
Mapping[str, IRModule]]
+ inputs : Union[tvm.te.schedule.Schedule,
+ tvm.tir.PrimFunc, IRModule, Mapping[str, IRModule]]
The input to be built
args : Optional[List[Union[tvm.tir.Buffer, tensor.Tensor, Var]]]
@@ -253,7 +158,7 @@ def build(
setup the dimensions and parameters correctly.
target_host is used to specify the host side codegen target.
By default, llvm is used if it is enabled,
- otherwise a stackvm intepreter is used.
+ otherwise a stackvm interpreter is used.
name : Optional[str]
The name of result function.
@@ -350,21 +255,11 @@ def build(
target_input_mod, target_host
)
- mod_host_all = tvm.IRModule({})
-
- device_modules = []
- for tar, input_mod in target_input_mod.items():
- mod_host, mdev = _build_for_device(input_mod, tar, target_host)
- mod_host_all.update(mod_host)
- device_modules.append(mdev)
-
- # Generate a unified host module.
- rt_mod_host = codegen.build_module(mod_host_all, target_host)
+ rt_mod_host = _driver_ffi.finalize_module(target_input_mod, target_host)
- # Import all modules.
- for mdev in device_modules:
- if mdev:
- rt_mod_host.import_module(mdev)
+ target_input_mod, target_host = Target.check_and_update_host_consist(
+ target_input_mod, target_host
+ )
if not isinstance(target_host, Target):
target_host = Target(target_host)
diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py
index c67ac1d..f1686d2 100644
--- a/python/tvm/relay/build_module.py
+++ b/python/tvm/relay/build_module.py
@@ -123,7 +123,7 @@ class BuildModule(object):
to setup the dimensions and parameters correctly.
target_host is used to specify the host side codegen target.
By default, llvm is used if it is enabled,
- otherwise a stackvm intepreter is used.
+ otherwise a stackvm interpreter is used.
params : dict of str to NDArray
Input parameters to the graph that do not change
@@ -303,7 +303,7 @@ def build(ir_mod, target=None, target_host=None,
params=None, mod_name="default"
setup the dimensions and parameters correctly.
target_host is used to specify the host side codegen target.
By default, llvm is used if it is enabled,
- otherwise a stackvm intepreter is used.
+ otherwise a stackvm interpreter is used.
params : dict of str to NDArray
Input parameters to the graph that do not change
@@ -452,7 +452,7 @@ def bind_params_by_name(func, params):
class GraphExecutor(_interpreter.Executor):
"""Wrapper around Executor interface.
- This executor is used for debug and testing purpoes.
+ This executor is used for debug and testing purposes.
Parameters
----------
diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc
index 2c6fbc2..e659421 100644
--- a/src/driver/driver_api.cc
+++ b/src/driver/driver_api.cc
@@ -42,17 +42,26 @@
TVM_REGISTER_PASS_CONFIG_OPTION("tir.detect_global_barrier", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_bound_checkers", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_assert", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_vectorize", Bool);
+TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array<Array<ObjectRef>>);
using runtime::PackedFunc;
using runtime::TVMArgs;
using runtime::TVMRetValue;
+using tvm::Array;
+using tvm::transform::Pass;
bool LLVMEnabled() {
const runtime::PackedFunc* pf = runtime::Registry::Get("target.build.llvm");
return pf != nullptr;
}
+bool ShouldAnnotateEntryFunc(const Target target, const IRModule mod) {
+ const bool aot_executor = (target->GetAttr<String>("executor").value_or("")
== "aot");
+ const bool single_entry_func = (mod->functions.size() == 1);
+ return single_entry_func && !aot_executor;
+}
+
/*! \return The default host target for a given device target */
Target DefaultTargetHost(Target target) {
if (target.defined() && target->kind->device_type == kDLCPU) {
@@ -155,6 +164,13 @@ transform::Pass BindTarget(Target target) {
return tir::transform::CreatePrimFuncPass(fpass, 0, "BindTarget", {});
}
+static transform::Pass AnnotateEntryFunc(bool b) {
+ auto fpass = [b](tir::PrimFunc f, IRModule m, transform::PassContext ctx) {
+ return WithAttr(std::move(f), tir::attr::kIsEntryFunc, Bool(true));
+ };
+ return tir::transform::CreatePrimFuncPass(fpass, 0, "AnnotateEntryFunc", {});
+}
+
template <typename FCond>
transform::Pass Filter(FCond fcond) {
auto fpass = [fcond](tir::PrimFunc f, IRModule m, transform::PassContext
ctx) {
@@ -184,7 +200,7 @@ Array<tvm::transform::Pass> CreatePassList(bool
disable_loop_partition) {
Array<transform::Pass> user_lower_phase2 = Array<transform::Pass>();
Array<transform::Pass> user_lower_phase3 = Array<transform::Pass>();
- // phase pasees is of the form
+ // phase passes is of the form
// [[phase_number, pass], [phase_number, pass]... ]
for (Array<ObjectRef> phase_pass : add_lower_pass) {
const IntImmNode* phase_num = phase_pass[0].as<IntImmNode>();
@@ -266,6 +282,11 @@ IRModule LowerWithPassList(IRModule mod,
Array<tvm::transform::Pass> pass_list)
return mod;
}
+IRModule ApplyPasses(IRModule mod, transform::Sequential seq) {
+ mod = seq(std::move(mod));
+ return mod;
+}
+
IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args,
const std::string& name,
const std::unordered_map<te::Tensor, tir::Buffer>&
binds) {
// Convert te schedule to IRModule
@@ -373,97 +394,96 @@ TVM_REGISTER_GLOBAL("driver.lower_schedule")
return LowerSchedule(std::move(sch), args, name, c_binds, simple_mode);
});
-std::pair<IRModule, IRModule> SplitDevHostFuncs(IRModule mod_mixed, const
Target& target_arg,
- const Target& target_host_arg,
- const transform::PassContext&
pass_ctx) {
+/**
+ * This function takes the input module that contains both the device and host
opts.
+ * Then, it applies transformation on the original module before splitting
into separate modules for
+ * device and host. Then it also applies transformations on the new splitted
modules.
+ */
+std::pair<IRModule, IRModule> SplitMixedModule(IRModule mod_mixed, const
Target& target_arg,
+ const Target& target_host_arg) {
Target target = target_arg, target_host = target_host_arg;
CheckAndUpdateHostConsistency(&target, &target_host);
- Array<tvm::transform::Pass> mixed_pass_list = {BindTarget(target),
-
tir::transform::VerifyMemory()};
-
mixed_pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations());
- if (pass_ctx->GetConfig<Bool>("tir.detect_global_barrier",
Bool(false)).value()) {
- mixed_pass_list.push_back(tir::transform::ThreadSync("global"));
- }
- mixed_pass_list.push_back(tir::transform::ThreadSync("shared"));
- mixed_pass_list.push_back(tir::transform::ThreadSync("warp"));
- mixed_pass_list.push_back(tir::transform::InferFragment());
- mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce());
+ ICHECK(mod_mixed.defined()) << "This module must be defined";
- if (target->GetAttr<Bool>("unpacked-api").value_or(Bool(false))) {
- mixed_pass_list.push_back(tir::transform::MakeUnpackedAPI());
- } else {
- mixed_pass_list.push_back(tir::transform::MakePackedAPI(-1));
- }
+ mod_mixed = ApplyPasses(mod_mixed, MixedModulePassManager(mod_mixed,
target));
- mixed_pass_list.push_back(tir::transform::SplitHostDevice());
+ IRModule host_mod = ApplyPasses(mod_mixed, HostModulePassManager(mod_mixed,
target_host));
- auto opt_mixed = transform::Sequential(mixed_pass_list);
- mod_mixed = opt_mixed(std::move(mod_mixed));
-
- // We make an assumption here that the overriden host target
- // can be used alongside the default host codegen based on device type
- // this is so the correct code generator is used later instead of overriding
the target.
- // We need better support for inserting multiple kDLCPU targets as our
current options
- // are kDeviceKernelLaunch or not
- Target overriden_host_target = target_host;
- if (target->kind->device_type == target_host->kind->device_type) {
- overriden_host_target = target;
- }
- auto host_pass_list = {
- Filter([](const tir::PrimFunc& f) {
- return f->GetAttr<Integer>(tvm::attr::kCallingConv,
Integer(CallingConv::kDefault)) !=
- CallingConv::kDeviceKernelLaunch;
- }),
- BindTarget(overriden_host_target),
- tir::transform::LowerTVMBuiltin(),
- tir::transform::LowerCustomDatatypes(),
- tir::transform::LowerIntrin(),
- tir::transform::LowerDeviceStorageAccessInfo(),
- tir::transform::CombineContextCall(),
- };
- auto opt_host = transform::Sequential(host_pass_list);
- ICHECK(mod_mixed.defined()) << "This module must be defined";
- auto mhost = opt_host(mod_mixed);
-
- // device pipeline
- auto device_pass_list = {
- Filter([](const tir::PrimFunc& f) {
- return f->GetAttr<Integer>(tvm::attr::kCallingConv,
Integer(CallingConv::kDefault)) ==
- CallingConv::kDeviceKernelLaunch;
- }),
- BindTarget(target),
- tir::transform::LowerWarpMemory(),
- tir::transform::Simplify(),
- tir::transform::LowerCustomDatatypes(),
- tir::transform::LowerIntrin(),
- tir::transform::LowerDeviceStorageAccessInfo(),
- };
- auto opt_device = transform::Sequential(device_pass_list);
- auto mdevice = opt_device(mod_mixed);
+ IRModule device_mod = ApplyPasses(mod_mixed,
DeviceModulePassManager(mod_mixed, target));
- // some final misc checks.
auto keys = target->GetKeys();
+
+ CheckAndUpdateHostConsistency(&target, &target_host);
+
bool target_is_gpu = std::find(keys.begin(), keys.end(), "gpu") !=
keys.end();
- if (target_is_gpu && mdevice->functions.size() == 0) {
- LOG(WARNING) << "Specified target " << target->str()
- << " but cannot find device code. Did you forget to bind?";
+ if (target_is_gpu && device_mod->functions.size() == 0) {
+ DLOG(WARNING) << "Specified target " << target->str()
+ << " but cannot find device code. Did you forget to bind?";
+ }
+
+ return {host_mod, device_mod};
+}
+
+runtime::Module FinalizeModule(const Map<Target, IRModule>& inputs_arg, const
Target& host_target) {
+ std::vector<runtime::Module> device_modules;
+ Map<Target, IRModule> inputs = inputs_arg;
+ Target target_host = host_target;
+
+ CheckAndUpdateHostConsistency(&inputs, &target_host);
+
+ if (!target_host.defined()) {
+ for (const auto& it : inputs) {
+ if (it.first->kind->device_type == kDLCPU || it.first->kind->device_type
== kDLMicroDev) {
+ target_host = it.first;
+ break;
+ }
+ }
+ }
+
+ if (!target_host.defined()) {
+ target_host = DefaultTargetHost(target_host);
}
- if (target->kind->device_type == kDLCPU && target_host == target) {
- // TODO(@jroesch): This check is no longer true we need to figure out if
we care about this.
- // We need to relax this check for just TIR functions.
- // ICHECK(mdevice->functions.empty()) << "No device code should be
generated when target "
- // << "and host_target are both llvm
target."
- // << "\n";
+ // Update target host for all targets
+ CheckAndUpdateHostConsistency(&inputs, &target_host);
+
+ IRModule mhost_all = IRModule(Map<GlobalVar, BaseFunc>());
+
+ ICHECK(mhost_all.defined()) << "The host module must be defined";
+
+ for (const auto& it : inputs) {
+ if (it.second.defined()) {
+ auto pair = SplitMixedModule(it.second, it.first, target_host);
+ auto& host_mod = pair.first;
+ auto& device_mod = pair.second;
+
+ ICHECK(host_mod.defined()) << "The split host module must be defined";
+
+ ICHECK(mhost_all.defined()) << "The host module must be defined";
+
+ mhost_all->Update(host_mod);
+
+ if (device_mod->functions.size() != 0) {
+ device_modules.push_back(codegen::Build(device_mod, it.first));
+ }
+ }
}
- return {mhost, mdevice};
+ runtime::Module complete_mod = codegen::Build(mhost_all, target_host);
+ for (const auto& it : device_modules) {
+ if (it.operator->()) {
+ complete_mod.Import(it);
+ }
+ }
+ return complete_mod;
}
-// Can we make this take one annotated IRModule?
-//
-// Build for heterogeneous execution.
+TVM_REGISTER_GLOBAL("driver.finalize_module")
+ .set_body_typed([](const Map<Target, IRModule>& inputs_arg, Target
host_target) {
+ return FinalizeModule(inputs_arg, host_target);
+ });
+
runtime::Module build(const Map<Target, IRModule>& inputs_arg, const Target&
target_host_arg) {
auto pass_ctx = transform::PassContext::Current();
@@ -498,11 +518,11 @@ runtime::Module build(const Map<Target, IRModule>&
inputs_arg, const Target& tar
if (it.second.defined()) {
const Target& target = it.first;
const IRModule& ir_module = it.second;
- auto pair = SplitDevHostFuncs(ir_module, target, target_host, pass_ctx);
- auto& mhost = pair.first;
- auto& mdevice = pair.second;
+ auto pair = SplitMixedModule(ir_module, target, target_host);
+ auto& host_mod = pair.first;
+ auto& device_mod = pair.second;
- ICHECK(mhost.defined()) << "The split host module must be defined";
+ ICHECK(host_mod.defined()) << "The split host module must be defined";
ICHECK(mhost_all.defined()) << "The host module must be defined";
@@ -513,19 +533,18 @@ runtime::Module build(const Map<Target, IRModule>&
inputs_arg, const Target& tar
bool overrides_host_target = target->kind->device_type ==
target_host->kind->device_type;
bool non_host_target_kind = target->kind != target_host->kind;
if (overrides_host_target && non_host_target_kind) {
- device_modules.push_back(codegen::Build(mhost, it.first));
+ device_modules.push_back(codegen::Build(host_mod, it.first));
} else {
- mhost_all->Update(mhost);
+ mhost_all->Update(host_mod);
}
- if (mdevice->functions.size() != 0) {
- device_modules.push_back(codegen::Build(mdevice, it.first));
+ if (device_mod->functions.size() != 0) {
+ device_modules.push_back(codegen::Build(device_mod, it.first));
}
}
}
runtime::Module mhost = codegen::Build(mhost_all, target_host);
- // Import all modules
for (const auto& it : device_modules) {
if (it.operator->()) {
mhost.Import(it);
@@ -556,8 +575,97 @@ runtime::Module build(const IRModule& funcs, const Target&
target_arg,
const Target& target_host_arg) {
auto target = target_arg, target_host = target_host_arg;
CheckAndUpdateHostConsistency(&target, &target_host);
+ // More maps of target and target host
Map<Target, IRModule> inputs = {{target, funcs}};
return build(inputs, target_host);
}
+transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target
target) {
+ transform::PassContext pass_ctx = transform::PassContext::Current();
+
+ Array<Pass> mixed_pass_list;
+
+ mixed_pass_list.push_back(BindTarget(target));
+
+ mixed_pass_list.push_back(tir::transform::VerifyMemory());
+
mixed_pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations());
+
+ if (ShouldAnnotateEntryFunc(target, mixed_mod)) {
+ mixed_pass_list.push_back(AnnotateEntryFunc(true));
+ }
+
+ bool detect_global_barrier =
+ pass_ctx->GetConfig<Bool>("tir.detect_global_barrier",
Bool(false)).value();
+ if (detect_global_barrier) {
+ mixed_pass_list.push_back(tir::transform::ThreadSync("global"));
+ }
+
+ mixed_pass_list.push_back(tir::transform::ThreadSync("shared"));
+ mixed_pass_list.push_back(tir::transform::ThreadSync("warp"));
+ mixed_pass_list.push_back(tir::transform::InferFragment());
+ mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce());
+
+ if (target->GetAttr<Bool>("unpacked-api").value_or(Bool(false))) {
+ mixed_pass_list.push_back(tir::transform::MakeUnpackedAPI());
+ } else {
+ mixed_pass_list.push_back(tir::transform::MakePackedAPI(-1));
+ }
+ mixed_pass_list.push_back(tir::transform::SplitHostDevice());
+
+ return transform::Sequential(mixed_pass_list);
+}
+
+TVM_REGISTER_GLOBAL("driver.mixed_mod_passes")
+ .set_body_typed([](IRModule mixed_mod, Target target) {
+ return MixedModulePassManager(mixed_mod, target);
+ });
+
+transform::Sequential HostModulePassManager(IRModule mixed_mod, Target
target_host) {
+ Array<tvm::transform::Pass> host_pass_list;
+ host_pass_list.push_back(Filter([](const tir::PrimFunc& f) {
+ return f->GetAttr<Integer>(tvm::attr::kCallingConv,
Integer(CallingConv::kDefault)) !=
+ CallingConv::kDeviceKernelLaunch;
+ }));
+
+ ICHECK(mixed_mod.defined()) << "This module must be defined";
+
+ host_pass_list.push_back(BindTarget(target_host));
+
+ host_pass_list.push_back(tir::transform::LowerTVMBuiltin());
+ host_pass_list.push_back(tir::transform::LowerCustomDatatypes());
+ host_pass_list.push_back(tir::transform::LowerIntrin());
+ host_pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo());
+ host_pass_list.push_back(tir::transform::CombineContextCall());
+
+ return transform::Sequential(host_pass_list);
+}
+
+TVM_REGISTER_GLOBAL("driver.host_mod_passes")
+ .set_body_typed([](IRModule mixed_mod, Target target_host) {
+ return HostModulePassManager(mixed_mod, target_host);
+ });
+
+transform::Sequential DeviceModulePassManager(IRModule mixed_mod, Target
target) {
+ Array<Pass> device_pass_list;
+ device_pass_list.push_back(Filter([](const tir::PrimFunc& f) {
+ return f->GetAttr<Integer>(tvm::attr::kCallingConv,
Integer(CallingConv::kDefault)) ==
+ CallingConv::kDeviceKernelLaunch;
+ }));
+
+ device_pass_list.push_back(BindTarget(target));
+
+ device_pass_list.push_back(tir::transform::LowerWarpMemory());
+ device_pass_list.push_back(tir::transform::Simplify());
+ device_pass_list.push_back(tir::transform::LowerCustomDatatypes());
+ device_pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo());
+ device_pass_list.push_back(tir::transform::LowerIntrin());
+
+ return transform::Sequential(device_pass_list);
+}
+
+TVM_REGISTER_GLOBAL("driver.device_mod_passes")
+ .set_body_typed([](IRModule mixed_mod, Target target_host) {
+ return DeviceModulePassManager(mixed_mod, target_host);
+ });
+
} // namespace tvm
diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc
index 36cd0c7..70ad2cc 100644
--- a/src/relay/backend/vm/compiler.cc
+++ b/src/relay/backend/vm/compiler.cc
@@ -824,7 +824,7 @@ class VMFunctionCompiler :
DeviceAwareExprFunctor<void(const Expr& n)> {
/*!
* \brief Compile a pattern match expression
- * It first converts the pattern match expression into a desicision tree,
the condition
+ * It first converts the pattern match expression into a decision tree, the
condition
* could be object comparison or variable binding. If any of the condition
fails in a clause,
* the decision tree switches to check the conditions of next clause and so
on. If no clause
* matches the value, a fatal node is inserted.