This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new aa49fc3c2b [Unity][Fix] Annotate TIR op pattern could have no stores.
(#14420)
aa49fc3c2b is described below
commit aa49fc3c2b8bb96188ae03749117352a616fe234
Author: Prakalp Srivastava <[email protected]>
AuthorDate: Wed Mar 29 18:52:20 2023 -0400
[Unity][Fix] Annotate TIR op pattern could have no stores. (#14420)
* [Unity][Fix] Annotate TIR op pattern could have no stores.
* Add test case for context
---
src/relax/analysis/tir_op_pattern_kind.cc | 8 ++++++++
.../relax/test_transform_annotate_tir_op_pattern.py | 19 +++++++++++++++++++
2 files changed, 27 insertions(+)
diff --git a/src/relax/analysis/tir_op_pattern_kind.cc
b/src/relax/analysis/tir_op_pattern_kind.cc
index aed984781c..d7d84c1973 100644
--- a/src/relax/analysis/tir_op_pattern_kind.cc
+++ b/src/relax/analysis/tir_op_pattern_kind.cc
@@ -77,6 +77,14 @@ class PatternKindAnalyzer : public StmtExprVisitor {
store_ = NullOpt;
// Step 2. Visit block body.
StmtVisitor::VisitStmt(op->body);
+
+ // We support exactly one buffer store in a block (usually generated by TE
compute)
+ // If we have not seen any store in the current block, classify as Opaque.
+ if (!store_.defined()) {
+ kind_ = relay::kOpaque;
+ return;
+ }
+
BufferStore store = store_.value();
// Step 3. Checking load store indices pattern
diff --git a/tests/python/relax/test_transform_annotate_tir_op_pattern.py
b/tests/python/relax/test_transform_annotate_tir_op_pattern.py
index 5fc8c99367..c2f0e7af5f 100644
--- a/tests/python/relax/test_transform_annotate_tir_op_pattern.py
+++ b/tests/python/relax/test_transform_annotate_tir_op_pattern.py
@@ -383,5 +383,24 @@ def test_sum_sqsum():
assert new_mod["sum_sqsum"].attrs["op_pattern"] ==
OpPatternKind.kCommReduce
+def test_no_buffer_stores():
+ @tvm.script.ir_module
+ class Module:
+ @T.prim_func
+ def no_buffer_stores(A: T.Buffer((32, 64), "float32"), vsum:
T.Buffer((32,), "float32")):
+ for ax0, k0 in T.grid(32, 64):
+ with T.block("block"):
+ v_ax0, v_k0 = T.axis.remap("SR", [ax0, k0])
+ T.reads(A[v_ax0, v_k0])
+ T.writes(vsum[v_ax0])
+ # absence of buffer stores usually happens when there is
an external call for
+ # computation. We assume opaque in all such cases.
+ T.call_packed("some_func")
+
+ mod = Module
+ new_mod = relax.transform.AnnotateTIROpPattern()(mod)
+ assert new_mod["no_buffer_stores"].attrs["op_pattern"] ==
OpPatternKind.kOpaque
+
+
if __name__ == "__main__":
tvm.testing.main()