This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 170add2  Add parameter to allow caller to supply a Runner (#8747)
170add2 is described below

commit 170add2f2fbc3507d1cbfc77ff95312dfe3a1ca9
Author: Robert Kimball <[email protected]>
AuthorDate: Sat Aug 14 12:00:46 2021 -0700

    Add parameter to allow caller to supply a Runner (#8747)
    
    * Add parameter to allow caller to supply a Runner
    
    * Add unit test for passing in runner to graph tuner
---
 python/tvm/autotvm/graph_tuner/base_graph_tuner.py |  6 ++-
 .../unittest/test_autotvm_graph_tuner_core.py      | 43 ++++++++++++++++++++++
 2 files changed, 48 insertions(+), 1 deletion(-)

diff --git a/python/tvm/autotvm/graph_tuner/base_graph_tuner.py 
b/python/tvm/autotvm/graph_tuner/base_graph_tuner.py
index 780e6c9..beb1aa0 100644
--- a/python/tvm/autotvm/graph_tuner/base_graph_tuner.py
+++ b/python/tvm/autotvm/graph_tuner/base_graph_tuner.py
@@ -375,6 +375,7 @@ class BaseGraphTuner(object):
         layout_records=None,
         target_host=None,
         infer_layout=False,
+        runner=None,
     ):
         """Benchmark all possible layout transformation in the graph,
         given a set of schedule candidates for each workload of target 
operator.
@@ -438,6 +439,8 @@ class BaseGraphTuner(object):
             of benchmarking on target device.
 
             This might bring performance loss comparing to benchmarking layout 
transformation.
+        runner : Runner, optional
+            Accept a user-supplied runner
         """
         self._logger.info("Start to benchmark layout transformation...")
         self._target, target_host = 
Target.check_and_update_host_consist(self._target, target_host)
@@ -483,7 +486,6 @@ class BaseGraphTuner(object):
             return _callback
 
         builder = autotvm.LocalBuilder(n_parallel=n_parallel, 
build_func=build_func)
-        runner = autotvm.LocalRunner(number=min_exec_num, repeat=1, 
timeout=timeout)
         if use_rpc:
             if device_key is None:
                 raise RuntimeError("device_key need to be set to use rpc 
tracker mode.")
@@ -496,6 +498,8 @@ class BaseGraphTuner(object):
                 repeat=1,
                 timeout=timeout,
             )
+        elif not runner:
+            runner = autotvm.LocalRunner(number=min_exec_num, repeat=1, 
timeout=timeout)
         measure_option = autotvm.measure_option(builder=builder, runner=runner)
         for args in args_list:
             data, in_layout, out_layout = args
diff --git a/tests/python/unittest/test_autotvm_graph_tuner_core.py 
b/tests/python/unittest/test_autotvm_graph_tuner_core.py
index 3d7d304..bcc4364 100644
--- a/tests/python/unittest/test_autotvm_graph_tuner_core.py
+++ b/tests/python/unittest/test_autotvm_graph_tuner_core.py
@@ -188,6 +188,49 @@ def test_graph_tuner_layout_transform():
         )
 
 
+def test_graph_tuner_layout_transform_runner():
+    log_file = "%s/test_tuner.log" % (os.getcwd())
+    target = "llvm"
+    dshape = (1, 3, 8, 8)
+    dtype = "float32"
+    layout = "NCHW"
+    conv2d = relay.op.get("nn.conv2d")
+    target_ops = [conv2d]
+
+    g, records, ltf_records, ltf_keys, _ = _create_data(target, dshape, dtype, 
layout)
+    executor = DPTuner(g, {"data": dshape}, records, target_ops, 
target=target, log_file=log_file)
+    runner = autotvm.LocalRunner(number=100, repeat=1, timeout=10)
+    executor.benchmark_layout_transform(
+        layout_records=ltf_records, infer_layout=True, runner=runner
+    )
+    out = executor._layout_transform_perf_records
+
+    num_flops = 0
+    total_time = 0
+    for record in ltf_records:
+        ltf_wkl = record[0].task.workload
+        input_shape = ltf_wkl[1][1]
+        flops = np.prod(input_shape)
+        num_flops += flops
+        total_time += record[1].costs[0]
+    avg_time = total_time / num_flops
+
+    for ltf_workload in out:
+        input_shape = ltf_workload[1][1]
+        flops = 1
+        for i in input_shape:
+            flops *= i
+        expected_time = flops * avg_time
+        out_time = out[ltf_workload][1].costs[0]
+        assert (
+            expected_time == out_time
+        ), "Inferred layout transformation time mismatch for %s: " "expecting 
%f but got %f" % (
+            str(ltf_workload),
+            expected_time,
+            out_time,
+        )
+
+
 def test_DPTuner_run():
     log_file = "%s/test_tuner.log" % (os.getcwd())
     target = "llvm"

Reply via email to