This is an automated email from the ASF dual-hosted git repository.

jroesch pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 594f23d  [Core][Build] Move build module transformations and utilities 
to C++ (#9103)
594f23d is described below

commit 594f23d976f09d4aa300f84de0c9a7906b71eeee
Author: Michalis Papadimitriou <mikepapa...@users.noreply.github.com>
AuthorDate: Thu Oct 14 11:09:02 2021 +0300

    [Core][Build] Move build module transformations and utilities to C++ (#9103)
    
    * Initial investigation
    
    * More progress!
    
    * More progress / notes
    
    * rewrite build_for_device mostly in c++
    
    * More progress
    
    * Initial split of transformations applied to device and host as post split 
action from mixed module
    
    * Combine duplicate passes after spliting mod on aot and vm flows
    
    * Minor cleanup
    
    * Move target mangling to driver_api.cc
    
    * Move more build utlities to cpp driver api
    
    * [Build][WIP] Moving build utilities to C++ from Python
    
    * [Build] Remove comments
    
    * [lint] Pass black
    
    * More formating
    
    * Move more build functionality into cpp
    
    * Remove comments
    
    * Remove unused defs and imports
    
    * Address PR comments
    
    * More PR comments
    
    * More comments
    
    * More comments
    
    * Add comments on the new split function
    
    * Fix PR comments on clarity
    
    * Test CI
    
    * Fix format
    
    * Refactor build
    
    * Expose splitted composite passes to python
    
    * Format files
    
    * Test fix
    
    * Fix for annotating entry funcs on code targeting CPU
    
    * Prevent entry funcs to be annotated when compiling for CPU with C runtime 
enabled
    
    * Guard for aot executor entry
    
    * Sphix format
    
    * Sanity fix
    
    * Sphinx fix
    
    Co-authored-by: electriclilies <lilyorthsm...@gmail.com>
---
 include/tvm/driver/driver_api.h   |  30 ++++
 python/tvm/driver/build_module.py | 125 ++---------------
 python/tvm/relay/build_module.py  |   6 +-
 src/driver/driver_api.cc          | 280 ++++++++++++++++++++++++++------------
 src/relay/backend/vm/compiler.cc  |   2 +-
 5 files changed, 238 insertions(+), 205 deletions(-)

diff --git a/include/tvm/driver/driver_api.h b/include/tvm/driver/driver_api.h
index 418d532..45a9382 100644
--- a/include/tvm/driver/driver_api.h
+++ b/include/tvm/driver/driver_api.h
@@ -30,6 +30,7 @@
 #define TVM_DRIVER_DRIVER_API_H_
 
 #include <tvm/ir/module.h>
+#include <tvm/ir/transform.h>
 #include <tvm/runtime/packed_func.h>
 #include <tvm/support/with.h>
 #include <tvm/target/target.h>
@@ -43,6 +44,34 @@
 #include <vector>
 
 namespace tvm {
+using tvm::transform::Pass;
+
+/*!
+ * \brief Configures and returns the composite Pass for the fused module (pre 
split) that contains
+ * device and host code.
+ * \param mixed_mod The original mixed module.
+ * \param target The device Target.
+ * \return The composite Pass for the fused module.
+//  */
+TVM_DLL transform::Sequential MixedModulePassManager(IRModule mixed_mod, 
Target target);
+
+/*!
+ * \brief Configures and returns the composite Pass for the device Target 
after device/host from
+ * mixed module.
+ * \param mixed_mod The optimized mixed module.
+ * \param target The device Target.
+ * \return The composite Pass for the device module.
+ */
+TVM_DLL transform::Sequential DeviceModulePassManager(IRModule mixed_mod, 
Target target);
+
+/*!
+ * \brief Configures and returns the composite Pass for the host Target after 
device/host from mixed
+ * module.
+ * \param mixed_mod The optimized mixed module.
+ * \param target_host The host Target.
+ * \return The composite Pass for the host module.
+ */
+TVM_DLL transform::Sequential HostModulePassManager(IRModule mixed_mod, Target 
target_host);
 
 /*!
  * \brief Lower an IRModule (optimize with it with the pass list defined in 
CreatePassList)
@@ -136,6 +165,7 @@ TVM_DLL runtime::Module build(const Map<Target, IRModule>& 
input, const Target&
  * \return The built module that contains code for different processors.
  */
 TVM_DLL runtime::Module build(const Map<String, IRModule>& input, const 
Target& target_host);
+
 }  // namespace tvm
 
 #endif  // TVM_DRIVER_DRIVER_API_H_
diff --git a/python/tvm/driver/build_module.py 
b/python/tvm/driver/build_module.py
index a7ebc00..429b3e1 100644
--- a/python/tvm/driver/build_module.py
+++ b/python/tvm/driver/build_module.py
@@ -16,27 +16,23 @@
 # under the License.
 
 # pylint: disable=invalid-name
-"""The build utils in python.
-"""
+"""The build utils in python."""
 
 from typing import Union, Optional, List, Mapping
-import warnings
 
 import tvm.tir
 
 from tvm.runtime import Module
 from tvm.runtime import ndarray
 from tvm.ir import container
-from tvm.ir import CallingConv
 from tvm.tir import PrimFunc
 from tvm.ir.module import IRModule
-from tvm.ir.transform import PassContext
-from tvm.target import codegen
 from tvm.te import tensor
 from tvm.te import schedule
 from tvm.target import Target
 from tvm.tir.buffer import Buffer
 from tvm.tir.expr import Var
+from tvm.driver import _ffi_api as _driver_ffi
 
 from . import _ffi_api as ffi
 
@@ -104,8 +100,8 @@ def lower(
 
     args : Optional[List[Union[tvm.tir.Buffer, tensor.Tensor, Var]]]
         The argument lists to the function for TE schedule.
-        It should be None if we want to lower TensorIR.
 
+        It should be None if we want to lower TensorIR.
     name : str
         The name of the result function.
 
@@ -132,98 +128,6 @@ def lower(
     raise ValueError("Expected input to be an IRModule, PrimFunc or Schedule, 
but got, ", type(inp))
 
 
-def _build_for_device(input_mod, target, target_host):
-    """Build the lowered functions for a device with the given compilation
-    target.
-
-    Parameters
-    ----------
-    input_mod : IRModule
-        The schedule to be built.
-
-    target : str or :any:`tvm.target.Target`
-        The target and option of the compilation.
-
-    target_host : str or :any:`tvm.target.Target`
-        The host compilation target.
-
-    Returns
-    -------
-    fhost : IRModule
-        The host IRModule.
-
-    mdev : tvm.module
-        A module that contains device code.
-    """
-    target, target_host = Target.check_and_update_host_consist(target, 
target_host)
-    device_type = ndarray.device(target.kind.name, 0).device_type
-
-    mod_mixed = input_mod
-    mod_mixed = tvm.tir.transform.Apply(lambda f: f.with_attr("target", 
target))(mod_mixed)
-
-    opt_mixed = [
-        tvm.tir.transform.VerifyMemory(),
-        tvm.tir.transform.MergeDynamicSharedMemoryAllocations(),
-    ]
-    if len(mod_mixed.functions) == 1:
-        opt_mixed += [tvm.tir.transform.Apply(lambda f: 
f.with_attr("tir.is_entry_func", True))]
-
-    if PassContext.current().config.get("tir.detect_global_barrier", False):
-        opt_mixed += [tvm.tir.transform.ThreadSync("global")]
-    opt_mixed += [
-        tvm.tir.transform.ThreadSync("shared"),
-        tvm.tir.transform.ThreadSync("warp"),
-        tvm.tir.transform.InferFragment(),
-        tvm.tir.transform.LowerThreadAllreduce(),
-        tvm.tir.transform.MakePackedAPI(),
-        tvm.tir.transform.SplitHostDevice(),
-    ]
-    mod_mixed = tvm.transform.Sequential(opt_mixed)(mod_mixed)
-
-    # device optimizations
-    opt_device = tvm.transform.Sequential(
-        [
-            tvm.tir.transform.Filter(
-                lambda f: "calling_conv" in f.attrs
-                and f.attrs["calling_conv"].value == 
CallingConv.DEVICE_KERNEL_LAUNCH
-            ),
-            tvm.tir.transform.LowerWarpMemory(),
-            tvm.tir.transform.Simplify(),
-            tvm.tir.transform.LowerDeviceStorageAccessInfo(),
-            tvm.tir.transform.LowerCustomDatatypes(),
-            tvm.tir.transform.LowerIntrin(),
-        ]
-    )
-    mod_dev = opt_device(mod_mixed)
-
-    # host optimizations
-    opt_host = tvm.transform.Sequential(
-        [
-            tvm.tir.transform.Filter(
-                lambda f: "calling_conv" not in f.attrs
-                or f.attrs["calling_conv"].value != 
CallingConv.DEVICE_KERNEL_LAUNCH
-            ),
-            tvm.tir.transform.Apply(lambda f: f.with_attr("target", 
target_host)),
-            tvm.tir.transform.LowerTVMBuiltin(),
-            tvm.tir.transform.LowerDeviceStorageAccessInfo(),
-            tvm.tir.transform.LowerCustomDatatypes(),
-            tvm.tir.transform.LowerIntrin(),
-            tvm.tir.transform.CombineContextCall(),
-        ]
-    )
-    mod_host = opt_host(mod_mixed)
-
-    if device_type == ndarray.cpu(0).device_type and target_host == target:
-        assert len(mod_dev.functions) == 0
-    if "gpu" in target.keys and len(mod_dev.functions) == 0:
-        warnings.warn(
-            "Specified target %s, but cannot find device code, did you do " 
"bind?" % target
-        )
-
-    rt_mod_dev = codegen.build_module(mod_dev, target) if 
len(mod_dev.functions) != 0 else None
-    return mod_host, rt_mod_dev
-
-
 def build(
     inputs: Union[schedule.Schedule, PrimFunc, IRModule, Mapping[str, 
IRModule]],
     args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None,
@@ -237,7 +141,8 @@ def build(
 
     Parameters
     ----------
-    inputs : Union[tvm.te.schedule.Schedule, tvm.tir.PrimFunc, IRModule, 
Mapping[str, IRModule]]
+    inputs : Union[tvm.te.schedule.Schedule,
+        tvm.tir.PrimFunc, IRModule, Mapping[str, IRModule]]
         The input to be built
 
     args : Optional[List[Union[tvm.tir.Buffer, tensor.Tensor, Var]]]
@@ -253,7 +158,7 @@ def build(
         setup the dimensions and parameters correctly.
         target_host is used to specify the host side codegen target.
         By default, llvm is used if it is enabled,
-        otherwise a stackvm intepreter is used.
+        otherwise a stackvm interpreter is used.
 
     name : Optional[str]
         The name of result function.
@@ -350,21 +255,11 @@ def build(
         target_input_mod, target_host
     )
 
-    mod_host_all = tvm.IRModule({})
-
-    device_modules = []
-    for tar, input_mod in target_input_mod.items():
-        mod_host, mdev = _build_for_device(input_mod, tar, target_host)
-        mod_host_all.update(mod_host)
-        device_modules.append(mdev)
-
-    # Generate a unified host module.
-    rt_mod_host = codegen.build_module(mod_host_all, target_host)
+    rt_mod_host = _driver_ffi.finalize_module(target_input_mod, target_host)
 
-    # Import all modules.
-    for mdev in device_modules:
-        if mdev:
-            rt_mod_host.import_module(mdev)
+    target_input_mod, target_host = Target.check_and_update_host_consist(
+        target_input_mod, target_host
+    )
 
     if not isinstance(target_host, Target):
         target_host = Target(target_host)
diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py
index c67ac1d..f1686d2 100644
--- a/python/tvm/relay/build_module.py
+++ b/python/tvm/relay/build_module.py
@@ -123,7 +123,7 @@ class BuildModule(object):
             to setup the dimensions and parameters correctly.
             target_host is used to specify the host side codegen target.
             By default, llvm is used if it is enabled,
-            otherwise a stackvm intepreter is used.
+            otherwise a stackvm interpreter is used.
 
         params : dict of str to NDArray
             Input parameters to the graph that do not change
@@ -303,7 +303,7 @@ def build(ir_mod, target=None, target_host=None, 
params=None, mod_name="default"
         setup the dimensions and parameters correctly.
         target_host is used to specify the host side codegen target.
         By default, llvm is used if it is enabled,
-        otherwise a stackvm intepreter is used.
+        otherwise a stackvm interpreter is used.
 
     params : dict of str to NDArray
         Input parameters to the graph that do not change
@@ -452,7 +452,7 @@ def bind_params_by_name(func, params):
 class GraphExecutor(_interpreter.Executor):
     """Wrapper around Executor interface.
 
-    This executor is used for debug and testing purpoes.
+    This executor is used for debug and testing purposes.
 
     Parameters
     ----------
diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc
index 2c6fbc2..e659421 100644
--- a/src/driver/driver_api.cc
+++ b/src/driver/driver_api.cc
@@ -42,17 +42,26 @@ 
TVM_REGISTER_PASS_CONFIG_OPTION("tir.detect_global_barrier", Bool);
 TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_bound_checkers", Bool);
 TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_assert", Bool);
 TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_vectorize", Bool);
+TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool);
 TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array<Array<ObjectRef>>);
 
 using runtime::PackedFunc;
 using runtime::TVMArgs;
 using runtime::TVMRetValue;
+using tvm::Array;
+using tvm::transform::Pass;
 
 bool LLVMEnabled() {
   const runtime::PackedFunc* pf = runtime::Registry::Get("target.build.llvm");
   return pf != nullptr;
 }
 
+bool ShouldAnnotateEntryFunc(const Target target, const IRModule mod) {
+  const bool aot_executor = (target->GetAttr<String>("executor").value_or("") 
== "aot");
+  const bool single_entry_func = (mod->functions.size() == 1);
+  return single_entry_func && !aot_executor;
+}
+
 /*! \return The default host target for a given device target */
 Target DefaultTargetHost(Target target) {
   if (target.defined() && target->kind->device_type == kDLCPU) {
@@ -155,6 +164,13 @@ transform::Pass BindTarget(Target target) {
   return tir::transform::CreatePrimFuncPass(fpass, 0, "BindTarget", {});
 }
 
+static transform::Pass AnnotateEntryFunc(bool b) {
+  auto fpass = [b](tir::PrimFunc f, IRModule m, transform::PassContext ctx) {
+    return WithAttr(std::move(f), tir::attr::kIsEntryFunc, Bool(true));
+  };
+  return tir::transform::CreatePrimFuncPass(fpass, 0, "AnnotateEntryFunc", {});
+}
+
 template <typename FCond>
 transform::Pass Filter(FCond fcond) {
   auto fpass = [fcond](tir::PrimFunc f, IRModule m, transform::PassContext 
ctx) {
@@ -184,7 +200,7 @@ Array<tvm::transform::Pass> CreatePassList(bool 
disable_loop_partition) {
   Array<transform::Pass> user_lower_phase2 = Array<transform::Pass>();
   Array<transform::Pass> user_lower_phase3 = Array<transform::Pass>();
 
-  // phase pasees is of the form
+  // phase passes is of the form
   // [[phase_number, pass], [phase_number, pass]... ]
   for (Array<ObjectRef> phase_pass : add_lower_pass) {
     const IntImmNode* phase_num = phase_pass[0].as<IntImmNode>();
@@ -266,6 +282,11 @@ IRModule LowerWithPassList(IRModule mod, 
Array<tvm::transform::Pass> pass_list)
   return mod;
 }
 
+IRModule ApplyPasses(IRModule mod, transform::Sequential seq) {
+  mod = seq(std::move(mod));
+  return mod;
+}
+
 IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, 
const std::string& name,
                           const std::unordered_map<te::Tensor, tir::Buffer>& 
binds) {
   // Convert te schedule to IRModule
@@ -373,97 +394,96 @@ TVM_REGISTER_GLOBAL("driver.lower_schedule")
       return LowerSchedule(std::move(sch), args, name, c_binds, simple_mode);
     });
 
-std::pair<IRModule, IRModule> SplitDevHostFuncs(IRModule mod_mixed, const 
Target& target_arg,
-                                                const Target& target_host_arg,
-                                                const transform::PassContext& 
pass_ctx) {
+/**
+ * This function takes the input module that contains both the device and host 
opts.
+ * Then, it applies transformation on the original module before splitting 
into separate modules for
+ * device and host. Then it also applies transformations on the new splitted 
modules.
+ */
+std::pair<IRModule, IRModule> SplitMixedModule(IRModule mod_mixed, const 
Target& target_arg,
+                                               const Target& target_host_arg) {
   Target target = target_arg, target_host = target_host_arg;
   CheckAndUpdateHostConsistency(&target, &target_host);
-  Array<tvm::transform::Pass> mixed_pass_list = {BindTarget(target),
-                                                 
tir::transform::VerifyMemory()};
 
-  
mixed_pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations());
-  if (pass_ctx->GetConfig<Bool>("tir.detect_global_barrier", 
Bool(false)).value()) {
-    mixed_pass_list.push_back(tir::transform::ThreadSync("global"));
-  }
-  mixed_pass_list.push_back(tir::transform::ThreadSync("shared"));
-  mixed_pass_list.push_back(tir::transform::ThreadSync("warp"));
-  mixed_pass_list.push_back(tir::transform::InferFragment());
-  mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce());
+  ICHECK(mod_mixed.defined()) << "This module must be defined";
 
-  if (target->GetAttr<Bool>("unpacked-api").value_or(Bool(false))) {
-    mixed_pass_list.push_back(tir::transform::MakeUnpackedAPI());
-  } else {
-    mixed_pass_list.push_back(tir::transform::MakePackedAPI(-1));
-  }
+  mod_mixed = ApplyPasses(mod_mixed, MixedModulePassManager(mod_mixed, 
target));
 
-  mixed_pass_list.push_back(tir::transform::SplitHostDevice());
+  IRModule host_mod = ApplyPasses(mod_mixed, HostModulePassManager(mod_mixed, 
target_host));
 
-  auto opt_mixed = transform::Sequential(mixed_pass_list);
-  mod_mixed = opt_mixed(std::move(mod_mixed));
-
-  // We make an assumption here that the overriden host target
-  // can be used alongside the default host codegen based on device type
-  // this is so the correct code generator is used later instead of overriding 
the target.
-  // We need better support for inserting multiple kDLCPU targets as our 
current options
-  // are kDeviceKernelLaunch or not
-  Target overriden_host_target = target_host;
-  if (target->kind->device_type == target_host->kind->device_type) {
-    overriden_host_target = target;
-  }
-  auto host_pass_list = {
-      Filter([](const tir::PrimFunc& f) {
-        return f->GetAttr<Integer>(tvm::attr::kCallingConv, 
Integer(CallingConv::kDefault)) !=
-               CallingConv::kDeviceKernelLaunch;
-      }),
-      BindTarget(overriden_host_target),
-      tir::transform::LowerTVMBuiltin(),
-      tir::transform::LowerCustomDatatypes(),
-      tir::transform::LowerIntrin(),
-      tir::transform::LowerDeviceStorageAccessInfo(),
-      tir::transform::CombineContextCall(),
-  };
-  auto opt_host = transform::Sequential(host_pass_list);
-  ICHECK(mod_mixed.defined()) << "This module must be defined";
-  auto mhost = opt_host(mod_mixed);
-
-  // device pipeline
-  auto device_pass_list = {
-      Filter([](const tir::PrimFunc& f) {
-        return f->GetAttr<Integer>(tvm::attr::kCallingConv, 
Integer(CallingConv::kDefault)) ==
-               CallingConv::kDeviceKernelLaunch;
-      }),
-      BindTarget(target),
-      tir::transform::LowerWarpMemory(),
-      tir::transform::Simplify(),
-      tir::transform::LowerCustomDatatypes(),
-      tir::transform::LowerIntrin(),
-      tir::transform::LowerDeviceStorageAccessInfo(),
-  };
-  auto opt_device = transform::Sequential(device_pass_list);
-  auto mdevice = opt_device(mod_mixed);
+  IRModule device_mod = ApplyPasses(mod_mixed, 
DeviceModulePassManager(mod_mixed, target));
 
-  // some final misc checks.
   auto keys = target->GetKeys();
+
+  CheckAndUpdateHostConsistency(&target, &target_host);
+
   bool target_is_gpu = std::find(keys.begin(), keys.end(), "gpu") != 
keys.end();
-  if (target_is_gpu && mdevice->functions.size() == 0) {
-    LOG(WARNING) << "Specified target " << target->str()
-                 << " but cannot find device code. Did you forget to bind?";
+  if (target_is_gpu && device_mod->functions.size() == 0) {
+    DLOG(WARNING) << "Specified target " << target->str()
+                  << " but cannot find device code. Did you forget to bind?";
+  }
+
+  return {host_mod, device_mod};
+}
+
+runtime::Module FinalizeModule(const Map<Target, IRModule>& inputs_arg, const 
Target& host_target) {
+  std::vector<runtime::Module> device_modules;
+  Map<Target, IRModule> inputs = inputs_arg;
+  Target target_host = host_target;
+
+  CheckAndUpdateHostConsistency(&inputs, &target_host);
+
+  if (!target_host.defined()) {
+    for (const auto& it : inputs) {
+      if (it.first->kind->device_type == kDLCPU || it.first->kind->device_type 
== kDLMicroDev) {
+        target_host = it.first;
+        break;
+      }
+    }
+  }
+
+  if (!target_host.defined()) {
+    target_host = DefaultTargetHost(target_host);
   }
 
-  if (target->kind->device_type == kDLCPU && target_host == target) {
-    // TODO(@jroesch): This check is no longer true we need to figure out if 
we care about this.
-    // We need to relax this check for just TIR functions.
-    // ICHECK(mdevice->functions.empty()) << "No device code should be 
generated when target "
-    //                                   << "and host_target are both llvm 
target."
-    //                                   << "\n";
+  // Update target host for all targets
+  CheckAndUpdateHostConsistency(&inputs, &target_host);
+
+  IRModule mhost_all = IRModule(Map<GlobalVar, BaseFunc>());
+
+  ICHECK(mhost_all.defined()) << "The host module must be defined";
+
+  for (const auto& it : inputs) {
+    if (it.second.defined()) {
+      auto pair = SplitMixedModule(it.second, it.first, target_host);
+      auto& host_mod = pair.first;
+      auto& device_mod = pair.second;
+
+      ICHECK(host_mod.defined()) << "The split host module must be defined";
+
+      ICHECK(mhost_all.defined()) << "The host module must be defined";
+
+      mhost_all->Update(host_mod);
+
+      if (device_mod->functions.size() != 0) {
+        device_modules.push_back(codegen::Build(device_mod, it.first));
+      }
+    }
   }
 
-  return {mhost, mdevice};
+  runtime::Module complete_mod = codegen::Build(mhost_all, target_host);
+  for (const auto& it : device_modules) {
+    if (it.operator->()) {
+      complete_mod.Import(it);
+    }
+  }
+  return complete_mod;
 }
 
-// Can we make this take one annotated IRModule?
-//
-// Build for heterogeneous execution.
+TVM_REGISTER_GLOBAL("driver.finalize_module")
+    .set_body_typed([](const Map<Target, IRModule>& inputs_arg, Target 
host_target) {
+      return FinalizeModule(inputs_arg, host_target);
+    });
+
 runtime::Module build(const Map<Target, IRModule>& inputs_arg, const Target& 
target_host_arg) {
   auto pass_ctx = transform::PassContext::Current();
 
@@ -498,11 +518,11 @@ runtime::Module build(const Map<Target, IRModule>& 
inputs_arg, const Target& tar
     if (it.second.defined()) {
       const Target& target = it.first;
       const IRModule& ir_module = it.second;
-      auto pair = SplitDevHostFuncs(ir_module, target, target_host, pass_ctx);
-      auto& mhost = pair.first;
-      auto& mdevice = pair.second;
+      auto pair = SplitMixedModule(ir_module, target, target_host);
+      auto& host_mod = pair.first;
+      auto& device_mod = pair.second;
 
-      ICHECK(mhost.defined()) << "The split host module must be defined";
+      ICHECK(host_mod.defined()) << "The split host module must be defined";
 
       ICHECK(mhost_all.defined()) << "The host module must be defined";
 
@@ -513,19 +533,18 @@ runtime::Module build(const Map<Target, IRModule>& 
inputs_arg, const Target& tar
       bool overrides_host_target = target->kind->device_type == 
target_host->kind->device_type;
       bool non_host_target_kind = target->kind != target_host->kind;
       if (overrides_host_target && non_host_target_kind) {
-        device_modules.push_back(codegen::Build(mhost, it.first));
+        device_modules.push_back(codegen::Build(host_mod, it.first));
       } else {
-        mhost_all->Update(mhost);
+        mhost_all->Update(host_mod);
       }
 
-      if (mdevice->functions.size() != 0) {
-        device_modules.push_back(codegen::Build(mdevice, it.first));
+      if (device_mod->functions.size() != 0) {
+        device_modules.push_back(codegen::Build(device_mod, it.first));
       }
     }
   }
 
   runtime::Module mhost = codegen::Build(mhost_all, target_host);
-  // Import all modules
   for (const auto& it : device_modules) {
     if (it.operator->()) {
       mhost.Import(it);
@@ -556,8 +575,97 @@ runtime::Module build(const IRModule& funcs, const Target& 
target_arg,
                       const Target& target_host_arg) {
   auto target = target_arg, target_host = target_host_arg;
   CheckAndUpdateHostConsistency(&target, &target_host);
+  // More maps of target and target host
   Map<Target, IRModule> inputs = {{target, funcs}};
   return build(inputs, target_host);
 }
 
+transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target 
target) {
+  transform::PassContext pass_ctx = transform::PassContext::Current();
+
+  Array<Pass> mixed_pass_list;
+
+  mixed_pass_list.push_back(BindTarget(target));
+
+  mixed_pass_list.push_back(tir::transform::VerifyMemory());
+  
mixed_pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations());
+
+  if (ShouldAnnotateEntryFunc(target, mixed_mod)) {
+    mixed_pass_list.push_back(AnnotateEntryFunc(true));
+  }
+
+  bool detect_global_barrier =
+      pass_ctx->GetConfig<Bool>("tir.detect_global_barrier", 
Bool(false)).value();
+  if (detect_global_barrier) {
+    mixed_pass_list.push_back(tir::transform::ThreadSync("global"));
+  }
+
+  mixed_pass_list.push_back(tir::transform::ThreadSync("shared"));
+  mixed_pass_list.push_back(tir::transform::ThreadSync("warp"));
+  mixed_pass_list.push_back(tir::transform::InferFragment());
+  mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce());
+
+  if (target->GetAttr<Bool>("unpacked-api").value_or(Bool(false))) {
+    mixed_pass_list.push_back(tir::transform::MakeUnpackedAPI());
+  } else {
+    mixed_pass_list.push_back(tir::transform::MakePackedAPI(-1));
+  }
+  mixed_pass_list.push_back(tir::transform::SplitHostDevice());
+
+  return transform::Sequential(mixed_pass_list);
+}
+
+TVM_REGISTER_GLOBAL("driver.mixed_mod_passes")
+    .set_body_typed([](IRModule mixed_mod, Target target) {
+      return MixedModulePassManager(mixed_mod, target);
+    });
+
+transform::Sequential HostModulePassManager(IRModule mixed_mod, Target 
target_host) {
+  Array<tvm::transform::Pass> host_pass_list;
+  host_pass_list.push_back(Filter([](const tir::PrimFunc& f) {
+    return f->GetAttr<Integer>(tvm::attr::kCallingConv, 
Integer(CallingConv::kDefault)) !=
+           CallingConv::kDeviceKernelLaunch;
+  }));
+
+  ICHECK(mixed_mod.defined()) << "This module must be defined";
+
+  host_pass_list.push_back(BindTarget(target_host));
+
+  host_pass_list.push_back(tir::transform::LowerTVMBuiltin());
+  host_pass_list.push_back(tir::transform::LowerCustomDatatypes());
+  host_pass_list.push_back(tir::transform::LowerIntrin());
+  host_pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo());
+  host_pass_list.push_back(tir::transform::CombineContextCall());
+
+  return transform::Sequential(host_pass_list);
+}
+
+TVM_REGISTER_GLOBAL("driver.host_mod_passes")
+    .set_body_typed([](IRModule mixed_mod, Target target_host) {
+      return HostModulePassManager(mixed_mod, target_host);
+    });
+
+transform::Sequential DeviceModulePassManager(IRModule mixed_mod, Target 
target) {
+  Array<Pass> device_pass_list;
+  device_pass_list.push_back(Filter([](const tir::PrimFunc& f) {
+    return f->GetAttr<Integer>(tvm::attr::kCallingConv, 
Integer(CallingConv::kDefault)) ==
+           CallingConv::kDeviceKernelLaunch;
+  }));
+
+  device_pass_list.push_back(BindTarget(target));
+
+  device_pass_list.push_back(tir::transform::LowerWarpMemory());
+  device_pass_list.push_back(tir::transform::Simplify());
+  device_pass_list.push_back(tir::transform::LowerCustomDatatypes());
+  device_pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo());
+  device_pass_list.push_back(tir::transform::LowerIntrin());
+
+  return transform::Sequential(device_pass_list);
+}
+
+TVM_REGISTER_GLOBAL("driver.device_mod_passes")
+    .set_body_typed([](IRModule mixed_mod, Target target_host) {
+      return DeviceModulePassManager(mixed_mod, target_host);
+    });
+
 }  // namespace tvm
diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc
index 36cd0c7..70ad2cc 100644
--- a/src/relay/backend/vm/compiler.cc
+++ b/src/relay/backend/vm/compiler.cc
@@ -824,7 +824,7 @@ class VMFunctionCompiler : 
DeviceAwareExprFunctor<void(const Expr& n)> {
 
   /*!
    * \brief Compile a pattern match expression
-   * It first converts the pattern match expression into a desicision tree, 
the condition
+   * It first converts the pattern match expression into a decision tree, the 
condition
    * could be object comparison or variable binding. If any of the condition 
fails in a clause,
    * the decision tree switches to check the conditions of next clause and so 
on. If no clause
    * matches the value, a fatal node is inserted.

Reply via email to