This is an automated email from the ASF dual-hosted git repository.

yongwww pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 1c35c39264 [Unity] Add Relax multi-device e2e cases  (#15823)
1c35c39264 is described below

commit 1c35c392648e4336fc5e00ab91abb37af997cd59
Author: Yong Wu <[email protected]>
AuthorDate: Wed Dec 20 13:52:56 2023 -0800

    [Unity] Add Relax multi-device e2e cases  (#15823)
    
    * [Unity] filter out non-GPU primfuncs in default_gpu_schedule
    
    * Add relex heterogeneous e2e case
    
    * Remove get_prim_func_device
    
    * Update test cases
    
    * Fix flake8
    
    * fix lint
    
    * Add test case for change of default_gpu_schedule
    
    * fix comment
---
 python/tvm/driver/build_module.py                  |  27 ++-
 python/tvm/relax/utils.py                          |  26 ++-
 python/tvm/relax/vm_build.py                       |  32 ++--
 python/tvm/runtime/relax_vm.py                     |   7 +-
 python/tvm/testing/utils.py                        |  20 +++
 src/ir/module.cc                                   |   2 +-
 src/relax/transform/call_tir_rewrite.cc            |  39 ++++-
 src/relax/transform/legalize_ops.cc                |  42 +++++
 src/relax/transform/utils.h                        |  11 ++
 src/runtime/relax_vm/vm.cc                         |   3 -
 src/script/printer/relax/utils.h                   |   1 -
 src/tir/transforms/default_gpu_schedule.cc         |  49 ++++--
 tests/python/relax/test_frontend_stablehlo.py      |   4 +-
 tests/python/relax/test_vm_multi_device.py         | 186 +++++++++++++++++++++
 .../test_transform_default_gpu_schedule.py         |  73 ++++++++
 15 files changed, 471 insertions(+), 51 deletions(-)

diff --git a/python/tvm/driver/build_module.py 
b/python/tvm/driver/build_module.py
index 9389e7fbee..52303123c1 100644
--- a/python/tvm/driver/build_module.py
+++ b/python/tvm/driver/build_module.py
@@ -243,20 +243,33 @@ def build(
 
     if not isinstance(inputs, (dict, container.Map)):
         target = Target.current() if target is None else target
-        target = target if target else "llvm"
-        target_input_mod = {target: input_mod}
+        if target is None and isinstance(input_mod, tvm.IRModule):
+            target_mod = {}
+            for gvar, func in input_mod.functions.items():
+                tgt = func.attrs["target"] if func.attrs and "target" in 
func.attrs else "llvm"
+                if tgt not in target_mod:
+                    target_mod[tgt] = {}
+                target_mod[tgt][gvar] = func
+
+            target_input_mod = {}
+            for tgt in target_mod.keys():
+                tir_mod = tvm.IRModule(target_mod[tgt])
+                tir_mod.with_attrs(input_mod.attrs)
+                target_input_mod[tgt] = tir_mod
+        else:
+            target_input_mod = {target: input_mod}
     else:
-        target_input_mod = inputs
+        target_input_mod = {tgt: lower(mod) for tgt, mod in inputs.items()}
 
     # Because modules can be created from a variety of sources, we annotate 
them
     # with the relevant attributes here to ensure they propagate
     annotated_mods = {}
-    for tar, mod in target_input_mod.items():
-        if not isinstance(tar, (str, Target)):
+    for tgt, mod in target_input_mod.items():
+        if not isinstance(tgt, (str, Target)):
             raise ValueError("The key of inputs must be str or " "Target when 
inputs is dict.")
         if not isinstance(mod, tvm.IRModule):
-            raise ValueError("inputs must be Schedule, IRModule," "or dict of 
str to IRModule.")
-        annotated_mods[tar] = mod.with_attr("runtime", runtime)
+            raise ValueError("inputs must be Schedule, IRModule, " "or dict of 
str to IRModule.")
+        annotated_mods[tgt] = mod.with_attr("runtime", runtime)
 
     # TODO(mbs): Both CompilationConfig and TIRToRuntime implement the same 
host target
     #  defaulting logic, but there's currently no way to get back the decided 
host.
diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py
index a1fa9cafe8..b720a727f6 100644
--- a/python/tvm/relax/utils.py
+++ b/python/tvm/relax/utils.py
@@ -28,7 +28,7 @@ from . import _ffi_api
 from .expr import Tuple as rx_Tuple
 from .expr import Expr, ShapeExpr, Function, PrimValue, StringImm, te_tensor
 from ..te import Tensor as te_Tensor, create_prim_func
-from ..ir import Array, Attrs, Type, Map
+from ..ir import Array, Attrs, Type, Map, VDevice
 from .struct_info import PrimStructInfo, ShapeStructInfo, TensorStructInfo
 
 
@@ -418,6 +418,24 @@ def gen_call_tir_inputs(
         diff = used_vars - bound_vars
         return list(diff)
 
+    def _get_vdevice(arg: Any) -> Optional[VDevice]:
+        """get the virtual device from arguments."""
+        vdevice = None
+        if isinstance(arg, Expr):  # type: ignore
+            if isinstance(arg.struct_info, TensorStructInfo):
+                vdevice = arg.struct_info.vdevice
+        elif isinstance(arg, (list, Array, tuple)):
+            for x in arg:
+                vdevice = _get_vdevice(x)
+                if vdevice is not None:
+                    return vdevice
+        elif isinstance(arg, (dict, Map)):
+            for k in arg:
+                vdevice = _get_vdevice(arg[k])
+                if vdevice is not None:
+                    return vdevice
+        return vdevice
+
     def _shape_with_old_tir_var(
         shape_values: List[tir.PrimExpr], tir_var_inverse_map: Dict[tir.Var, 
tir.PrimExpr]
     ):
@@ -456,7 +474,11 @@ def gen_call_tir_inputs(
     tir_var_inverse_map = {v: k for k, v in tir_var_map.items()}
 
     output_sinfo = [
-        TensorStructInfo(_shape_with_old_tir_var(out.shape, 
tir_var_inverse_map), out.dtype)
+        TensorStructInfo(
+            _shape_with_old_tir_var(out.shape, tir_var_inverse_map),
+            out.dtype,
+            _get_vdevice(args),
+        )
         for out in outs
     ]
 
diff --git a/python/tvm/relax/vm_build.py b/python/tvm/relax/vm_build.py
index 62760b3417..9120f74e13 100644
--- a/python/tvm/relax/vm_build.py
+++ b/python/tvm/relax/vm_build.py
@@ -79,7 +79,7 @@ class Executable:
             vm = tvm.relax.VirtualMachine(rt_mod, tvm.cuda())
         """
 
-        # TODO(tvm-team): Update runtime.Module interfac
+        # TODO(tvm-team): Update runtime.Module interface
         # to query these properties as bitmask.
         def _not_runnable(x):
             return x.type_key in ("c", "static_library")
@@ -179,13 +179,17 @@ def _vmcodegen(
     raise ValueError(f"Unknown exec_mode {exec_mode}")
 
 
-def _autodetect_system_lib_req(target: tvm.target.Target, system_lib):
+def _autodetect_system_lib_req(
+    target: Optional[tvm.target.Target] = None, system_lib: Optional[bool] = 
None
+):
     """Automatically detect system lib requirement"""
-    host = target if target.host is None else target.host
-    if system_lib is None:
-        system_lib = False
-        if "wasm" in host.attrs.get("mtriple", ""):
-            system_lib = True
+    if target is not None:
+        host = target if target.host is None else target.host
+        if system_lib is None:
+            system_lib = False
+            if "wasm" in host.attrs.get("mtriple", ""):
+                system_lib = True
+
     if system_lib:
         # use packed-func to avoid relay dep.
         return tvm.get_global_func("relay.backend.CreateRuntime")("cpp", 
{"system-lib": system_lib})
@@ -194,7 +198,7 @@ def _autodetect_system_lib_req(target: tvm.target.Target, 
system_lib):
 
 def _vmlink(
     builder: "relax.ExecBuilder",
-    target: Union[str, tvm.target.Target],
+    target: Optional[Union[str, tvm.target.Target]],
     tir_mod: Optional[tvm.IRModule] = None,
     ext_libs: List[tvm.runtime.Module] = None,
     params: Optional[Dict[str, list]] = None,
@@ -213,8 +217,10 @@ def _vmlink(
     builder: relax.ExecBuilder
         Builder used to collect executables.
 
-    target : Union[str, tvm.target.Target]
+    target : Optional[Union[str, tvm.target.Target]]
         A build target which can have optional host side compilation target.
+        If the target is not specified, the target in the vdevice list will be 
used.
+        For multi-target compilation, the vdevice should be annotated.
 
     tir_mod: IRModule
         The input TIR IRModule to be linked together.
@@ -239,14 +245,16 @@ def _vmlink(
     lib = None
     if tir_mod is not None:
         lib = tvm.build(
-            tir_mod, target=target, runtime=_autodetect_system_lib_req(target, 
system_lib)
+            tir_mod,
+            target=target,
+            runtime=_autodetect_system_lib_req(target, system_lib),
         )
     return Executable(_ffi_api.VMLink(builder, target, lib, ext_libs, params)) 
 # type: ignore
 
 
 def build(
     mod: tvm.IRModule,
-    target: Union[str, tvm.target.Target],
+    target: Optional[Union[str, tvm.target.Target]] = None,
     params: Optional[Dict[str, list]] = None,
     pipeline: Union[None, str, tvm.transform.Pass] = "default_build",
     exec_mode: str = "bytecode",
@@ -261,7 +269,7 @@ def build(
     mod: IRModule
         The input IRModule to be built.
 
-    target : Union[str, tvm.target.Target]
+    target : Optional[Union[str, tvm.target.Target]]
         A build target which can have optional host side compilation target.
 
         When TVM compiles device specific program such as CUDA,
diff --git a/python/tvm/runtime/relax_vm.py b/python/tvm/runtime/relax_vm.py
index a925e048b2..5b8bbe6d33 100644
--- a/python/tvm/runtime/relax_vm.py
+++ b/python/tvm/runtime/relax_vm.py
@@ -54,7 +54,7 @@ class VirtualMachine(object):
 
         Parameters
         ----------
-        mod: Union[tvm.runtime.Module, tvm.relax.Executable]
+        rt_mod: Union[tvm.runtime.Module, tvm.relax.Executable]
             Runtime module exported by the result of build.
 
         device : Union[Device, List[Device]]
@@ -107,11 +107,6 @@ class VirtualMachine(object):
                 )
             devs = [dev]
 
-        if any(dev.device_type % RPC_SESS_MASK == tvm.cpu().device_type for 
dev in devs[:-1]):
-            raise RuntimeError(
-                "CPU host is required to be the last element of the device 
list if provided."
-            )
-
         # CPU is required for executing shape functions
         if devs[-1].device_type % RPC_SESS_MASK != tvm.cpu().device_type:
             devs.append(tvm.cpu())
diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py
index 29c9463ba5..ccad989c33 100644
--- a/python/tvm/testing/utils.py
+++ b/python/tvm/testing/utils.py
@@ -832,6 +832,16 @@ def _any_gpu_exists():
     )
 
 
+def _multi_gpu_exists():
+    return (
+        (tvm.cuda(0).exist and tvm.cuda(1).exist)
+        or (tvm.rocm(0).exist and tvm.rocm(1).exist)
+        or (tvm.opencl(0).exist and tvm.opencl(1).exist)
+        or (tvm.metal(0).exist and tvm.metal(1).exist)
+        or (tvm.vulkan(0).exist and tvm.vulkan(1).exist)
+    )
+
+
 # Mark a test as requiring llvm to run
 requires_llvm = Feature(
     "llvm", "LLVM", cmake_flag="USE_LLVM", target_kind_enabled="llvm", 
target_kind_hardware="llvm"
@@ -847,6 +857,16 @@ requires_gpu = Feature("gpu", 
run_time_check=_any_gpu_exists)
 # :py:func:`tvm.testing.requires_gpu`.
 uses_gpu = requires_gpu(support_required="optional")
 
+# Mark a test as requiring multiple GPUs to run.
+requires_multi_gpu = Feature("multi_gpu", run_time_check=_multi_gpu_exists)
+
+# Mark to differentiate tests that use multiple GPUs in some capacity.
+#
+# These tests will be run on test nodes with multiple GPUs.
+# To mark a test that must have multiple GPUs present to run, use
+# :py:func:`tvm.testing.requires_multi_gpu`.
+uses_multi_gpu = requires_multi_gpu(support_required="optional")
+
 # Mark a test as requiring the x86 Architecture to run.
 requires_x86 = Feature(
     "x86", "x86 Architecture", run_time_check=lambda: platform.machine() == 
"x86_64"
diff --git a/src/ir/module.cc b/src/ir/module.cc
index c016612c15..156158a85f 100644
--- a/src/ir/module.cc
+++ b/src/ir/module.cc
@@ -324,7 +324,7 @@ void IRModuleNode::Update(const IRModule& mod) {
 
 IRModule IRModuleNode::ShallowCopy() {
   return IRModule(this->functions, this->type_definitions, this->Imports(), 
this->source_map,
-                  this->attrs);
+                  this->attrs, this->global_infos);
 }
 
 std::pair<IRModule, GlobalVar> IRModule::FromExprInContext(
diff --git a/src/relax/transform/call_tir_rewrite.cc 
b/src/relax/transform/call_tir_rewrite.cc
index e040ccea14..760d04a220 100644
--- a/src/relax/transform/call_tir_rewrite.cc
+++ b/src/relax/transform/call_tir_rewrite.cc
@@ -18,7 +18,8 @@
  */
 /*!
  * \file src/relax/transform/call_tir_rewrite.cc
- * \brief Perform explicit tensor allocation for call_tir.
+ * \brief Perform explicit tensor allocation for call_tir,
+ *        call_tir_inplace, and call_dps_packed.
  */
 #include <tvm/relax/attrs/op.h>
 #include <tvm/relax/expr_functor.h>
@@ -28,6 +29,7 @@
 #include <tvm/tir/op.h>
 
 #include "../../relay/transforms/pattern_utils.h"
+#include "utils.h"
 
 namespace tvm {
 namespace relax {
@@ -43,6 +45,19 @@ namespace relax {
 
 class CallTIRMutator : public ExprMutator {
  public:
+  explicit CallTIRMutator(const IRModule& mod) : ExprMutator(mod), 
mod_(std::move(mod)) {}
+
+  IRModule Run() {
+    for (const auto& [gv, func] : mod_->functions) {
+      if (func->IsInstance<FunctionNode>()) {
+        auto updated_func = Downcast<Function>(this->VisitExpr(func));
+        builder_->UpdateFunction(gv, Downcast<BaseFunc>(updated_func));
+      }
+    }
+    return builder_->GetContextIRModule();
+  }
+
+ private:
   using ExprMutator::VisitExpr_;
   Expr VisitExpr_(const CallNode* call) override {
     // post-order mutation
@@ -65,11 +80,15 @@ class CallTIRMutator : public ExprMutator {
         const TensorStructInfo& tensor_sinfo = _tensor_sinfo.value();
         ICHECK(tensor_sinfo->shape.defined())
             << "the TensorStructInfo shape of call_tir has not populated";
+        int dev_index = 0;
+        if (tensor_sinfo->vdevice.defined()) {
+          dev_index = GetDeviceIndex(mod_, tensor_sinfo->vdevice.value());
+        }
         if (!is_inplace) {
           outs.push_back(
-              builder_->Emit(Call(alloc_tensor_op,  //
+              builder_->Emit(Call(alloc_tensor_op,
                                   
{Downcast<ShapeExpr>(tensor_sinfo->shape.value()),
-                                   DataTypeImm(tensor_sinfo->dtype), 
PrimValue::Int64(0)},  //
+                                   DataTypeImm(tensor_sinfo->dtype), 
PrimValue::Int64(dev_index)},
                                   Attrs()),
                              "alloc"));
         } else {
@@ -150,16 +169,20 @@ class CallTIRMutator : public ExprMutator {
 
     return GetRef<Expr>(call);
   }
-};
 
-Expr CallTIRRewrite(const Expr& e) { return CallTIRMutator().VisitExpr(e); }
+  /*! \brief The context IRModule. */
+  IRModule mod_;
+};
 
 namespace transform {
 
 Pass CallTIRRewrite() {
-  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> 
pass_func =
-      [=](Function f, IRModule m, PassContext pc) { return 
Downcast<Function>(CallTIRRewrite(f)); };
-  return CreateFunctionPass(pass_func, 0, "CallTIRRewrite", {});
+  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
+      [=](IRModule mod, PassContext pc) { return CallTIRMutator(mod).Run(); };
+  return CreateModulePass(/*pass_function=*/pass_func,
+                          /*opt_level=*/0,
+                          /*pass_name=*/"CallTIRRewrite",
+                          /*required=*/{});
 }
 
 
TVM_REGISTER_GLOBAL("relax.transform.CallTIRRewrite").set_body_typed(CallTIRRewrite);
diff --git a/src/relax/transform/legalize_ops.cc 
b/src/relax/transform/legalize_ops.cc
index a557a41f8e..c8fba59dab 100644
--- a/src/relax/transform/legalize_ops.cc
+++ b/src/relax/transform/legalize_ops.cc
@@ -26,6 +26,7 @@
 #include <tvm/relax/analysis.h>
 #include <tvm/relax/expr_functor.h>
 #include <tvm/relax/op_attr_types.h>
+#include <tvm/relax/struct_info.h>
 #include <tvm/relax/transform.h>
 
 namespace tvm {
@@ -72,6 +73,14 @@ class LegalizeMutator : public ExprMutator {
         builder_->UpdateFunction(gv, Downcast<BaseFunc>(updated_func));
       }
     }
+    // Fill the "kTarget" attribute of PrimFunc
+    for (const auto& [gv, func] : builder_->GetContextIRModule()->functions) {
+      const tir::PrimFuncNode* prim_func;
+      if (tmap_.count(gv) && (prim_func = func.as<tir::PrimFuncNode>())) {
+        auto f = WithAttr(GetRef<tir::PrimFunc>(prim_func), 
tvm::attr::kTarget, tmap_[gv]);
+        builder_->UpdateFunction(gv, f);
+      }
+    }
     return builder_->GetContextIRModule();
   }
 
@@ -109,6 +118,33 @@ class LegalizeMutator : public ExprMutator {
     return Call(call_pure_packed_op, ret_args, ret->attrs, ret->sinfo_args);
   }
 
+  Target GetTarget(const Array<StructInfo>& sinfos) {
+    for (auto sinfo : sinfos) {
+      if (const auto* tinfo = sinfo.as<TensorStructInfoNode>()) {
+        if (tinfo->vdevice.defined()) {
+          auto vdevice = tinfo->vdevice.value();
+          if (vdevice->target.defined()) {
+            return vdevice->target;
+          }
+        }
+      } else if (const auto* tup_sinfo = sinfo.as<TupleStructInfoNode>()) {
+        return GetTarget(tup_sinfo->fields);
+      }
+    }
+    return Target();
+  }
+
+  void SaveTarget(const Expr& expr) {
+    if (expr->IsInstance<CallNode>()) {
+      auto call = Downcast<Call>(expr);
+      auto target = GetTarget(call->sinfo_args);
+      const GlobalVarNode* gvar_node;
+      if (target.defined() && (gvar_node = call->args[0].as<GlobalVarNode>())) 
{
+        this->tmap_.Set(GetRef<GlobalVar>(gvar_node), target);
+      }
+    }
+  }
+
   Expr VisitExpr_(const CallNode* call) final {
     Call visited_call = Downcast<Call>(this->VisitExprPostOrder_(call));
     static const auto& legalize_map = Op::GetAttrMap<FLegalize>("FLegalize");
@@ -164,6 +200,10 @@ class LegalizeMutator : public ExprMutator {
       builder_->BeginBindingBlock();
     }
     Expr legalized = legalization_func(builder_, visited_call);
+
+    // Save the expected target info. into tmap_
+    SaveTarget(legalized);
+
     legalized = builder_->Normalize(legalized);
 
     BindingBlock prologue = builder_->EndBlock();
@@ -196,6 +236,8 @@ class LegalizeMutator : public ExprMutator {
   IRModule mod_;
   /*! \brief The customized legalization function map. */
   Map<String, PackedFunc> cmap_;
+  /*! \brief The map from GlobalVar of PrimFunc to compilation Target. */
+  Map<GlobalVar, Target> tmap_;
   /*!
    * \brief A boolean value indicating if to print warnings for CallNode whose 
op's
    * legalization function is not registered.
diff --git a/src/relax/transform/utils.h b/src/relax/transform/utils.h
index 2226e62763..8b3525c628 100644
--- a/src/relax/transform/utils.h
+++ b/src/relax/transform/utils.h
@@ -383,6 +383,17 @@ inline String GetCodegenName(const std::string& 
composite_name) {
   return composite_name.substr(0, delim_pos);
 }
 
+inline int GetDeviceIndex(const IRModule& mod, const VDevice& vdevice) {
+  Array<GlobalInfo> vdevices = mod->global_infos["vdevice"];
+  for (int i = 0; i < static_cast<int>(vdevices.size()); ++i) {
+    if (vdevices[i] == vdevice) {
+      return i;
+    }
+  }
+  LOG(FATAL) << "The vdevice is not in the ir_module.";
+  return -1;
+}
+
 /* \brief Eliminate common subexpressions
  *
  * Utility for simplifying relax expressions by removing common
diff --git a/src/runtime/relax_vm/vm.cc b/src/runtime/relax_vm/vm.cc
index b31268e697..d7f943d5f4 100644
--- a/src/runtime/relax_vm/vm.cc
+++ b/src/runtime/relax_vm/vm.cc
@@ -440,9 +440,6 @@ void 
VirtualMachineImpl::LoadExecutable(ObjectPtr<Executable> exec) {
 
 void VirtualMachineImpl::Init(const std::vector<Device>& devices,
                               const std::vector<AllocatorType>& alloc_types) {
-  // TODO(@yuchen): support multi-device heterogeneous execution
-  ICHECK_LT(devices.size(), 3)
-      << "Currently relax vm only supports at most 2 devices (host + device)";
   ICHECK_EQ(devices.size(), alloc_types.size());
 
   this->devices.reserve(devices.size());
diff --git a/src/script/printer/relax/utils.h b/src/script/printer/relax/utils.h
index e0b5348d73..58b8bf4431 100644
--- a/src/script/printer/relax/utils.h
+++ b/src/script/printer/relax/utils.h
@@ -109,7 +109,6 @@ inline int FindVDeviceIndexByTargetKind(const VDevice& 
vdevice, const IRDocsifie
       kind_index++;
     }
   }
-  LOG(WARNING) << "The VDevice was not found in the global_infos map: " << 
vdevice;
   return -1;
 }
 
diff --git a/src/tir/transforms/default_gpu_schedule.cc 
b/src/tir/transforms/default_gpu_schedule.cc
index 5a22d0b0d9..6cf7f6e067 100644
--- a/src/tir/transforms/default_gpu_schedule.cc
+++ b/src/tir/transforms/default_gpu_schedule.cc
@@ -98,24 +98,53 @@ IRModule MarkScheduled(const IRModule& mod) {
                   mod->type_definitions,  // type_definitions
                   mod->import_set_,       // import_set
                   mod->source_map,        // map
-                  mod->attrs);            // attrs);
+                  mod->attrs,             // attrs
+                  mod->global_infos);     // global_infos
+}
+
+bool IsScheduledOnGPU(const BaseFunc& func) {
+  // the target from context.
+  tvm::Target target = tvm::Target::Current();
+  // the Target in kTarget attribute of PrimFunc
+  Optional<tvm::Target> func_target = 
func->attrs.GetAttr<tvm::Target>(tvm::attr::kTarget);
+  if (func_target.defined()) {
+    target = func_target.value();
+  }
+
+  if (target.defined()) {
+    int dev_type = target->GetTargetDeviceType();
+    if (dev_type != kDLCUDA) {
+      return false;
+    }
+  }
+  return true;
 }
 
 Pass DefaultGPUSchedule() {
   runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =  //
       [=](IRModule m, PassContext pc) {
-        // get the target from context.
-        tvm::Target target = tvm::Target::Current();
-        ICHECK(target.defined()) << "Target is not set in current context";
-        // get the max thread per block from target.
-        Optional<Integer> opt_max_thread_per_block = 
target->GetAttr<Integer>("max_num_threads");
-        ICHECK(opt_max_thread_per_block.defined())
-            << "max_num_threads is not set for target " << target;
-        int64_t max_thread_per_block = 
opt_max_thread_per_block.value().IntValue();
         tir::Schedule sch = tir::Schedule::Traced(m, /*seed=*/-1, 
/*debug_mask=*/0,
                                                   
tir::ScheduleErrorRenderLevel::kDetail);
         for (const auto& [gv, func] : m->functions) {
-          if (func->IsInstance<tir::PrimFuncNode>() && 
!func->HasNonzeroAttr(attr::kIsScheduled)) {
+          if (func->IsInstance<tir::PrimFuncNode>() && 
!func->HasNonzeroAttr(attr::kIsScheduled) &&
+              IsScheduledOnGPU(func)) {
+            // get the target from context.
+            tvm::Target target = tvm::Target::Current();
+            // get the target from kTarget attribute
+            Optional<tvm::Target> func_target =
+                func->attrs.GetAttr<tvm::Target>(tvm::attr::kTarget);
+            if (func_target.defined()) {
+              target = func_target.value();
+            }
+            ICHECK(target.defined()) << "The target is missing either in the 
current context or in "
+                                        "the prim_func's attribute.";
+            // get the max thread per block from target.
+            Optional<Integer> opt_max_thread_per_block =
+                target->GetAttr<Integer>("max_num_threads");
+            ICHECK(opt_max_thread_per_block.defined())
+                << "max_num_threads is not set for target " << target;
+            int64_t max_thread_per_block = 
opt_max_thread_per_block.value().IntValue();
+
             sch->WorkOn(gv->name_hint);
             Array<tir::BlockRV> blocks = 
meta_schedule::BlockCollector::Collect(sch);
             for (const tir::BlockRV& block : blocks) {
diff --git a/tests/python/relax/test_frontend_stablehlo.py 
b/tests/python/relax/test_frontend_stablehlo.py
index d3068f29c7..f2d0461dda 100644
--- a/tests/python/relax/test_frontend_stablehlo.py
+++ b/tests/python/relax/test_frontend_stablehlo.py
@@ -132,7 +132,7 @@ def check_correctness(
 
     # Multiple ouputs
     assert len(tvm_output) == len(jax_output), "numbers of outputs mismatch"
-    for (tvm_out, jax_out) in zip(tvm_output, jax_output):
+    for tvm_out, jax_out in zip(tvm_output, jax_output):
         tvm.testing.assert_allclose(tvm_out.numpy(), jax_out, rtol=1e-5, 
atol=1e-5)
 
 
@@ -314,7 +314,9 @@ def test_dot_general():
     check_correctness(jax.jit(fn), input_shapes)
 
 
[email protected]()
 @tvm.testing.requires_gpu
+# TODO(yongwww): fix flaky error of "invalid device ordinal"
 def test_conv():
     import jax
     from flax import linen as nn
diff --git a/tests/python/relax/test_vm_multi_device.py 
b/tests/python/relax/test_vm_multi_device.py
new file mode 100644
index 0000000000..ec2fbd1cdf
--- /dev/null
+++ b/tests/python/relax/test_vm_multi_device.py
@@ -0,0 +1,186 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test eliminate common subexpr pass"""
+from typing import List
+import tvm
+from tvm import relax
+import tvm.testing
+from tvm.ir.module import IRModule
+from tvm.script.parser import ir as I, relax as R
+from tvm._ffi.runtime_ctypes import Device
+import numpy as np
+
+
+def compile(
+    mod: IRModule,
+    device: List[Device] = [
+        tvm.cpu(),
+    ],
+) -> relax.VirtualMachine:
+    # compile the model
+    mod = relax.transform.RealizeVDevice()(mod)
+    mod = relax.transform.LegalizeOps()(mod)
+    mod = tvm.tir.transform.DefaultGPUSchedule()(mod)
+    # no need to feed target argument for mult-target compilation
+    ex = relax.build(mod)
+
+    return relax.VirtualMachine(ex, device)
+
+
+def test_multi_cpu():
+    @I.ir_module
+    class Example:
+        I.module_attrs({"attr": 10})
+        I.module_global_infos(
+            {
+                "vdevice": [
+                    I.vdevice("llvm", 0),
+                    I.vdevice("llvm", 1),
+                ]
+            }
+        )
+
+        @R.function
+        def foo(
+            x: R.Tensor((2, 3), "float32"),
+            y: R.Tensor((3, 4), "float32"),
+            z: R.Tensor((4, 5), "float32"),
+        ) -> R.Tensor((2, 5), "float32"):
+            with R.dataflow():
+                lv0: R.Tensor((2, 4), "float32", "llvm:0") = R.matmul(x, y)  # 
noqa: F722
+                lv1: R.Tensor((2, 4), "float32", "llvm:1") = R.to_vdevice(  # 
noqa: F722
+                    lv0, "llvm:1"  # noqa: F722
+                )
+                gv = R.matmul(lv1, z)  # noqa: F722
+                R.output(gv)
+            return gv
+
+    devices = [tvm.cpu(0), tvm.cpu(1)]
+    vm = compile(Example, devices)
+
+    np_ipt0 = np.random.rand(2, 3).astype(np.float32)
+    np_ipt1 = np.random.rand(3, 4).astype(np.float32)
+    np_ipt2 = np.random.rand(4, 5).astype(np.float32)
+    np_res = np.matmul(np.matmul(np_ipt0, np_ipt1), np_ipt2)
+
+    ipt0 = tvm.nd.array(np_ipt0, devices[0])
+    ipt1 = tvm.nd.array(np_ipt1, devices[0])
+    ipt2 = tvm.nd.array(np_ipt2, devices[1])
+    res = vm["foo"](ipt0, ipt1, ipt2)
+    tvm.testing.assert_allclose(res.numpy(), np_res)
+
+
[email protected]_multi_gpu
+def test_multi_gpu():
+    @I.ir_module
+    class Example:
+        I.module_attrs({"attr": 10})
+        I.module_global_infos(
+            {
+                "vdevice": [
+                    I.vdevice("cuda", 1),
+                    I.vdevice("cuda", 0),
+                    I.vdevice("cuda", 2),
+                ]
+            }
+        )
+
+        @R.function
+        def foo(
+            a: R.Tensor((2, 3), "float32"),
+            b: R.Tensor((3, 4), "float32"),
+            c: R.Tensor((4, 5), "float32"),
+            d: R.Tensor((5, 6), "float32"),
+        ) -> R.Tensor((2, 6), "float32"):
+            with R.dataflow():
+                lv0: R.Tensor((2, 4), "float32", "cuda:0") = R.matmul(a, b)  # 
noqa: F722
+                lv1: R.Tensor((2, 4), "float32", "cuda:1") = R.to_vdevice(  # 
noqa: F722
+                    lv0, "cuda:1"  # noqa: F722
+                )
+                lv2: R.Tensor((2, 5), "float32", "cuda:1") = R.matmul(lv1, c)  
# noqa: F722
+                lv3: R.Tensor((2, 5), "float32", "cuda:2") = R.to_vdevice(  # 
noqa: F722
+                    lv2, "cuda:2"  # noqa: F722
+                )
+                gv: R.Tensor((2, 6), "float32", "cuda:2") = R.matmul(lv3, d)  
# noqa: F722
+                R.output(gv)
+            return gv
+
+    # The number and ordering of devices should be identical with the vdevice 
list
+    # defined in global_infos of ir_module
+    devices = [tvm.cuda(1), tvm.cuda(0), tvm.cuda(2)]
+    vm = compile(Example, devices)
+
+    np_ipt0 = np.random.rand(2, 3).astype(np.float32)
+    np_ipt1 = np.random.rand(3, 4).astype(np.float32)
+    np_ipt2 = np.random.rand(4, 5).astype(np.float32)
+    np_ipt3 = np.random.rand(5, 6).astype(np.float32)
+    np_res = np.matmul(np.matmul(np.matmul(np_ipt0, np_ipt1), np_ipt2), 
np_ipt3)
+
+    ipt0 = tvm.nd.array(np_ipt0, devices[0])
+    ipt1 = tvm.nd.array(np_ipt1, devices[0])
+    ipt2 = tvm.nd.array(np_ipt2, devices[1])
+    ipt3 = tvm.nd.array(np_ipt3, devices[2])
+    res = vm["foo"](ipt0, ipt1, ipt2, ipt3)
+    tvm.testing.assert_allclose(res.numpy(), np_res)
+
+
[email protected]_gpu
+def test_multi_device():
+    @I.ir_module
+    class Example:
+        I.module_attrs({"attr": 10})
+        I.module_global_infos(
+            {
+                "vdevice": [
+                    I.vdevice("cuda", 0),
+                    I.vdevice("llvm"),
+                ]
+            }
+        )
+
+        @R.function
+        def foo(
+            x: R.Tensor((2, 3), "float32"),
+            y: R.Tensor((3, 4), "float32"),
+            z: R.Tensor((4, 5), "float32"),
+        ) -> R.Tensor((2, 5), "float32"):
+            with R.dataflow():
+                lv0: R.Tensor((2, 4), "float32", "llvm") = R.matmul(x, y)
+                lv1: R.Tensor((2, 4), "float32", "cuda") = R.to_vdevice(lv0, 
"cuda")
+                gv: R.Tensor((2, 5), "float32", "cuda") = R.matmul(lv1, z)
+                R.output(gv)
+            return gv
+
+    # The number and ordering of devices should be identical with the vdevice 
list
+    # defined in global_infos of ir_module
+    devices = [tvm.cuda(0), tvm.cpu(0)]
+    vm = compile(Example, devices)
+
+    np_ipt0 = np.random.rand(2, 3).astype(np.float32)
+    np_ipt1 = np.random.rand(3, 4).astype(np.float32)
+    np_ipt2 = np.random.rand(4, 5).astype(np.float32)
+    np_res = np.matmul(np.matmul(np_ipt0, np_ipt1), np_ipt2)
+
+    ipt0 = tvm.nd.array(np_ipt0, devices[1])
+    ipt1 = tvm.nd.array(np_ipt1, devices[1])
+    ipt2 = tvm.nd.array(np_ipt2, devices[0])
+    res = vm["foo"](ipt0, ipt1, ipt2)
+    tvm.testing.assert_allclose(res.numpy(), np_res, rtol=1e-4, atol=1e-4)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/tir-transform/test_transform_default_gpu_schedule.py 
b/tests/python/tir-transform/test_transform_default_gpu_schedule.py
index 1af846c9d5..63809beade 100644
--- a/tests/python/tir-transform/test_transform_default_gpu_schedule.py
+++ b/tests/python/tir-transform/test_transform_default_gpu_schedule.py
@@ -88,6 +88,49 @@ def test_matmul():
                         C[v_i, v_j] = T.float16(0)
                     C[v_i, v_j] = C[v_i, v_j] + A[v_i, v_k] * B[v_k, v_j]
 
+        @T.prim_func
+        def matmul_gpu(
+            A: T.Buffer((32, 32), "float16"),
+            B: T.Buffer((32, 32), "float16"),
+            C: T.Buffer((32, 32), "float16"),
+        ):
+            T.func_attr({"global_symbol": "main",
+                         "target": T.target({"arch": "sm_86",
+                                             "keys": ["cuda", "gpu"],
+                                             "kind": "cuda",
+                                             "max_num_threads": 1024,
+                                             "tag": "",
+                                             "thread_warp_size": 32}),
+                         "tir.noalias": True})
+            # with T.block("root"):
+            for i, j, k in T.grid(32, 32, 32):
+                with T.block("C"):
+                    v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
+                    T.reads(A[v_i, v_k], B[v_k, v_j])
+                    T.writes(C[v_i, v_j])
+                    with T.init():
+                        C[v_i, v_j] = T.float16(0)
+                    C[v_i, v_j] = C[v_i, v_j] + A[v_i, v_k] * B[v_k, v_j]
+
+        @T.prim_func
+        def matmul_cpu(
+            A: T.Buffer((32, 32), "float16"),
+            B: T.Buffer((32, 32), "float16"),
+            C: T.Buffer((32, 32), "float16"),
+        ):
+            T.func_attr({"global_symbol": "main",
+                         "target": T.target({"keys": ["cpu"], "kind": "llvm", 
"tag": ""}),
+                        "tir.noalias": True})
+            # with T.block("root"):
+            for i, j, k in T.grid(32, 32, 32):
+                with T.block("C"):
+                    v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
+                    T.reads(A[v_i, v_k], B[v_k, v_j])
+                    T.writes(C[v_i, v_j])
+                    with T.init():
+                        C[v_i, v_j] = T.float16(0)
+                    C[v_i, v_j] = C[v_i, v_j] + A[v_i, v_k] * B[v_k, v_j]
+
     @tvm.script.ir_module
     class Expected:
         @T.prim_func
@@ -114,6 +157,36 @@ def test_matmul():
                             with T.init():
                                 C[v_i, v_j] = T.float16(0)
                             C[v_i, v_j] = C[v_i, v_j] + A[v_i, v_k] * B[v_k, 
v_j]
+
+        @T.prim_func
+        def matmul_cpu(A: T.Buffer((32, 32), "float16"), B: T.Buffer((32, 32), 
"float16"), C: T.Buffer((32, 32), "float16")):
+            T.func_attr({"global_symbol": "main", "target": T.target({"keys": 
["cpu"], "kind": "llvm", "tag": ""}), "tir.is_scheduled": T.bool(True), 
"tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            for i, j, k in T.grid(32, 32, 32):
+                with T.block("C"):
+                    v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
+                    T.reads(A[v_i, v_k], B[v_k, v_j])
+                    T.writes(C[v_i, v_j])
+                    with T.init():
+                        C[v_i, v_j] = T.float16(0)
+                    C[v_i, v_j] = C[v_i, v_j] + A[v_i, v_k] * B[v_k, v_j]
+
+        @T.prim_func
+        def matmul_gpu(A: T.Buffer((32, 32), "float16"), B: T.Buffer((32, 32), 
"float16"), C: T.Buffer((32, 32), "float16")):
+            T.func_attr({"global_symbol": "main", "target": T.target({"arch": 
"sm_86", "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, 
"tag": "", "thread_warp_size": 32}), "tir.is_scheduled": T.bool(True), 
"tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            for i_j_fused_0 in T.thread_binding(1, thread="blockIdx.x"):
+                for i_j_fused_1 in T.thread_binding(1024, 
thread="threadIdx.x"):
+                    for k in range(32):
+                        with T.block("C"):
+                            v_i = T.axis.spatial(32, (i_j_fused_0 * 1024 + 
i_j_fused_1) // 32)
+                            v_j = T.axis.spatial(32, (i_j_fused_0 * 1024 + 
i_j_fused_1) % 32)
+                            v_k = T.axis.reduce(32, k)
+                            T.reads(A[v_i, v_k], B[v_k, v_j])
+                            T.writes(C[v_i, v_j])
+                            with T.init():
+                                C[v_i, v_j] = T.float16(0)
+                            C[v_i, v_j] = C[v_i, v_j] + A[v_i, v_k] * B[v_k, 
v_j]
     # fmt: on
     # pylint: enable=no-self-argument,missing-class-docstring,line-too-long
     target = tvm.target.Target("nvidia/geforce-rtx-3070")


Reply via email to