This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new aeda760e5e [TIR] Disallow unused rhs vars in GetAutoTensorizeMapping
(#12225)
aeda760e5e is described below
commit aeda760e5e29eddd0a7ddb22c7031f9607440770
Author: Wuwei Lin <[email protected]>
AuthorDate: Fri Jul 29 01:00:21 2022 -0700
[TIR] Disallow unused rhs vars in GetAutoTensorizeMapping (#12225)
---
src/tir/schedule/analysis/analysis.cc | 5 +++++
.../python/unittest/test_tir_schedule_analysis.py | 24 ++++++++++++++++++++++
2 files changed, 29 insertions(+)
diff --git a/src/tir/schedule/analysis/analysis.cc
b/src/tir/schedule/analysis/analysis.cc
index 569259d061..72b8c12fea 100644
--- a/src/tir/schedule/analysis/analysis.cc
+++ b/src/tir/schedule/analysis/analysis.cc
@@ -2460,6 +2460,7 @@ class AutoTensorizeMappingProposer {
}
// Step 3: Fuse LHS iters mapped to the same RHS iter
+ std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> used_rhs_vars;
for (size_t i = 0; i < extractor_->lhs_iters_.size(); ++i) {
const Var& lhs_iter_var = extractor_->lhs_iters_[i]->var;
const VarSet& rhs_candidates = lhs_feasible_vars_[lhs_iter_var];
@@ -2472,12 +2473,16 @@ class AutoTensorizeMappingProposer {
PrimExpr updated_fused_lhs =
fused_lhs * lhs_iter_extents.at(lhs_iter_var) + index_map_src[i];
fused_lhs_iters.Set(rhs_var, updated_fused_lhs);
+ used_rhs_vars.insert(rhs_var);
} else {
// non-unique mapping is not supported
return {};
}
}
for (const auto& iter : extractor_->rhs_iters_) {
+ if (!used_rhs_vars.count(iter->var)) {
+ return {};
+ }
index_map_tgt.push_back(analyzer_->Simplify(fused_lhs_iters[iter->var]));
}
// At most one mapping is supported.
diff --git a/tests/python/unittest/test_tir_schedule_analysis.py
b/tests/python/unittest/test_tir_schedule_analysis.py
index 625343f740..d3e6033e88 100644
--- a/tests/python/unittest/test_tir_schedule_analysis.py
+++ b/tests/python/unittest/test_tir_schedule_analysis.py
@@ -265,6 +265,9 @@ def check_index_map(workload, block_name, intrin_name,
expected_index_map):
block = s.get_block(block_name)
desc_func = TensorIntrin.get(intrin_name).desc
info = get_auto_tensorize_mapping_info(s, block, desc_func)
+ if expected_index_map is None:
+ assert info is None
+ return
assert len(info.mappings) == 1
assert
IndexMap.from_func(expected_index_map).is_equivalent_to(info.mappings[0])
@@ -304,5 +307,26 @@ def test_get_auto_tensorize_mapping_info_batch_matmul(b,
m, n, k):
)
[email protected](
+ "n,m,k,expected",
+ [
+ (
+ 512,
+ 512,
+ 512,
+ lambda n, m, k: (
+ n,
+ m,
+ k,
+ ),
+ ),
+ (1, 32, 32, None),
+ ],
+)
+def test_get_auto_tensorize_mapping_info_matmul(n, m, k, expected):
+ matmul = create_prim_func(te_workload.matmul(n, m, k, in_dtype="float16",
out_dtype="float32"))
+ check_index_map(matmul, "C", WMMA_SYNC_16x16x16_f16f16f32_INTRIN, expected)
+
+
if __name__ == "__main__":
tvm.testing.main()