This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 49224cb Fix use of fallback AutoTVM knobs in default scheduling
(#8707)
49224cb is described below
commit 49224cb8b81b7e0b857935191019981b22787be3
Author: Andrey Malyshev <[email protected]>
AuthorDate: Sun Aug 15 23:33:04 2021 +0300
Fix use of fallback AutoTVM knobs in default scheduling (#8707)
* Fix use of fallback AutoTVM knobs
Previously knob values depended on order of explicit cfg update and
cfg.define_split
calls in fallback mode
* Add test for define_split with fallback defined values
---
python/tvm/autotvm/task/space.py | 5 ++++-
tests/python/unittest/test_autotvm_space.py | 2 ++
2 files changed, 6 insertions(+), 1 deletion(-)
diff --git a/python/tvm/autotvm/task/space.py b/python/tvm/autotvm/task/space.py
index afbfb4c..8a707b8 100644
--- a/python/tvm/autotvm/task/space.py
+++ b/python/tvm/autotvm/task/space.py
@@ -824,7 +824,10 @@ class ConfigSpace(object):
def _add_new_transform(self, space_class, name, axes, policy, **kwargs):
"""Add a new transform space in template"""
- if self._collect:
+ # if we do not have tuned info (_collect == True) but defined KNOB
value
+ # for "default" scheduling before call of _add_new_transform, in this
case
+ # no need to create new space and override previously pointed KNOB
values
+ if self._collect and not (self.is_fallback and name in
self._entity_map):
# convert schedule axis to space definition axis
axes = [x if isinstance(x, (VirtualAxis, Axis)) else self.axis(x)
for x in axes]
diff --git a/tests/python/unittest/test_autotvm_space.py
b/tests/python/unittest/test_autotvm_space.py
index 2d40371..d56ca9e 100644
--- a/tests/python/unittest/test_autotvm_space.py
+++ b/tests/python/unittest/test_autotvm_space.py
@@ -84,6 +84,8 @@ def test_split():
cfg = FallbackConfigEntity()
cfg.define_split("tile_n", cfg.axis(128), num_outputs=3)
cfg.fallback_split("tile_n", [-1, 8, 4])
+ # verify if define_split override previously manualy defined split params
+ cfg.define_split("tile_n", cfg.axis(128), num_outputs=3)
assert cfg["tile_n"].size == [4, 8, 4]
cfg = FallbackConfigEntity()