lhutton1 commented on code in PR #12525:
URL: https://github.com/apache/tvm/pull/12525#discussion_r1151161849
##########
tests/python/driver/tvmc/test_autotuner.py:
##########
@@ -207,3 +209,27 @@ def test_autotune_pass_context(mock_pc, onnx_mnist,
tmpdir_factory):
# AutoTVM overrides the pass context later in the pipeline to disable
AlterOpLayout
assert mock_pc.call_count == 2
assert mock_pc.call_args_list[0][1]["opt_level"] == 3
+
+
+def test_filter_tasks_valid():
+ filter_tasks(list(range(10)), "list") == ([], True)
+ filter_tasks(list(range(10)), "help") == ([], True)
+ filter_tasks(list(range(10)), "all") == ([0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
False)
+ filter_tasks(list(range(10)), "5") == ([5], False)
+ filter_tasks(list(range(10)), "1-5") == ([1, 2, 3, 4, 5], False)
+ filter_tasks(list(range(10)), "-5") == ([0, 1, 2, 3, 4, 5], False)
+ filter_tasks(list(range(10)), "6-") == ([6, 7, 8, 9], False)
+ filter_tasks(list(range(10)), "0,1-3,all") == ([0, 1, 2, 3, 4, 5, 6, 7, 8,
9], False)
+ filter_tasks(list(range(10)), "0,4-5,9,list") == ([0, 4, 5, 9], True)
+
+
[email protected]
Review Comment:
XFail is normally used to mark a test that is "unexpectedly failing", which
doesn't seem like the case here. If the test suddenly starts passing it won't
be flagged in CI. I think it would be better to rewrite this test using
`pytest.raises` e.g. something like:
```
@pytest.mark.parameterize("value,err_msg", [("10", "my expected error
message"), ("5,10", "my expected error message 2"), ...])
def test_filter_tasks_invalid(value, err_msg):
with pytest.raises(AssertionError, match=err_msg):
filter_tasks(list(range(10)), value)
```
##########
tests/python/driver/tvmc/test_autotuner.py:
##########
@@ -207,3 +209,27 @@ def test_autotune_pass_context(mock_pc, onnx_mnist,
tmpdir_factory):
# AutoTVM overrides the pass context later in the pipeline to disable
AlterOpLayout
assert mock_pc.call_count == 2
assert mock_pc.call_args_list[0][1]["opt_level"] == 3
+
+
+def test_filter_tasks_valid():
Review Comment:
Nit: it would be helpful to parameterize these tests similar to below, so
failures are reported separately in the CI log
##########
python/tvm/driver/tvmc/autotuner.py:
##########
@@ -290,10 +295,82 @@ def drive_tune(args):
include_simple_tasks=args.include_simple_tasks,
log_estimated_latency=args.log_estimated_latency,
additional_target_options=reconstruct_target_args(args),
+ tasks_filter=args.tasks,
**transform_args,
)
+def filter_tasks(
+ tasks: Union[List[auto_scheduler.SearchTask], List[autotvm.task.Task]],
+ expr: str,
+):
+ """Utility to filter a list of tasks (AutoTVM or AutoScheduler) based on
+ a user-supplied string expression.
+
+ Parameters
+ ----------
+ tasks: list
+ A list of extracted AutoTVM or AutoScheduler tasks.
+ expr: str
+ User-supplied expression to be used for filtering.
+ """
+ assert isinstance(expr, str), "Expected filter expression of string type"
+ assert len(expr) > 0, "Got empty filter expression"
+
+ # groups of keywords are comma-separated
+ splitted = expr.split(",")
+
+ do_list = False
+ do_filter = False
+ selected = []
+ for item in splitted:
+ if item in ["list", "help"]:
+ do_list = True
+ elif item in ["all"]:
+ selected = list(range(len(tasks)))
+ else:
+ do_filter = True
+ if "-" in item:
+ lhs, rhs = item.split("-")[:2]
+ lhs = int(lhs) if lhs else 0
+ rhs = int(rhs) if rhs else len(tasks) - 1
+ assert 0 <= lhs < len(tasks), "Left-hand side expression out
of range"
+ assert 0 <= rhs < len(tasks), "Right-hand side expression out
of range"
+ selected.extend(list(range(lhs, rhs + 1)))
+ else:
+ assert isinstance(item, str)
+ idx = int(item)
+ assert 0 <= idx < len(tasks)
+ selected.append(idx)
+
+ if do_filter:
+ # remove duplicates
+ selected = list(set(selected))
+ tasks = [task for i, task in enumerate(tasks) if i in selected]
+
+ return tasks, do_list
+
+
+def print_task_list(tasks, enable_autoscheduler):
+ print("Available Tasks for tuning:")
+
+ def _trunc_helper(text, length):
+ return text if len(text) < length else text[: length - 3] + "..."
+
+ print(
+ "\n".join(
Review Comment:
nit: is it possible to write a test to check this output?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]