This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 8931cfa  fix relay.build to not change the module argument in place 
(#5822)
8931cfa is described below

commit 8931cfa624269c5c558c0ae6faa3504764d04c64
Author: Thomas Viehmann <[email protected]>
AuthorDate: Tue Jun 16 23:13:40 2020 +0200

    fix relay.build to not change the module argument in place (#5822)
---
 src/relay/backend/build_module.cc           | 3 ++-
 tests/python/relay/test_cpp_build_module.py | 7 ++++++-
 2 files changed, 8 insertions(+), 2 deletions(-)

diff --git a/src/relay/backend/build_module.cc 
b/src/relay/backend/build_module.cc
index f9ce24d..dea923d 100644
--- a/src/relay/backend/build_module.cc
+++ b/src/relay/backend/build_module.cc
@@ -244,7 +244,8 @@ class RelayBuildModule : public runtime::ModuleNode {
       GlobalVar main_glb_var = relay_module->GetGlobalVar("main");
       Function main_func = 
Downcast<Function>(relay_module->Lookup(main_glb_var));
       auto new_main = BindParamsByName(main_func, params);
-      relay_module->Update(main_glb_var, new_main);
+      IRModuleNode* relay_module_ptr = relay_module.CopyOnWrite();
+      relay_module_ptr->Update(main_glb_var, new_main);
     }
 
     Array<Pass> pass_seqs;
diff --git a/tests/python/relay/test_cpp_build_module.py 
b/tests/python/relay/test_cpp_build_module.py
index 8d54384..fa56eb0 100644
--- a/tests/python/relay/test_cpp_build_module.py
+++ b/tests/python/relay/test_cpp_build_module.py
@@ -44,7 +44,12 @@ def test_basic_build():
     targets = {
         tvm.tir.IntImm("int32", ctx.device_type): tgt
     }
-    g_json, mmod, params = relay.build(tvm.IRModule.from_expr(func), targets, 
"llvm", params=params)
+    mod = tvm.IRModule.from_expr(func)
+    func_in_mod = mod["main"]
+    assert mod["main"] == func_in_mod, "cannot compare function to itself"
+
+    g_json, mmod, params = relay.build(mod, targets, "llvm", params=params)
+    assert mod["main"] == func_in_mod, "relay.build changed module in-place"
 
     # test
     rt = tvm.contrib.graph_runtime.create(g_json, mmod, ctx)

Reply via email to