This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 8931cfa fix relay.build to not change the module argument in place
(#5822)
8931cfa is described below
commit 8931cfa624269c5c558c0ae6faa3504764d04c64
Author: Thomas Viehmann <[email protected]>
AuthorDate: Tue Jun 16 23:13:40 2020 +0200
fix relay.build to not change the module argument in place (#5822)
---
src/relay/backend/build_module.cc | 3 ++-
tests/python/relay/test_cpp_build_module.py | 7 ++++++-
2 files changed, 8 insertions(+), 2 deletions(-)
diff --git a/src/relay/backend/build_module.cc
b/src/relay/backend/build_module.cc
index f9ce24d..dea923d 100644
--- a/src/relay/backend/build_module.cc
+++ b/src/relay/backend/build_module.cc
@@ -244,7 +244,8 @@ class RelayBuildModule : public runtime::ModuleNode {
GlobalVar main_glb_var = relay_module->GetGlobalVar("main");
Function main_func =
Downcast<Function>(relay_module->Lookup(main_glb_var));
auto new_main = BindParamsByName(main_func, params);
- relay_module->Update(main_glb_var, new_main);
+ IRModuleNode* relay_module_ptr = relay_module.CopyOnWrite();
+ relay_module_ptr->Update(main_glb_var, new_main);
}
Array<Pass> pass_seqs;
diff --git a/tests/python/relay/test_cpp_build_module.py
b/tests/python/relay/test_cpp_build_module.py
index 8d54384..fa56eb0 100644
--- a/tests/python/relay/test_cpp_build_module.py
+++ b/tests/python/relay/test_cpp_build_module.py
@@ -44,7 +44,12 @@ def test_basic_build():
targets = {
tvm.tir.IntImm("int32", ctx.device_type): tgt
}
- g_json, mmod, params = relay.build(tvm.IRModule.from_expr(func), targets,
"llvm", params=params)
+ mod = tvm.IRModule.from_expr(func)
+ func_in_mod = mod["main"]
+ assert mod["main"] == func_in_mod, "cannot compare function to itself"
+
+ g_json, mmod, params = relay.build(mod, targets, "llvm", params=params)
+ assert mod["main"] == func_in_mod, "relay.build changed module in-place"
# test
rt = tvm.contrib.graph_runtime.create(g_json, mmod, ctx)