This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 7c06de52a1 [Fix][MetaSchedule] Fix redundant stages in async pipeline
for mlt (#14143)
7c06de52a1 is described below
commit 7c06de52a1a93d32b925543b16b7c514a6ddccbd
Author: Tian Xia <[email protected]>
AuthorDate: Wed Mar 1 01:29:55 2023 +0800
[Fix][MetaSchedule] Fix redundant stages in async pipeline for mlt (#14143)
This PR fixes redundant stages if visiting `InitializeWithTuneContext`
multiple times.
---
.../schedule_rule/multi_level_tiling.cc | 2 +-
...ule_space_cuda_async_multiple_initialization.py | 88 ++++++++++++++++++++++
2 files changed, 89 insertions(+), 1 deletion(-)
diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc
b/src/meta_schedule/schedule_rule/multi_level_tiling.cc
index 54407c46c8..779114e9cf 100644
--- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc
+++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc
@@ -96,7 +96,7 @@ void MultiLevelTilingNode::InitializeWithTuneContext(const
TuneContext& context)
if (std::stoi(sm) >= 80) {
// only stage = 4 & 5 is tested. all integer that is bigger than 2
// is theoretically feasible, but no guarantee for great performance.
- this->stages.insert(this->stages.end(), {4, 5});
+ this->stages = {4, 5};
}
} catch (const std::invalid_argument& e) {
LOG(WARNING) << "ValueError: Unable to parse `target.arch`: " << sm
diff --git
a/tests/python/unittest/test_meta_schedule_space_cuda_async_multiple_initialization.py
b/tests/python/unittest/test_meta_schedule_space_cuda_async_multiple_initialization.py
new file mode 100644
index 0000000000..e7b3789257
--- /dev/null
+++
b/tests/python/unittest/test_meta_schedule_space_cuda_async_multiple_initialization.py
@@ -0,0 +1,88 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Tests for MetaSchedule search space on CUDA"""
+from typing import List, Optional, Tuple, Union
+
+# isort: off
+from typing_extensions import Literal
+
+# isort: on
+from tvm.meta_schedule.testing.space_generation import get_rules
+from tvm import meta_schedule as ms
+from tvm.meta_schedule.testing.te_workload import create_te_workload
+from tvm.target import Target
+from tvm.ir import IRModule
+from tvm.tir import Schedule
+
+
+def generate_design_space(
+ kind: Literal["llvm", "cuda", "cuda-tensorcore", "hexagon"],
+ mod: IRModule,
+ target: Target,
+ types: Union[type, Tuple[type, ...]],
+ sch_rules: Optional[List[ms.ScheduleRule]] = None,
+ initialize_time: int = 1,
+) -> List[Schedule]:
+ if sch_rules is None:
+ sch_rules = get_rules(kind, types)
+ else:
+ assert types is None
+ ctx = ms.TuneContext(
+ mod=mod,
+ target=target,
+ space_generator=ms.space_generator.PostOrderApply(
+ sch_rules=sch_rules,
+ postprocs=[],
+ mutator_probs={},
+ ),
+ task_name="test",
+ )
+ # each time cloning will trigger one more initialization
+ for _ in range(initialize_time - 1):
+ ctx = ctx.clone()
+ return ctx.generate_design_space()
+
+
+def _target():
+ return Target("nvidia/geforce-rtx-3070")
+
+
+def _design_space(mod):
+ return generate_design_space(
+ kind="cuda",
+ mod=mod,
+ target=_target(),
+ types=ms.ScheduleRule,
+ initialize_time=100,
+ )
+
+
+def test_c2d():
+ mod = create_te_workload("C2D", 0)
+ actual = _design_space(mod)
+ assert len(actual) == 3
+
+
+def test_gmm():
+ mod = create_te_workload("GMM", 0)
+ actual = _design_space(mod)
+ assert len(actual) == 3
+
+
+if __name__ == "__main__":
+ test_c2d()
+ test_gmm()