areusch commented on a change in pull request #9276:
URL: https://github.com/apache/tvm/pull/9276#discussion_r808706614
##########
File path: tests/python/unittest/test_link_params.py
##########
@@ -181,54 +183,53 @@ def _add_decl(name, dtype):
@tvm.testing.requires_llvm
-def test_llvm_link_params():
- for dtype in LINKABLE_DTYPES:
- ir_mod, param_init = _make_mod_and_params(dtype)
- rand_input = _make_random_tensor(dtype, INPUT_SHAPE)
- main_func = ir_mod["main"]
- target = "llvm --runtime=c --system-lib --link-params"
- with tvm.transform.PassContext(opt_level=3):
- lib = tvm.relay.build(ir_mod, target, params=param_init)
-
- # NOTE: Need to export_library() and load_library() to link all
the Module(llvm, ...)
- # against one another.
- temp_dir = tempfile.mkdtemp()
- export_file = os.path.join(temp_dir, "lib.so")
- lib.lib.export_library(export_file)
- mod = tvm.runtime.load_module(export_file)
- assert set(lib.params.keys()) == {"p0", "p1"} # NOTE: op folded
- assert mod.get_function("TVMSystemLibEntryPoint") != None
-
- graph = json.loads(lib.graph_json)
- for p in lib.params:
- _verify_linked_param(dtype, lib, mod, graph, p) or found_one
-
- # Wrap in function to explicitly deallocate the runtime.
- def _run_linked(lib, mod):
- graph_json, _, _ = lib
- graph_rt = tvm.contrib.graph_executor.create(graph_json, mod,
tvm.cpu(0))
- graph_rt.set_input("rand_input", rand_input) # NOTE: params
not required.
- graph_rt.run()
- return graph_rt.get_output(0)
-
- linked_output = _run_linked(lib, mod)
-
- with tvm.transform.PassContext(opt_level=3):
- lib = tvm.relay.build(ir_mod, "llvm --system-lib",
params=param_init)
-
- def _run_unlinked(lib):
- graph_json, mod, lowered_params = lib
- graph_rt = tvm.contrib.graph_executor.create(graph_json, mod,
tvm.cpu(0))
- graph_rt.set_input("rand_input", rand_input, **lowered_params)
- graph_rt.run()
- return graph_rt.get_output(0)
-
- unlinked_output = _run_unlinked(lib)
-
- if "int" in dtype:
- np.testing.assert_equal(unlinked_output.numpy(),
linked_output.numpy())
- else:
- np.testing.assert_allclose(unlinked_output.numpy(),
linked_output.numpy())
+def test_llvm_link_params(linkable_dtype):
+ ir_mod, param_init = _make_mod_and_params(linkable_dtype)
Review comment:
i'd like to address this one, but i'm a little short on time right now.
hopefully we can merge this in the spirit of incremental progress and address
later
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]