trevor-m commented on a change in pull request #6955:
URL: https://github.com/apache/tvm/pull/6955#discussion_r532764867
##########
File path: tests/python/contrib/test_tensorrt.py
##########
@@ -1034,5 +1039,119 @@ def set_func_attr(func, compile_name, symbol_name):
tvm.ir.assert_structural_equal(mod_trt, mod_exp, map_free_vars=True)
+def convert_traced_model_to_vm_trt(
+ traced_module: torch.jit.TopLevelTracedModule, np_sample_input:
np.ndarray, target: str
+) -> tvm.runtime.vm.Executable:
+ """
+ This function converts a traced pytorch model to VM + TRT.
+ """
+ input_shape = np_sample_input.shape
+ input_name = "input0"
+ shape_list = [(input_name, input_shape)]
+ mod, params = relay.frontend.from_pytorch(traced_module, shape_list)
+ mod, config = tensorrt.partition_for_tensorrt(mod, params,
remove_no_mac_subgraphs=True)
+ with tvm.transform.PassContext(opt_level=3,
disabled_pass=["FoldScaleAxis"]):
+ vm_trt_exec = relay.vm.compile(mod, target=target, params=params)
+
+ return vm_trt_exec
+
+
+def test_maskrcnn_resnet50() -> None:
Review comment:
What properties of this particular model are we trying to test here? Can
we make a simple test case which replicates the same behavior instead - i.e.
single conv2d op with dynamic batch size in TRT.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]