icemelon9 commented on a change in pull request #6337:
URL: https://github.com/apache/incubator-tvm/pull/6337#discussion_r479448619
##########
File path: python/tvm/runtime/vm.py
##########
@@ -307,8 +307,14 @@ def __init__(self, exe, ctx, memory_cfg=None):
def _setup_ctx(self, ctx, memory_cfg):
"""Init context and allocators."""
- if isinstance(ctx, tvm.runtime.TVMContext):
- ctx = [ctx]
+ ctxs = ctx
+ if not isinstance(ctx, (list, tuple)):
+ assert isinstance(ctx, tvm.runtime.TVMContext)
Review comment:
Add an error message here
##########
File path: src/runtime/vm/executable.cc
##########
@@ -631,9 +653,10 @@ Instruction DeserializeInstruction(const
VMInstructionSerializer& instr) {
dtype.bits = instr.fields[3];
dtype.lanes = instr.fields[4];
- RegName dst = instr.fields[5];
+ Index device_type = instr.fields[5];
+ RegName dst = instr.fields[6];
Review comment:
update the number of fields check in line 647
##########
File path: src/runtime/vm/vm.cc
##########
@@ -68,8 +68,17 @@ inline ObjectRef CopyTo(ObjectRef src, const DLContext& ctx)
{
if (nd_array->ctx.device_type != ctx.device_type) {
return nd_array.CopyTo(ctx);
}
+ return src;
+ } else {
+ CHECK(src->IsInstance<ADTObj>())
+ << "VM data must be NDArray or a list of NDArray, but received: " <<
src->_type_key;
+ std::vector<ObjectRef> ret;
+ ADT adt = Downcast<ADT>(src);
+ for (size_t i = 0; i < adt.size(); i++) {
+ ret.push_back(CopyTo(adt[i], ctx));
+ }
+ return ADT(0, ret.begin(), ret.end());
Review comment:
why not use `adt->tag`?
##########
File path: python/tvm/runtime/vm.py
##########
@@ -307,8 +307,14 @@ def __init__(self, exe, ctx, memory_cfg=None):
def _setup_ctx(self, ctx, memory_cfg):
"""Init context and allocators."""
- if isinstance(ctx, tvm.runtime.TVMContext):
- ctx = [ctx]
+ ctxs = ctx
+ if not isinstance(ctx, (list, tuple)):
+ assert isinstance(ctx, tvm.runtime.TVMContext)
+ ctxs = [ctx]
+ # CPU is required for executing shape functions
+ if ctx.device_type != tvm.cpu(0).device_type:
Review comment:
probably check all ctxs to see if there is a cpu ctx.
##########
File path: src/runtime/vm/vm.cc
##########
@@ -164,18 +178,15 @@ PackedFunc VirtualMachine::GetFunction(const std::string&
name,
}
}
-TVMContext VirtualMachine::GetParamsContext() const {
+TVMContext VirtualMachine::GetContext(Index device_type) const {
Review comment:
similar here for this function
##########
File path: src/runtime/vm/vm.cc
##########
@@ -146,12 +155,17 @@ PackedFunc VirtualMachine::GetFunction(const std::string&
name,
auto func_index = gvit->second;
const auto& vm_func = exec_->functions[func_index];
const auto& param_names = vm_func.params;
- // TODO(icemelon9): For heterogeneous execution, get input device
information
- TVMContext ctx = ctxs_[0];
CHECK_EQ(args.size() - 1, param_names.size())
<< "The number of provided parameters doesn't match the number of
arguments";
+ CHECK_EQ(param_names.size(), vm_func.params_device_type.size())
+ << "The number of provided parameters doesn't match the number of
assigned devices";
std::vector<ObjectRef> func_args(param_names.size());
for (int i = 1; i < args.size(); ++i) {
+ TVMContext ctx;
+ int device_type = vm_func.params_device_type[i - 1];
+ ctx.device_type = DLDeviceType(device_type);
Review comment:
We should create a map from device type to ctx in the `Init`. So here we
can just look up the corresponding context.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]