areusch commented on a change in pull request #9276:
URL: https://github.com/apache/tvm/pull/9276#discussion_r808706394



##########
File path: tests/python/unittest/test_link_params.py
##########
@@ -181,54 +183,53 @@ def _add_decl(name, dtype):
 
 
 @tvm.testing.requires_llvm
-def test_llvm_link_params():
-    for dtype in LINKABLE_DTYPES:
-        ir_mod, param_init = _make_mod_and_params(dtype)
-        rand_input = _make_random_tensor(dtype, INPUT_SHAPE)
-        main_func = ir_mod["main"]
-        target = "llvm --runtime=c --system-lib --link-params"
-        with tvm.transform.PassContext(opt_level=3):
-            lib = tvm.relay.build(ir_mod, target, params=param_init)
-
-            # NOTE: Need to export_library() and load_library() to link all 
the Module(llvm, ...)
-            # against one another.
-            temp_dir = tempfile.mkdtemp()
-            export_file = os.path.join(temp_dir, "lib.so")
-            lib.lib.export_library(export_file)
-            mod = tvm.runtime.load_module(export_file)
-            assert set(lib.params.keys()) == {"p0", "p1"}  # NOTE: op folded
-            assert mod.get_function("TVMSystemLibEntryPoint") != None
-
-            graph = json.loads(lib.graph_json)
-            for p in lib.params:
-                _verify_linked_param(dtype, lib, mod, graph, p) or found_one
-
-            # Wrap in function to explicitly deallocate the runtime.
-            def _run_linked(lib, mod):
-                graph_json, _, _ = lib
-                graph_rt = tvm.contrib.graph_executor.create(graph_json, mod, 
tvm.cpu(0))
-                graph_rt.set_input("rand_input", rand_input)  # NOTE: params 
not required.
-                graph_rt.run()
-                return graph_rt.get_output(0)
-
-            linked_output = _run_linked(lib, mod)
-
-        with tvm.transform.PassContext(opt_level=3):
-            lib = tvm.relay.build(ir_mod, "llvm --system-lib", 
params=param_init)
-
-            def _run_unlinked(lib):
-                graph_json, mod, lowered_params = lib
-                graph_rt = tvm.contrib.graph_executor.create(graph_json, mod, 
tvm.cpu(0))
-                graph_rt.set_input("rand_input", rand_input, **lowered_params)
-                graph_rt.run()
-                return graph_rt.get_output(0)
-
-            unlinked_output = _run_unlinked(lib)
-
-        if "int" in dtype:
-            np.testing.assert_equal(unlinked_output.numpy(), 
linked_output.numpy())
-        else:
-            np.testing.assert_allclose(unlinked_output.numpy(), 
linked_output.numpy())
+def test_llvm_link_params(linkable_dtype):
+    ir_mod, param_init = _make_mod_and_params(linkable_dtype)
+    rand_input = _make_random_tensor(linkable_dtype, INPUT_SHAPE)
+    main_func = ir_mod["main"]
+    target = "llvm --runtime=c --system-lib --link-params"
+    with tvm.transform.PassContext(opt_level=3):
+        lib = tvm.relay.build(ir_mod, target, params=param_init)
+
+        # NOTE: Need to export_library() and load_library() to link all the 
Module(llvm, ...)
+        # against one another.
+        temp_dir = tempfile.mkdtemp()

Review comment:
       fixed this one




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to