This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 49224cb  Fix use of fallback AutoTVM knobs in default scheduling 
(#8707)
49224cb is described below

commit 49224cb8b81b7e0b857935191019981b22787be3
Author: Andrey Malyshev <[email protected]>
AuthorDate: Sun Aug 15 23:33:04 2021 +0300

    Fix use of fallback AutoTVM knobs in default scheduling (#8707)
    
    * Fix use of fallback AutoTVM knobs
    
    Previously knob values depended on order of explicit cfg update and 
cfg.define_split
    calls in fallback mode
    
    * Add test for define_split with fallback defined values
---
 python/tvm/autotvm/task/space.py            | 5 ++++-
 tests/python/unittest/test_autotvm_space.py | 2 ++
 2 files changed, 6 insertions(+), 1 deletion(-)

diff --git a/python/tvm/autotvm/task/space.py b/python/tvm/autotvm/task/space.py
index afbfb4c..8a707b8 100644
--- a/python/tvm/autotvm/task/space.py
+++ b/python/tvm/autotvm/task/space.py
@@ -824,7 +824,10 @@ class ConfigSpace(object):
 
     def _add_new_transform(self, space_class, name, axes, policy, **kwargs):
         """Add a new transform space in template"""
-        if self._collect:
+        # if we do not have tuned info (_collect == True) but defined KNOB 
value
+        # for "default" scheduling before call of _add_new_transform, in this 
case
+        # no need to create new space and override previously pointed KNOB 
values
+        if self._collect and not (self.is_fallback and name in 
self._entity_map):
             # convert schedule axis to space definition axis
             axes = [x if isinstance(x, (VirtualAxis, Axis)) else self.axis(x) 
for x in axes]
 
diff --git a/tests/python/unittest/test_autotvm_space.py 
b/tests/python/unittest/test_autotvm_space.py
index 2d40371..d56ca9e 100644
--- a/tests/python/unittest/test_autotvm_space.py
+++ b/tests/python/unittest/test_autotvm_space.py
@@ -84,6 +84,8 @@ def test_split():
     cfg = FallbackConfigEntity()
     cfg.define_split("tile_n", cfg.axis(128), num_outputs=3)
     cfg.fallback_split("tile_n", [-1, 8, 4])
+    # verify if define_split override previously manualy defined split params
+    cfg.define_split("tile_n", cfg.axis(128), num_outputs=3)
     assert cfg["tile_n"].size == [4, 8, 4]
 
     cfg = FallbackConfigEntity()

Reply via email to