trevor-m commented on a change in pull request #7016:
URL: https://github.com/apache/tvm/pull/7016#discussion_r534501808



##########
File path: tests/python/contrib/test_tensorrt.py
##########
@@ -1050,26 +1050,26 @@ def test_tensorrt_dynamic_batch():
     batches_to_test = [1, 1, 0, 2, 3, 0, 1, 3, 2]
     x_shape = (relay.Any(), 1, 8, 8)
     x_data = np.ones([max(batches_to_test)] + 
list(x_shape)[1:]).astype("float32")
-    result_dict = {}
     for use_trt in [True, False]:
+        result_dict = {}
         x = relay.var("x", shape=x_shape, dtype="float32")
         out = relay.nn.relu(x)
         f = relay.Function([x], out)
         mod = tvm.IRModule()
         mod["main"] = f
         if use_trt:
-            mod = relay.tensorrt.EnableTrt(mod)
+            mod, _ = tensorrt.partition_for_tensorrt(mod)
 
         if not skip_runtime_test():
             with relay.build_config(opt_level=3):
                 relay_exec = relay.create_executor("vm", mod=mod, 
ctx=tvm.cpu(0), target="llvm")
 
             for i, batch_size in enumerate(batches_to_test):
-                result_dict[(i, use_trt)] = 
relay_exec.evaluate()(x_data[:batch_size, ...])
+                result_dict[use_trt] = 
relay_exec.evaluate()(x_data[:batch_size, ...])

Review comment:
       I think we still need to include `i` in the result dict, otherwise the 
results from all TRT runs will overwrite eachother and we are only looking at 
the last run




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to