This is an automated email from the ASF dual-hosted git repository.
yongwww pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 1c35c39264 [Unity] Add Relax multi-device e2e cases (#15823)
1c35c39264 is described below
commit 1c35c392648e4336fc5e00ab91abb37af997cd59
Author: Yong Wu <[email protected]>
AuthorDate: Wed Dec 20 13:52:56 2023 -0800
[Unity] Add Relax multi-device e2e cases (#15823)
* [Unity] filter out non-GPU primfuncs in default_gpu_schedule
* Add relex heterogeneous e2e case
* Remove get_prim_func_device
* Update test cases
* Fix flake8
* fix lint
* Add test case for change of default_gpu_schedule
* fix comment
---
python/tvm/driver/build_module.py | 27 ++-
python/tvm/relax/utils.py | 26 ++-
python/tvm/relax/vm_build.py | 32 ++--
python/tvm/runtime/relax_vm.py | 7 +-
python/tvm/testing/utils.py | 20 +++
src/ir/module.cc | 2 +-
src/relax/transform/call_tir_rewrite.cc | 39 ++++-
src/relax/transform/legalize_ops.cc | 42 +++++
src/relax/transform/utils.h | 11 ++
src/runtime/relax_vm/vm.cc | 3 -
src/script/printer/relax/utils.h | 1 -
src/tir/transforms/default_gpu_schedule.cc | 49 ++++--
tests/python/relax/test_frontend_stablehlo.py | 4 +-
tests/python/relax/test_vm_multi_device.py | 186 +++++++++++++++++++++
.../test_transform_default_gpu_schedule.py | 73 ++++++++
15 files changed, 471 insertions(+), 51 deletions(-)
diff --git a/python/tvm/driver/build_module.py
b/python/tvm/driver/build_module.py
index 9389e7fbee..52303123c1 100644
--- a/python/tvm/driver/build_module.py
+++ b/python/tvm/driver/build_module.py
@@ -243,20 +243,33 @@ def build(
if not isinstance(inputs, (dict, container.Map)):
target = Target.current() if target is None else target
- target = target if target else "llvm"
- target_input_mod = {target: input_mod}
+ if target is None and isinstance(input_mod, tvm.IRModule):
+ target_mod = {}
+ for gvar, func in input_mod.functions.items():
+ tgt = func.attrs["target"] if func.attrs and "target" in
func.attrs else "llvm"
+ if tgt not in target_mod:
+ target_mod[tgt] = {}
+ target_mod[tgt][gvar] = func
+
+ target_input_mod = {}
+ for tgt in target_mod.keys():
+ tir_mod = tvm.IRModule(target_mod[tgt])
+ tir_mod.with_attrs(input_mod.attrs)
+ target_input_mod[tgt] = tir_mod
+ else:
+ target_input_mod = {target: input_mod}
else:
- target_input_mod = inputs
+ target_input_mod = {tgt: lower(mod) for tgt, mod in inputs.items()}
# Because modules can be created from a variety of sources, we annotate
them
# with the relevant attributes here to ensure they propagate
annotated_mods = {}
- for tar, mod in target_input_mod.items():
- if not isinstance(tar, (str, Target)):
+ for tgt, mod in target_input_mod.items():
+ if not isinstance(tgt, (str, Target)):
raise ValueError("The key of inputs must be str or " "Target when
inputs is dict.")
if not isinstance(mod, tvm.IRModule):
- raise ValueError("inputs must be Schedule, IRModule," "or dict of
str to IRModule.")
- annotated_mods[tar] = mod.with_attr("runtime", runtime)
+ raise ValueError("inputs must be Schedule, IRModule, " "or dict of
str to IRModule.")
+ annotated_mods[tgt] = mod.with_attr("runtime", runtime)
# TODO(mbs): Both CompilationConfig and TIRToRuntime implement the same
host target
# defaulting logic, but there's currently no way to get back the decided
host.
diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py
index a1fa9cafe8..b720a727f6 100644
--- a/python/tvm/relax/utils.py
+++ b/python/tvm/relax/utils.py
@@ -28,7 +28,7 @@ from . import _ffi_api
from .expr import Tuple as rx_Tuple
from .expr import Expr, ShapeExpr, Function, PrimValue, StringImm, te_tensor
from ..te import Tensor as te_Tensor, create_prim_func
-from ..ir import Array, Attrs, Type, Map
+from ..ir import Array, Attrs, Type, Map, VDevice
from .struct_info import PrimStructInfo, ShapeStructInfo, TensorStructInfo
@@ -418,6 +418,24 @@ def gen_call_tir_inputs(
diff = used_vars - bound_vars
return list(diff)
+ def _get_vdevice(arg: Any) -> Optional[VDevice]:
+ """get the virtual device from arguments."""
+ vdevice = None
+ if isinstance(arg, Expr): # type: ignore
+ if isinstance(arg.struct_info, TensorStructInfo):
+ vdevice = arg.struct_info.vdevice
+ elif isinstance(arg, (list, Array, tuple)):
+ for x in arg:
+ vdevice = _get_vdevice(x)
+ if vdevice is not None:
+ return vdevice
+ elif isinstance(arg, (dict, Map)):
+ for k in arg:
+ vdevice = _get_vdevice(arg[k])
+ if vdevice is not None:
+ return vdevice
+ return vdevice
+
def _shape_with_old_tir_var(
shape_values: List[tir.PrimExpr], tir_var_inverse_map: Dict[tir.Var,
tir.PrimExpr]
):
@@ -456,7 +474,11 @@ def gen_call_tir_inputs(
tir_var_inverse_map = {v: k for k, v in tir_var_map.items()}
output_sinfo = [
- TensorStructInfo(_shape_with_old_tir_var(out.shape,
tir_var_inverse_map), out.dtype)
+ TensorStructInfo(
+ _shape_with_old_tir_var(out.shape, tir_var_inverse_map),
+ out.dtype,
+ _get_vdevice(args),
+ )
for out in outs
]
diff --git a/python/tvm/relax/vm_build.py b/python/tvm/relax/vm_build.py
index 62760b3417..9120f74e13 100644
--- a/python/tvm/relax/vm_build.py
+++ b/python/tvm/relax/vm_build.py
@@ -79,7 +79,7 @@ class Executable:
vm = tvm.relax.VirtualMachine(rt_mod, tvm.cuda())
"""
- # TODO(tvm-team): Update runtime.Module interfac
+ # TODO(tvm-team): Update runtime.Module interface
# to query these properties as bitmask.
def _not_runnable(x):
return x.type_key in ("c", "static_library")
@@ -179,13 +179,17 @@ def _vmcodegen(
raise ValueError(f"Unknown exec_mode {exec_mode}")
-def _autodetect_system_lib_req(target: tvm.target.Target, system_lib):
+def _autodetect_system_lib_req(
+ target: Optional[tvm.target.Target] = None, system_lib: Optional[bool] =
None
+):
"""Automatically detect system lib requirement"""
- host = target if target.host is None else target.host
- if system_lib is None:
- system_lib = False
- if "wasm" in host.attrs.get("mtriple", ""):
- system_lib = True
+ if target is not None:
+ host = target if target.host is None else target.host
+ if system_lib is None:
+ system_lib = False
+ if "wasm" in host.attrs.get("mtriple", ""):
+ system_lib = True
+
if system_lib:
# use packed-func to avoid relay dep.
return tvm.get_global_func("relay.backend.CreateRuntime")("cpp",
{"system-lib": system_lib})
@@ -194,7 +198,7 @@ def _autodetect_system_lib_req(target: tvm.target.Target,
system_lib):
def _vmlink(
builder: "relax.ExecBuilder",
- target: Union[str, tvm.target.Target],
+ target: Optional[Union[str, tvm.target.Target]],
tir_mod: Optional[tvm.IRModule] = None,
ext_libs: List[tvm.runtime.Module] = None,
params: Optional[Dict[str, list]] = None,
@@ -213,8 +217,10 @@ def _vmlink(
builder: relax.ExecBuilder
Builder used to collect executables.
- target : Union[str, tvm.target.Target]
+ target : Optional[Union[str, tvm.target.Target]]
A build target which can have optional host side compilation target.
+ If the target is not specified, the target in the vdevice list will be
used.
+ For multi-target compilation, the vdevice should be annotated.
tir_mod: IRModule
The input TIR IRModule to be linked together.
@@ -239,14 +245,16 @@ def _vmlink(
lib = None
if tir_mod is not None:
lib = tvm.build(
- tir_mod, target=target, runtime=_autodetect_system_lib_req(target,
system_lib)
+ tir_mod,
+ target=target,
+ runtime=_autodetect_system_lib_req(target, system_lib),
)
return Executable(_ffi_api.VMLink(builder, target, lib, ext_libs, params))
# type: ignore
def build(
mod: tvm.IRModule,
- target: Union[str, tvm.target.Target],
+ target: Optional[Union[str, tvm.target.Target]] = None,
params: Optional[Dict[str, list]] = None,
pipeline: Union[None, str, tvm.transform.Pass] = "default_build",
exec_mode: str = "bytecode",
@@ -261,7 +269,7 @@ def build(
mod: IRModule
The input IRModule to be built.
- target : Union[str, tvm.target.Target]
+ target : Optional[Union[str, tvm.target.Target]]
A build target which can have optional host side compilation target.
When TVM compiles device specific program such as CUDA,
diff --git a/python/tvm/runtime/relax_vm.py b/python/tvm/runtime/relax_vm.py
index a925e048b2..5b8bbe6d33 100644
--- a/python/tvm/runtime/relax_vm.py
+++ b/python/tvm/runtime/relax_vm.py
@@ -54,7 +54,7 @@ class VirtualMachine(object):
Parameters
----------
- mod: Union[tvm.runtime.Module, tvm.relax.Executable]
+ rt_mod: Union[tvm.runtime.Module, tvm.relax.Executable]
Runtime module exported by the result of build.
device : Union[Device, List[Device]]
@@ -107,11 +107,6 @@ class VirtualMachine(object):
)
devs = [dev]
- if any(dev.device_type % RPC_SESS_MASK == tvm.cpu().device_type for
dev in devs[:-1]):
- raise RuntimeError(
- "CPU host is required to be the last element of the device
list if provided."
- )
-
# CPU is required for executing shape functions
if devs[-1].device_type % RPC_SESS_MASK != tvm.cpu().device_type:
devs.append(tvm.cpu())
diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py
index 29c9463ba5..ccad989c33 100644
--- a/python/tvm/testing/utils.py
+++ b/python/tvm/testing/utils.py
@@ -832,6 +832,16 @@ def _any_gpu_exists():
)
+def _multi_gpu_exists():
+ return (
+ (tvm.cuda(0).exist and tvm.cuda(1).exist)
+ or (tvm.rocm(0).exist and tvm.rocm(1).exist)
+ or (tvm.opencl(0).exist and tvm.opencl(1).exist)
+ or (tvm.metal(0).exist and tvm.metal(1).exist)
+ or (tvm.vulkan(0).exist and tvm.vulkan(1).exist)
+ )
+
+
# Mark a test as requiring llvm to run
requires_llvm = Feature(
"llvm", "LLVM", cmake_flag="USE_LLVM", target_kind_enabled="llvm",
target_kind_hardware="llvm"
@@ -847,6 +857,16 @@ requires_gpu = Feature("gpu",
run_time_check=_any_gpu_exists)
# :py:func:`tvm.testing.requires_gpu`.
uses_gpu = requires_gpu(support_required="optional")
+# Mark a test as requiring multiple GPUs to run.
+requires_multi_gpu = Feature("multi_gpu", run_time_check=_multi_gpu_exists)
+
+# Mark to differentiate tests that use multiple GPUs in some capacity.
+#
+# These tests will be run on test nodes with multiple GPUs.
+# To mark a test that must have multiple GPUs present to run, use
+# :py:func:`tvm.testing.requires_multi_gpu`.
+uses_multi_gpu = requires_multi_gpu(support_required="optional")
+
# Mark a test as requiring the x86 Architecture to run.
requires_x86 = Feature(
"x86", "x86 Architecture", run_time_check=lambda: platform.machine() ==
"x86_64"
diff --git a/src/ir/module.cc b/src/ir/module.cc
index c016612c15..156158a85f 100644
--- a/src/ir/module.cc
+++ b/src/ir/module.cc
@@ -324,7 +324,7 @@ void IRModuleNode::Update(const IRModule& mod) {
IRModule IRModuleNode::ShallowCopy() {
return IRModule(this->functions, this->type_definitions, this->Imports(),
this->source_map,
- this->attrs);
+ this->attrs, this->global_infos);
}
std::pair<IRModule, GlobalVar> IRModule::FromExprInContext(
diff --git a/src/relax/transform/call_tir_rewrite.cc
b/src/relax/transform/call_tir_rewrite.cc
index e040ccea14..760d04a220 100644
--- a/src/relax/transform/call_tir_rewrite.cc
+++ b/src/relax/transform/call_tir_rewrite.cc
@@ -18,7 +18,8 @@
*/
/*!
* \file src/relax/transform/call_tir_rewrite.cc
- * \brief Perform explicit tensor allocation for call_tir.
+ * \brief Perform explicit tensor allocation for call_tir,
+ * call_tir_inplace, and call_dps_packed.
*/
#include <tvm/relax/attrs/op.h>
#include <tvm/relax/expr_functor.h>
@@ -28,6 +29,7 @@
#include <tvm/tir/op.h>
#include "../../relay/transforms/pattern_utils.h"
+#include "utils.h"
namespace tvm {
namespace relax {
@@ -43,6 +45,19 @@ namespace relax {
class CallTIRMutator : public ExprMutator {
public:
+ explicit CallTIRMutator(const IRModule& mod) : ExprMutator(mod),
mod_(std::move(mod)) {}
+
+ IRModule Run() {
+ for (const auto& [gv, func] : mod_->functions) {
+ if (func->IsInstance<FunctionNode>()) {
+ auto updated_func = Downcast<Function>(this->VisitExpr(func));
+ builder_->UpdateFunction(gv, Downcast<BaseFunc>(updated_func));
+ }
+ }
+ return builder_->GetContextIRModule();
+ }
+
+ private:
using ExprMutator::VisitExpr_;
Expr VisitExpr_(const CallNode* call) override {
// post-order mutation
@@ -65,11 +80,15 @@ class CallTIRMutator : public ExprMutator {
const TensorStructInfo& tensor_sinfo = _tensor_sinfo.value();
ICHECK(tensor_sinfo->shape.defined())
<< "the TensorStructInfo shape of call_tir has not populated";
+ int dev_index = 0;
+ if (tensor_sinfo->vdevice.defined()) {
+ dev_index = GetDeviceIndex(mod_, tensor_sinfo->vdevice.value());
+ }
if (!is_inplace) {
outs.push_back(
- builder_->Emit(Call(alloc_tensor_op, //
+ builder_->Emit(Call(alloc_tensor_op,
{Downcast<ShapeExpr>(tensor_sinfo->shape.value()),
- DataTypeImm(tensor_sinfo->dtype),
PrimValue::Int64(0)}, //
+ DataTypeImm(tensor_sinfo->dtype),
PrimValue::Int64(dev_index)},
Attrs()),
"alloc"));
} else {
@@ -150,16 +169,20 @@ class CallTIRMutator : public ExprMutator {
return GetRef<Expr>(call);
}
-};
-Expr CallTIRRewrite(const Expr& e) { return CallTIRMutator().VisitExpr(e); }
+ /*! \brief The context IRModule. */
+ IRModule mod_;
+};
namespace transform {
Pass CallTIRRewrite() {
- runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>
pass_func =
- [=](Function f, IRModule m, PassContext pc) { return
Downcast<Function>(CallTIRRewrite(f)); };
- return CreateFunctionPass(pass_func, 0, "CallTIRRewrite", {});
+ runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
+ [=](IRModule mod, PassContext pc) { return CallTIRMutator(mod).Run(); };
+ return CreateModulePass(/*pass_function=*/pass_func,
+ /*opt_level=*/0,
+ /*pass_name=*/"CallTIRRewrite",
+ /*required=*/{});
}
TVM_REGISTER_GLOBAL("relax.transform.CallTIRRewrite").set_body_typed(CallTIRRewrite);
diff --git a/src/relax/transform/legalize_ops.cc
b/src/relax/transform/legalize_ops.cc
index a557a41f8e..c8fba59dab 100644
--- a/src/relax/transform/legalize_ops.cc
+++ b/src/relax/transform/legalize_ops.cc
@@ -26,6 +26,7 @@
#include <tvm/relax/analysis.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/op_attr_types.h>
+#include <tvm/relax/struct_info.h>
#include <tvm/relax/transform.h>
namespace tvm {
@@ -72,6 +73,14 @@ class LegalizeMutator : public ExprMutator {
builder_->UpdateFunction(gv, Downcast<BaseFunc>(updated_func));
}
}
+ // Fill the "kTarget" attribute of PrimFunc
+ for (const auto& [gv, func] : builder_->GetContextIRModule()->functions) {
+ const tir::PrimFuncNode* prim_func;
+ if (tmap_.count(gv) && (prim_func = func.as<tir::PrimFuncNode>())) {
+ auto f = WithAttr(GetRef<tir::PrimFunc>(prim_func),
tvm::attr::kTarget, tmap_[gv]);
+ builder_->UpdateFunction(gv, f);
+ }
+ }
return builder_->GetContextIRModule();
}
@@ -109,6 +118,33 @@ class LegalizeMutator : public ExprMutator {
return Call(call_pure_packed_op, ret_args, ret->attrs, ret->sinfo_args);
}
+ Target GetTarget(const Array<StructInfo>& sinfos) {
+ for (auto sinfo : sinfos) {
+ if (const auto* tinfo = sinfo.as<TensorStructInfoNode>()) {
+ if (tinfo->vdevice.defined()) {
+ auto vdevice = tinfo->vdevice.value();
+ if (vdevice->target.defined()) {
+ return vdevice->target;
+ }
+ }
+ } else if (const auto* tup_sinfo = sinfo.as<TupleStructInfoNode>()) {
+ return GetTarget(tup_sinfo->fields);
+ }
+ }
+ return Target();
+ }
+
+ void SaveTarget(const Expr& expr) {
+ if (expr->IsInstance<CallNode>()) {
+ auto call = Downcast<Call>(expr);
+ auto target = GetTarget(call->sinfo_args);
+ const GlobalVarNode* gvar_node;
+ if (target.defined() && (gvar_node = call->args[0].as<GlobalVarNode>()))
{
+ this->tmap_.Set(GetRef<GlobalVar>(gvar_node), target);
+ }
+ }
+ }
+
Expr VisitExpr_(const CallNode* call) final {
Call visited_call = Downcast<Call>(this->VisitExprPostOrder_(call));
static const auto& legalize_map = Op::GetAttrMap<FLegalize>("FLegalize");
@@ -164,6 +200,10 @@ class LegalizeMutator : public ExprMutator {
builder_->BeginBindingBlock();
}
Expr legalized = legalization_func(builder_, visited_call);
+
+ // Save the expected target info. into tmap_
+ SaveTarget(legalized);
+
legalized = builder_->Normalize(legalized);
BindingBlock prologue = builder_->EndBlock();
@@ -196,6 +236,8 @@ class LegalizeMutator : public ExprMutator {
IRModule mod_;
/*! \brief The customized legalization function map. */
Map<String, PackedFunc> cmap_;
+ /*! \brief The map from GlobalVar of PrimFunc to compilation Target. */
+ Map<GlobalVar, Target> tmap_;
/*!
* \brief A boolean value indicating if to print warnings for CallNode whose
op's
* legalization function is not registered.
diff --git a/src/relax/transform/utils.h b/src/relax/transform/utils.h
index 2226e62763..8b3525c628 100644
--- a/src/relax/transform/utils.h
+++ b/src/relax/transform/utils.h
@@ -383,6 +383,17 @@ inline String GetCodegenName(const std::string&
composite_name) {
return composite_name.substr(0, delim_pos);
}
+inline int GetDeviceIndex(const IRModule& mod, const VDevice& vdevice) {
+ Array<GlobalInfo> vdevices = mod->global_infos["vdevice"];
+ for (int i = 0; i < static_cast<int>(vdevices.size()); ++i) {
+ if (vdevices[i] == vdevice) {
+ return i;
+ }
+ }
+ LOG(FATAL) << "The vdevice is not in the ir_module.";
+ return -1;
+}
+
/* \brief Eliminate common subexpressions
*
* Utility for simplifying relax expressions by removing common
diff --git a/src/runtime/relax_vm/vm.cc b/src/runtime/relax_vm/vm.cc
index b31268e697..d7f943d5f4 100644
--- a/src/runtime/relax_vm/vm.cc
+++ b/src/runtime/relax_vm/vm.cc
@@ -440,9 +440,6 @@ void
VirtualMachineImpl::LoadExecutable(ObjectPtr<Executable> exec) {
void VirtualMachineImpl::Init(const std::vector<Device>& devices,
const std::vector<AllocatorType>& alloc_types) {
- // TODO(@yuchen): support multi-device heterogeneous execution
- ICHECK_LT(devices.size(), 3)
- << "Currently relax vm only supports at most 2 devices (host + device)";
ICHECK_EQ(devices.size(), alloc_types.size());
this->devices.reserve(devices.size());
diff --git a/src/script/printer/relax/utils.h b/src/script/printer/relax/utils.h
index e0b5348d73..58b8bf4431 100644
--- a/src/script/printer/relax/utils.h
+++ b/src/script/printer/relax/utils.h
@@ -109,7 +109,6 @@ inline int FindVDeviceIndexByTargetKind(const VDevice&
vdevice, const IRDocsifie
kind_index++;
}
}
- LOG(WARNING) << "The VDevice was not found in the global_infos map: " <<
vdevice;
return -1;
}
diff --git a/src/tir/transforms/default_gpu_schedule.cc
b/src/tir/transforms/default_gpu_schedule.cc
index 5a22d0b0d9..6cf7f6e067 100644
--- a/src/tir/transforms/default_gpu_schedule.cc
+++ b/src/tir/transforms/default_gpu_schedule.cc
@@ -98,24 +98,53 @@ IRModule MarkScheduled(const IRModule& mod) {
mod->type_definitions, // type_definitions
mod->import_set_, // import_set
mod->source_map, // map
- mod->attrs); // attrs);
+ mod->attrs, // attrs
+ mod->global_infos); // global_infos
+}
+
+bool IsScheduledOnGPU(const BaseFunc& func) {
+ // the target from context.
+ tvm::Target target = tvm::Target::Current();
+ // the Target in kTarget attribute of PrimFunc
+ Optional<tvm::Target> func_target =
func->attrs.GetAttr<tvm::Target>(tvm::attr::kTarget);
+ if (func_target.defined()) {
+ target = func_target.value();
+ }
+
+ if (target.defined()) {
+ int dev_type = target->GetTargetDeviceType();
+ if (dev_type != kDLCUDA) {
+ return false;
+ }
+ }
+ return true;
}
Pass DefaultGPUSchedule() {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = //
[=](IRModule m, PassContext pc) {
- // get the target from context.
- tvm::Target target = tvm::Target::Current();
- ICHECK(target.defined()) << "Target is not set in current context";
- // get the max thread per block from target.
- Optional<Integer> opt_max_thread_per_block =
target->GetAttr<Integer>("max_num_threads");
- ICHECK(opt_max_thread_per_block.defined())
- << "max_num_threads is not set for target " << target;
- int64_t max_thread_per_block =
opt_max_thread_per_block.value().IntValue();
tir::Schedule sch = tir::Schedule::Traced(m, /*seed=*/-1,
/*debug_mask=*/0,
tir::ScheduleErrorRenderLevel::kDetail);
for (const auto& [gv, func] : m->functions) {
- if (func->IsInstance<tir::PrimFuncNode>() &&
!func->HasNonzeroAttr(attr::kIsScheduled)) {
+ if (func->IsInstance<tir::PrimFuncNode>() &&
!func->HasNonzeroAttr(attr::kIsScheduled) &&
+ IsScheduledOnGPU(func)) {
+ // get the target from context.
+ tvm::Target target = tvm::Target::Current();
+ // get the target from kTarget attribute
+ Optional<tvm::Target> func_target =
+ func->attrs.GetAttr<tvm::Target>(tvm::attr::kTarget);
+ if (func_target.defined()) {
+ target = func_target.value();
+ }
+ ICHECK(target.defined()) << "The target is missing either in the
current context or in "
+ "the prim_func's attribute.";
+ // get the max thread per block from target.
+ Optional<Integer> opt_max_thread_per_block =
+ target->GetAttr<Integer>("max_num_threads");
+ ICHECK(opt_max_thread_per_block.defined())
+ << "max_num_threads is not set for target " << target;
+ int64_t max_thread_per_block =
opt_max_thread_per_block.value().IntValue();
+
sch->WorkOn(gv->name_hint);
Array<tir::BlockRV> blocks =
meta_schedule::BlockCollector::Collect(sch);
for (const tir::BlockRV& block : blocks) {
diff --git a/tests/python/relax/test_frontend_stablehlo.py
b/tests/python/relax/test_frontend_stablehlo.py
index d3068f29c7..f2d0461dda 100644
--- a/tests/python/relax/test_frontend_stablehlo.py
+++ b/tests/python/relax/test_frontend_stablehlo.py
@@ -132,7 +132,7 @@ def check_correctness(
# Multiple ouputs
assert len(tvm_output) == len(jax_output), "numbers of outputs mismatch"
- for (tvm_out, jax_out) in zip(tvm_output, jax_output):
+ for tvm_out, jax_out in zip(tvm_output, jax_output):
tvm.testing.assert_allclose(tvm_out.numpy(), jax_out, rtol=1e-5,
atol=1e-5)
@@ -314,7 +314,9 @@ def test_dot_general():
check_correctness(jax.jit(fn), input_shapes)
[email protected]()
@tvm.testing.requires_gpu
+# TODO(yongwww): fix flaky error of "invalid device ordinal"
def test_conv():
import jax
from flax import linen as nn
diff --git a/tests/python/relax/test_vm_multi_device.py
b/tests/python/relax/test_vm_multi_device.py
new file mode 100644
index 0000000000..ec2fbd1cdf
--- /dev/null
+++ b/tests/python/relax/test_vm_multi_device.py
@@ -0,0 +1,186 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test eliminate common subexpr pass"""
+from typing import List
+import tvm
+from tvm import relax
+import tvm.testing
+from tvm.ir.module import IRModule
+from tvm.script.parser import ir as I, relax as R
+from tvm._ffi.runtime_ctypes import Device
+import numpy as np
+
+
+def compile(
+ mod: IRModule,
+ device: List[Device] = [
+ tvm.cpu(),
+ ],
+) -> relax.VirtualMachine:
+ # compile the model
+ mod = relax.transform.RealizeVDevice()(mod)
+ mod = relax.transform.LegalizeOps()(mod)
+ mod = tvm.tir.transform.DefaultGPUSchedule()(mod)
+ # no need to feed target argument for mult-target compilation
+ ex = relax.build(mod)
+
+ return relax.VirtualMachine(ex, device)
+
+
+def test_multi_cpu():
+ @I.ir_module
+ class Example:
+ I.module_attrs({"attr": 10})
+ I.module_global_infos(
+ {
+ "vdevice": [
+ I.vdevice("llvm", 0),
+ I.vdevice("llvm", 1),
+ ]
+ }
+ )
+
+ @R.function
+ def foo(
+ x: R.Tensor((2, 3), "float32"),
+ y: R.Tensor((3, 4), "float32"),
+ z: R.Tensor((4, 5), "float32"),
+ ) -> R.Tensor((2, 5), "float32"):
+ with R.dataflow():
+ lv0: R.Tensor((2, 4), "float32", "llvm:0") = R.matmul(x, y) #
noqa: F722
+ lv1: R.Tensor((2, 4), "float32", "llvm:1") = R.to_vdevice( #
noqa: F722
+ lv0, "llvm:1" # noqa: F722
+ )
+ gv = R.matmul(lv1, z) # noqa: F722
+ R.output(gv)
+ return gv
+
+ devices = [tvm.cpu(0), tvm.cpu(1)]
+ vm = compile(Example, devices)
+
+ np_ipt0 = np.random.rand(2, 3).astype(np.float32)
+ np_ipt1 = np.random.rand(3, 4).astype(np.float32)
+ np_ipt2 = np.random.rand(4, 5).astype(np.float32)
+ np_res = np.matmul(np.matmul(np_ipt0, np_ipt1), np_ipt2)
+
+ ipt0 = tvm.nd.array(np_ipt0, devices[0])
+ ipt1 = tvm.nd.array(np_ipt1, devices[0])
+ ipt2 = tvm.nd.array(np_ipt2, devices[1])
+ res = vm["foo"](ipt0, ipt1, ipt2)
+ tvm.testing.assert_allclose(res.numpy(), np_res)
+
+
[email protected]_multi_gpu
+def test_multi_gpu():
+ @I.ir_module
+ class Example:
+ I.module_attrs({"attr": 10})
+ I.module_global_infos(
+ {
+ "vdevice": [
+ I.vdevice("cuda", 1),
+ I.vdevice("cuda", 0),
+ I.vdevice("cuda", 2),
+ ]
+ }
+ )
+
+ @R.function
+ def foo(
+ a: R.Tensor((2, 3), "float32"),
+ b: R.Tensor((3, 4), "float32"),
+ c: R.Tensor((4, 5), "float32"),
+ d: R.Tensor((5, 6), "float32"),
+ ) -> R.Tensor((2, 6), "float32"):
+ with R.dataflow():
+ lv0: R.Tensor((2, 4), "float32", "cuda:0") = R.matmul(a, b) #
noqa: F722
+ lv1: R.Tensor((2, 4), "float32", "cuda:1") = R.to_vdevice( #
noqa: F722
+ lv0, "cuda:1" # noqa: F722
+ )
+ lv2: R.Tensor((2, 5), "float32", "cuda:1") = R.matmul(lv1, c)
# noqa: F722
+ lv3: R.Tensor((2, 5), "float32", "cuda:2") = R.to_vdevice( #
noqa: F722
+ lv2, "cuda:2" # noqa: F722
+ )
+ gv: R.Tensor((2, 6), "float32", "cuda:2") = R.matmul(lv3, d)
# noqa: F722
+ R.output(gv)
+ return gv
+
+ # The number and ordering of devices should be identical with the vdevice
list
+ # defined in global_infos of ir_module
+ devices = [tvm.cuda(1), tvm.cuda(0), tvm.cuda(2)]
+ vm = compile(Example, devices)
+
+ np_ipt0 = np.random.rand(2, 3).astype(np.float32)
+ np_ipt1 = np.random.rand(3, 4).astype(np.float32)
+ np_ipt2 = np.random.rand(4, 5).astype(np.float32)
+ np_ipt3 = np.random.rand(5, 6).astype(np.float32)
+ np_res = np.matmul(np.matmul(np.matmul(np_ipt0, np_ipt1), np_ipt2),
np_ipt3)
+
+ ipt0 = tvm.nd.array(np_ipt0, devices[0])
+ ipt1 = tvm.nd.array(np_ipt1, devices[0])
+ ipt2 = tvm.nd.array(np_ipt2, devices[1])
+ ipt3 = tvm.nd.array(np_ipt3, devices[2])
+ res = vm["foo"](ipt0, ipt1, ipt2, ipt3)
+ tvm.testing.assert_allclose(res.numpy(), np_res)
+
+
[email protected]_gpu
+def test_multi_device():
+ @I.ir_module
+ class Example:
+ I.module_attrs({"attr": 10})
+ I.module_global_infos(
+ {
+ "vdevice": [
+ I.vdevice("cuda", 0),
+ I.vdevice("llvm"),
+ ]
+ }
+ )
+
+ @R.function
+ def foo(
+ x: R.Tensor((2, 3), "float32"),
+ y: R.Tensor((3, 4), "float32"),
+ z: R.Tensor((4, 5), "float32"),
+ ) -> R.Tensor((2, 5), "float32"):
+ with R.dataflow():
+ lv0: R.Tensor((2, 4), "float32", "llvm") = R.matmul(x, y)
+ lv1: R.Tensor((2, 4), "float32", "cuda") = R.to_vdevice(lv0,
"cuda")
+ gv: R.Tensor((2, 5), "float32", "cuda") = R.matmul(lv1, z)
+ R.output(gv)
+ return gv
+
+ # The number and ordering of devices should be identical with the vdevice
list
+ # defined in global_infos of ir_module
+ devices = [tvm.cuda(0), tvm.cpu(0)]
+ vm = compile(Example, devices)
+
+ np_ipt0 = np.random.rand(2, 3).astype(np.float32)
+ np_ipt1 = np.random.rand(3, 4).astype(np.float32)
+ np_ipt2 = np.random.rand(4, 5).astype(np.float32)
+ np_res = np.matmul(np.matmul(np_ipt0, np_ipt1), np_ipt2)
+
+ ipt0 = tvm.nd.array(np_ipt0, devices[1])
+ ipt1 = tvm.nd.array(np_ipt1, devices[1])
+ ipt2 = tvm.nd.array(np_ipt2, devices[0])
+ res = vm["foo"](ipt0, ipt1, ipt2)
+ tvm.testing.assert_allclose(res.numpy(), np_res, rtol=1e-4, atol=1e-4)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git a/tests/python/tir-transform/test_transform_default_gpu_schedule.py
b/tests/python/tir-transform/test_transform_default_gpu_schedule.py
index 1af846c9d5..63809beade 100644
--- a/tests/python/tir-transform/test_transform_default_gpu_schedule.py
+++ b/tests/python/tir-transform/test_transform_default_gpu_schedule.py
@@ -88,6 +88,49 @@ def test_matmul():
C[v_i, v_j] = T.float16(0)
C[v_i, v_j] = C[v_i, v_j] + A[v_i, v_k] * B[v_k, v_j]
+ @T.prim_func
+ def matmul_gpu(
+ A: T.Buffer((32, 32), "float16"),
+ B: T.Buffer((32, 32), "float16"),
+ C: T.Buffer((32, 32), "float16"),
+ ):
+ T.func_attr({"global_symbol": "main",
+ "target": T.target({"arch": "sm_86",
+ "keys": ["cuda", "gpu"],
+ "kind": "cuda",
+ "max_num_threads": 1024,
+ "tag": "",
+ "thread_warp_size": 32}),
+ "tir.noalias": True})
+ # with T.block("root"):
+ for i, j, k in T.grid(32, 32, 32):
+ with T.block("C"):
+ v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
+ T.reads(A[v_i, v_k], B[v_k, v_j])
+ T.writes(C[v_i, v_j])
+ with T.init():
+ C[v_i, v_j] = T.float16(0)
+ C[v_i, v_j] = C[v_i, v_j] + A[v_i, v_k] * B[v_k, v_j]
+
+ @T.prim_func
+ def matmul_cpu(
+ A: T.Buffer((32, 32), "float16"),
+ B: T.Buffer((32, 32), "float16"),
+ C: T.Buffer((32, 32), "float16"),
+ ):
+ T.func_attr({"global_symbol": "main",
+ "target": T.target({"keys": ["cpu"], "kind": "llvm",
"tag": ""}),
+ "tir.noalias": True})
+ # with T.block("root"):
+ for i, j, k in T.grid(32, 32, 32):
+ with T.block("C"):
+ v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
+ T.reads(A[v_i, v_k], B[v_k, v_j])
+ T.writes(C[v_i, v_j])
+ with T.init():
+ C[v_i, v_j] = T.float16(0)
+ C[v_i, v_j] = C[v_i, v_j] + A[v_i, v_k] * B[v_k, v_j]
+
@tvm.script.ir_module
class Expected:
@T.prim_func
@@ -114,6 +157,36 @@ def test_matmul():
with T.init():
C[v_i, v_j] = T.float16(0)
C[v_i, v_j] = C[v_i, v_j] + A[v_i, v_k] * B[v_k,
v_j]
+
+ @T.prim_func
+ def matmul_cpu(A: T.Buffer((32, 32), "float16"), B: T.Buffer((32, 32),
"float16"), C: T.Buffer((32, 32), "float16")):
+ T.func_attr({"global_symbol": "main", "target": T.target({"keys":
["cpu"], "kind": "llvm", "tag": ""}), "tir.is_scheduled": T.bool(True),
"tir.noalias": T.bool(True)})
+ # with T.block("root"):
+ for i, j, k in T.grid(32, 32, 32):
+ with T.block("C"):
+ v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
+ T.reads(A[v_i, v_k], B[v_k, v_j])
+ T.writes(C[v_i, v_j])
+ with T.init():
+ C[v_i, v_j] = T.float16(0)
+ C[v_i, v_j] = C[v_i, v_j] + A[v_i, v_k] * B[v_k, v_j]
+
+ @T.prim_func
+ def matmul_gpu(A: T.Buffer((32, 32), "float16"), B: T.Buffer((32, 32),
"float16"), C: T.Buffer((32, 32), "float16")):
+ T.func_attr({"global_symbol": "main", "target": T.target({"arch":
"sm_86", "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024,
"tag": "", "thread_warp_size": 32}), "tir.is_scheduled": T.bool(True),
"tir.noalias": T.bool(True)})
+ # with T.block("root"):
+ for i_j_fused_0 in T.thread_binding(1, thread="blockIdx.x"):
+ for i_j_fused_1 in T.thread_binding(1024,
thread="threadIdx.x"):
+ for k in range(32):
+ with T.block("C"):
+ v_i = T.axis.spatial(32, (i_j_fused_0 * 1024 +
i_j_fused_1) // 32)
+ v_j = T.axis.spatial(32, (i_j_fused_0 * 1024 +
i_j_fused_1) % 32)
+ v_k = T.axis.reduce(32, k)
+ T.reads(A[v_i, v_k], B[v_k, v_j])
+ T.writes(C[v_i, v_j])
+ with T.init():
+ C[v_i, v_j] = T.float16(0)
+ C[v_i, v_j] = C[v_i, v_j] + A[v_i, v_k] * B[v_k,
v_j]
# fmt: on
# pylint: enable=no-self-argument,missing-class-docstring,line-too-long
target = tvm.target.Target("nvidia/geforce-rtx-3070")