While switching to TVMC, I noticed a "virtual_device" property on the top-level
relay module function. It was not properly propagated through my relay passes
and caused an assertion in lowering to TE, with:
Check failed: (!virtual_device->IsFullyUnconstrained()) is false
at:
```
File "/home/user1/mlenv/deps/src/tvm/python/tvm/driver/tvmc/__main__.py",
line 24, in <module>
tvmc.main.main()
File "/home/user1/mlenv/deps/src/tvm/python/tvm/driver/tvmc/main.py", line
115, in main
sys.exit(_main(sys.argv[1:]))
File "/home/user1/mlenv/deps/src/tvm/python/tvm/driver/tvmc/main.py", line
103, in _main
return args.func(args)
File "/home/user1/mlenv/deps/src/tvm/python/tvm/driver/tvmc/compiler.py",
line 173, in drive_compile
compile_model(
File "/home/user1/mlenv/deps/src/tvm/python/tvm/driver/tvmc/compiler.py",
line 337, in compile_model
graph_module = build(
File "/home/user1/mlenv/deps/src/tvm/python/tvm/driver/tvmc/compiler.py",
line 410, in build
return relay.build(
File "/home/user1/mlenv/deps/src/tvm/python/tvm/relay/build_module.py", line
431, in build
graph_json, runtime_mod, params = bld_mod.build(
File "/home/user1/mlenv/deps/src/tvm/python/tvm/relay/build_module.py", line
154, in build
self._build(mod, raw_targets, executor, runtime, workspace_memory_pools,
mod_name)
File "/home/user1/mlenv/deps/src/tvm/python/tvm/_ffi/_ctypes/packed_func.py",
line 237, in __call__
raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
29: TVMFuncCall
28:
tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::relay::backend::RelayBuildModule::GetFunction(std::__cxx11::basic_string<char,
std::char_traits<char>, std::allocator<char> > const&,
tvm::runtime::ObjectPtr<tvm::runtime::Object>
const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#3}>
>::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs,
tvm::runtime::TVMRetValue*)
27: tvm::relay::backend::RelayBuildModule::BuildRelay(tvm::IRModule,
tvm::runtime::String const&)
26:
tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::relay::backend::AOTExecutorCodegenModule::GetFunction(std::__cxx11::basic_string<char,
std::char_traits<char>, std::allocator<char> > const&,
tvm::runtime::ObjectPtr<tvm::runtime::Object>
const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#2}>
>::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs,
tvm::runtime::TVMRetValue*)
25: tvm::relay::backend::AOTExecutorCodegen::Codegen(tvm::IRModule,
tvm::relay::Function, tvm::runtime::String)
24: tvm::transform::Pass::operator()(tvm::IRModule) const
23: tvm::transform::Pass::operator()(tvm::IRModule,
tvm::transform::PassContext const&) const
22: tvm::transform::SequentialNode::operator()(tvm::IRModule,
tvm::transform::PassContext const&) const
21: tvm::transform::Pass::operator()(tvm::IRModule,
tvm::transform::PassContext const&) const
20: tvm::transform::ModulePassNode::operator()(tvm::IRModule,
tvm::transform::PassContext const&) const
19:
_ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_8IRModuleES5_NS_9transform11PassContextEEE17AssignTypedLambdaIZNS_5relay3tec7LowerTEENS0_6StringENS_17CompilationConfigESt8functionIFvNS_8BaseFuncEEEEUlS5_S7_E_EEvT_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SL_SP_
18: tvm::relay::tec::LowerTE(tvm::IRModule const&, tvm::runtime::String
const&, std::function<void (tvm::BaseFunc)>, tvm::CompilationConfig)
17: tvm::transform::Pass::operator()(tvm::IRModule) const
16: tvm::transform::Pass::operator()(tvm::IRModule,
tvm::transform::PassContext const&) const
15: tvm::relay::transform::FunctionPassNode::operator()(tvm::IRModule,
tvm::transform::PassContext const&) const
14:
_ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_5relay8FunctionES6_NS_8IRModuleENS_9transform11PassContextEEE17AssignTypedLambdaIZNS5_3tec15LowerTensorExprERKNS0_6StringENSD_10TECompilerESt8functionIFvNS_8BaseFuncEEENS_17CompilationConfigEEUlS6_S7_S9_E_EEvT_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SP_ST_
13: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
12: _ZZN3tvm5relay11ExprFuncto
11:
tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::FunctionNode
const*)
10:
tvm::relay::tec::LowerTensorExprMutator::DeviceAwareVisitExpr_(tvm::relay::FunctionNode
const*)
9: _ZN3tvm5relay9tr
8: tvm::relay::ExprMutator::VisitExpr_(tvm::relay::FunctionNode const*)
7: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
6: _ZZN3tvm5relay11ExprFuncto
5:
tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::LetNode
const*)
4:
tvm::relay::tec::LowerTensorExprMutator::PreVisitLetBinding_(tvm::relay::Var
const&, tvm::RelayExpr const&)
3: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
2: _ZZN3tvm5relay11ExprFuncto
1:
tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::CallNode
const*)
0:
tvm::relay::tec::LowerTensorExprMutator::DeviceAwareVisitExpr_(tvm::relay::CallNode
const*)
File "/home/user1/mlenv/deps/src/tvm/src/relay/backend/te_compiler.cc", line
885
```
I noticed that this property is sometimes updated manually after creating new
copies of a function:
https://github.com/apache/tvm/blob/308d320a66f16abf67c5daf4ae58cec3567decdd/src/relay/ir/expr_functor.cc#L492
However, this was not always done and I had to patch the following cases to fix
the compilation again:
```
diff --git a/src/relay/ir/transform.cc b/src/relay/ir/transform.cc
index 1a16cc9be..d05a30626 100644
--- a/src/relay/ir/transform.cc
+++ b/src/relay/ir/transform.cc
@@ -131,6 +131,7 @@ IRModule FunctionPassNode::operator()(IRModule mod, const
PassContext& pass_ctx)
// only process optimizable Relay Functions
if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) {
Function updated_func = pass_func(GetRef<Function>(function_node),
updated_mod, pass_ctx);
+ updated_func->virtual_device_ =
GetRef<Function>(function_node)->virtual_device();
updates.push_back({kv.first, std::move(updated_func)});
}
}
diff --git a/python/tvm/relay/expr_functor.py b/python/tvm/relay/expr_functor.py
index b9ca7d0e1..889031ed4 100644
--- a/python/tvm/relay/expr_functor.py
+++ b/python/tvm/relay/expr_functor.py
@@ -204,7 +204,10 @@ class ExprMutator(ExprFunctor):
def visit_function(self, fn):
new_params = [self.visit(x) for x in fn.params]
new_body = self.visit(fn.body)
- return Function(list(new_params), new_body, fn.ret_type,
fn.type_params, fn.attrs)
+ func = Function(list(new_params), new_body, fn.ret_type,
fn.type_params, fn.attrs)
+ from tvm.relay.function import FunctionCopyVirtualDevice
+ FunctionCopyVirtualDevice(func, fn)
+ return func
def visit_let(self, let):
new_var = self.visit(let.var)
diff --git a/python/tvm/relay/function.py b/python/tvm/relay/function.py
index f889f1e59..997fd1776 100644
--- a/python/tvm/relay/function.py
+++ b/python/tvm/relay/function.py
@@ -26,6 +26,10 @@
from . import _ffi_api
+def FunctionCopyVirtualDevice(f1, f2):
+ _ffi_api.FunctionCopyVirtualDevice(f1, f2)
+
+
@tvm._ffi.register_object("relay.Function")
class Function(BaseFunc):
"""A function declaration expression.
diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc
index 63e74144e..bd3906731 100644
--- a/src/relay/ir/function.cc
+++ b/src/relay/ir/function.cc
@@ -127,6 +127,10 @@ TVM_REGISTER_GLOBAL("relay.ir.Function")
tvm::Array<TypeVar> ty_params, tvm::DictAttrs attrs) {
return Function(params, body, ret_type, ty_params, attrs);
});
+TVM_REGISTER_GLOBAL("relay.ir.FunctionCopyVirtualDevice")
+ .set_body_typed([](Function f1, Function f2) {
+ f1->virtual_device_ = f2->virtual_device_;
+ });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<FunctionNode>([](const ObjectRef& ref, ReprPrinter* p) {
```
This does not seem like an elegant solution and I'm wondering why the
virtual_device is not part of the Function() python interface. Would that be an
appropriate solution?
@mbs-octoml @electriclilies
---
[Visit
Topic](https://discuss.tvm.apache.org/t/relay-function-virtual-device-property/12958/1)
to respond.
You are receiving this because you enabled mailing list mode.
To unsubscribe from these emails, [click
here](https://discuss.tvm.apache.org/email/unsubscribe/210ea8fb1ff03f47764a08f3a41ee4fc7c54532725118fa9e53837103d4c1c18).