This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new aa49fc3c2b [Unity][Fix] Annotate TIR op pattern could have no stores. 
(#14420)
aa49fc3c2b is described below

commit aa49fc3c2b8bb96188ae03749117352a616fe234
Author: Prakalp Srivastava <[email protected]>
AuthorDate: Wed Mar 29 18:52:20 2023 -0400

    [Unity][Fix] Annotate TIR op pattern could have no stores. (#14420)
    
    * [Unity][Fix] Annotate TIR op pattern could have no stores.
    
    * Add test case for context
---
 src/relax/analysis/tir_op_pattern_kind.cc             |  8 ++++++++
 .../relax/test_transform_annotate_tir_op_pattern.py   | 19 +++++++++++++++++++
 2 files changed, 27 insertions(+)

diff --git a/src/relax/analysis/tir_op_pattern_kind.cc 
b/src/relax/analysis/tir_op_pattern_kind.cc
index aed984781c..d7d84c1973 100644
--- a/src/relax/analysis/tir_op_pattern_kind.cc
+++ b/src/relax/analysis/tir_op_pattern_kind.cc
@@ -77,6 +77,14 @@ class PatternKindAnalyzer : public StmtExprVisitor {
     store_ = NullOpt;
     // Step 2. Visit block body.
     StmtVisitor::VisitStmt(op->body);
+
+    // We support exactly one buffer store in a block (usually generated by TE 
compute)
+    // If we have not seen any store in the current block, classify as Opaque.
+    if (!store_.defined()) {
+      kind_ = relay::kOpaque;
+      return;
+    }
+
     BufferStore store = store_.value();
 
     // Step 3. Checking load store indices pattern
diff --git a/tests/python/relax/test_transform_annotate_tir_op_pattern.py 
b/tests/python/relax/test_transform_annotate_tir_op_pattern.py
index 5fc8c99367..c2f0e7af5f 100644
--- a/tests/python/relax/test_transform_annotate_tir_op_pattern.py
+++ b/tests/python/relax/test_transform_annotate_tir_op_pattern.py
@@ -383,5 +383,24 @@ def test_sum_sqsum():
     assert new_mod["sum_sqsum"].attrs["op_pattern"] == 
OpPatternKind.kCommReduce
 
 
+def test_no_buffer_stores():
+    @tvm.script.ir_module
+    class Module:
+        @T.prim_func
+        def no_buffer_stores(A: T.Buffer((32, 64), "float32"), vsum: 
T.Buffer((32,), "float32")):
+            for ax0, k0 in T.grid(32, 64):
+                with T.block("block"):
+                    v_ax0, v_k0 = T.axis.remap("SR", [ax0, k0])
+                    T.reads(A[v_ax0, v_k0])
+                    T.writes(vsum[v_ax0])
+                    # absence of buffer stores usually happens when there is 
an external call for
+                    # computation. We assume opaque in all such cases.
+                    T.call_packed("some_func")
+
+    mod = Module
+    new_mod = relax.transform.AnnotateTIROpPattern()(mod)
+    assert new_mod["no_buffer_stores"].attrs["op_pattern"] == 
OpPatternKind.kOpaque
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to