comaniac commented on a change in pull request #6657:
URL: https://github.com/apache/incubator-tvm/pull/6657#discussion_r502756049
##########
File path: python/tvm/auto_scheduler/measure_record.py
##########
@@ -159,3 +175,38 @@ def load_best(filename, workload_key=None, target=None):
best_res = res
return best_inp, best_res
+
+
+def correct_measure_input(inp, rebuild_state=False):
Review comment:
I see, but the name is still not quite clear. Maybe
`clone_measure_input`, `recover_measure_input`, or something like that?
##########
File path: tests/python/unittest/test_auto_scheduler_measure.py
##########
@@ -167,53 +167,76 @@ def test_record_pragma_storage_align_rfactor():
record_common(dag, s)
-def test_measure_local_builder_runner(enable_cpu_cache_flush=False):
+def test_correct_measure_input():
+ task = auto_scheduler.create_task(matmul_auto_scheduler_test, [512, 512,
512], "llvm")
+
+ inp = auto_scheduler.measure.MeasureInput(task,
task.compute_dag.init_state)
+ res = auto_scheduler.measure.MeasureResult([0.1], 0, "", 0.2, 1)
+
+ with tempfile.NamedTemporaryFile() as fp:
+ auto_scheduler.save_records(fp.name, [inp], [res])
+
+ log_reader = auto_scheduler.RecordReader(fp.name)
+ inputs, results = log_reader.read_lines()
+ assert len(inputs) == 1
+
+ raw_inp = inputs[0]
+
+ correct_inp =
auto_scheduler.measure_record.correct_measure_input(raw_inp)
+ assert str(correct_inp.task.compute_dag) == str(inp.task.compute_dag)
+
+ correct_inp = auto_scheduler.measure_record.correct_measure_input(
+ raw_inp, rebuild_state=True
+ )
+ assert str(correct_inp.state) == str(inp.state)
+
+
+def test_measure_local_builder_runner():
if not tvm.testing.device_enabled("llvm"):
return
- dag, s0 = get_tiled_matmul()
- tgt = tvm.target.Target("llvm")
- task = auto_scheduler.SearchTask(dag, "test", tgt)
+ task = auto_scheduler.create_task(matmul_auto_scheduler_test, [512, 512,
512], "llvm")
- minp = auto_scheduler.MeasureInput(task, s0)
- local_builder = auto_scheduler.LocalBuilder()
- local_runner = auto_scheduler.LocalRunner(
- timeout=60, enable_cpu_cache_flush=enable_cpu_cache_flush
- )
+ for enable_cpu_cache_flush in [True, False]:
+ minp = auto_scheduler.MeasureInput(task, task.compute_dag.init_state)
+ local_builder = auto_scheduler.LocalBuilder()
+ local_runner = auto_scheduler.LocalRunner(
+ timeout=60, enable_cpu_cache_flush=enable_cpu_cache_flush
+ )
- bress = local_builder.build([minp])
- assert bress[0].error_no == 0
- mress = local_runner.run([minp], bress)
- assert mress[0].error_no == 0
+ bress = local_builder.build([minp])
+ assert bress[0].error_no == 0
+ mress = local_runner.run([minp], bress)
+ assert mress[0].error_no == 0
-def test_measure_local_builder_rpc_runner(enable_cpu_cache_flush=False):
+def test_measure_local_builder_rpc_runner():
if not tvm.testing.device_enabled("llvm"):
return
- dag, s0 = get_tiled_matmul()
- tgt = tvm.target.Target("llvm")
- task = auto_scheduler.SearchTask(dag, "test", tgt)
+ task = auto_scheduler.create_task(matmul_auto_scheduler_test, [512, 512,
512], "llvm")
- minp = auto_scheduler.MeasureInput(task, s0)
- local_builder = auto_scheduler.LocalBuilder()
- measure_ctx = auto_scheduler.LocalRPCMeasureContext(
- timeout=60, enable_cpu_cache_flush=enable_cpu_cache_flush
- )
- rpc_runner = measure_ctx.runner
+ for enable_cpu_cache_flush in [True, False]:
Review comment:
Just want to avoid redundant testing. If we have another test for this
feature, then we could simplify this test.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]