This is an automated email from the ASF dual-hosted git repository.
lmzheng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 78bba3c [AutoScheduler] Fix a bug in thread binding (#6683)
78bba3c is described below
commit 78bba3c74785433b0f29589d89112ef06e4bca75
Author: Lianmin Zheng <[email protected]>
AuthorDate: Wed Oct 14 04:13:40 2020 -0700
[AutoScheduler] Fix a bug in thread binding (#6683)
* fix for lstm use case
* update
---
src/auto_scheduler/search_policy/sketch_policy_rules.cc | 5 ++++-
1 file changed, 4 insertions(+), 1 deletion(-)
diff --git a/src/auto_scheduler/search_policy/sketch_policy_rules.cc
b/src/auto_scheduler/search_policy/sketch_policy_rules.cc
index 045ee86..99188d4 100644
--- a/src/auto_scheduler/search_policy/sketch_policy_rules.cc
+++ b/src/auto_scheduler/search_policy/sketch_policy_rules.cc
@@ -702,11 +702,14 @@ PopulationGenerationRule::ResultKind
InitVectorization::Apply(SketchPolicyNode*
PopulationGenerationRule::ResultKind InitThreadBind::Apply(SketchPolicyNode*
policy, State* state,
std::mt19937*
rand_gen) const {
+ // Collect all stages that are roots of stages that perform multi-level
tiling.
std::set<int> multi_level_tiling_root_set;
for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) {
if (NeedsMultilevelTiling(policy->search_task, *state, stage_id)) {
const Stage& stage = (*state)->stages[stage_id];
- if (stage->compute_at != ComputeAtKind::kIter) {
+ if (stage->compute_at == ComputeAtKind::kInlined) {
+ continue;
+ } else if (stage->compute_at != ComputeAtKind::kIter) {
// This stage is not multi-level tiled,
// so it must be produced by RuleCrossThreadReduction.
CHECK(HasCrossThreadReduction(*state, stage_id));