This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new a5d7dfa583 [FIX,TIR] Handle LetStmt in EstimateTIRFLops (#12138)
a5d7dfa583 is described below

commit a5d7dfa5833075492009014079e23f6ce1a1c64d
Author: Tristan Konolige <[email protected]>
AuthorDate: Mon Aug 1 05:35:25 2022 -0700

    [FIX,TIR] Handle LetStmt in EstimateTIRFLops (#12138)
    
    * [FIX,TIR] Handle LetStmt in EstimateTIRFLops
    
    Add flops for the right hand side of let statements to the total flops.
    
    * assert flops
---
 .../postproc/rewrite_parallel_vectorize_unroll.cc          | 13 ++++++++-----
 src/tir/analysis/estimate_flops.cc                         |  6 ++++++
 .../unittest/test_tir_analysis_estimate_tir_flops.py       | 14 ++++++++++++++
 3 files changed, 28 insertions(+), 5 deletions(-)

diff --git a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc 
b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc
index 001c97645b..f3c2b1328b 100644
--- a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc
+++ b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc
@@ -303,11 +303,14 @@ bool FindAnnotatedRootBlock(const Schedule& sch, 
ParsedAnnotation* parsed, Block
     const GlobalVar& g_var = kv.first;
     const BaseFunc& base_func = kv.second;
     if (const auto* prim_func = base_func.as<PrimFuncNode>()) {
-      Block block = Downcast<BlockRealize>(prim_func->body)->block;
-      if (ParseAnnotation(block, parsed)) {
-        *root_rv = sch->GetBlock(block->name_hint, g_var->name_hint);
-        RemoveParsedAnn(sch, *root_rv, *parsed);
-        return true;
+      const BlockRealizeNode* block_realize = 
prim_func->body.as<BlockRealizeNode>();
+      if (block_realize != nullptr) {
+        Block block = block_realize->block;
+        if (ParseAnnotation(block, parsed)) {
+          *root_rv = sch->GetBlock(block->name_hint, g_var->name_hint);
+          RemoveParsedAnn(sch, *root_rv, *parsed);
+          return true;
+        }
       }
     }
   }
diff --git a/src/tir/analysis/estimate_flops.cc 
b/src/tir/analysis/estimate_flops.cc
index 895ae798e3..44a58f8792 100644
--- a/src/tir/analysis/estimate_flops.cc
+++ b/src/tir/analysis/estimate_flops.cc
@@ -152,6 +152,12 @@ class FlopEstimator : private ExprFunctor<TResult(const 
PrimExpr& n)>,
     return cond;
   }
 
+  TResult VisitStmt_(const LetStmtNode* let) override {
+    TResult value = VisitExpr(let->value);
+    value += VisitStmt(let->body);
+    return value;
+  }
+
   TResult VisitExpr_(const SelectNode* op) override {
     TResult cond = VisitExpr(op->condition);
     cond += VisitExpr(op->true_value).MaxWith(VisitExpr(op->false_value));
diff --git a/tests/python/unittest/test_tir_analysis_estimate_tir_flops.py 
b/tests/python/unittest/test_tir_analysis_estimate_tir_flops.py
index 7aa015831a..1cba1a739c 100644
--- a/tests/python/unittest/test_tir_analysis_estimate_tir_flops.py
+++ b/tests/python/unittest/test_tir_analysis_estimate_tir_flops.py
@@ -18,9 +18,11 @@
 import sys
 
 import pytest
+
 import tvm.testing
 from tvm.ir import IRModule
 from tvm.meta_schedule.testing.te_workload import create_te_workload
+from tvm.script import tir as T
 from tvm.tir.analysis import estimate_tir_flops
 
 
@@ -48,5 +50,17 @@ def test_te_workload(workload, flops):
     assert float(flops) == estimate_tir_flops(mod)
 
 
[email protected]_func
+def flops_with_let(a: T.Buffer[16, "float32"]):
+    for i in range(8):
+        j = i + 8
+        a[j] = a[i]
+
+
+def test_flops_with_let():
+    flops = estimate_tir_flops(IRModule({"main": flops_with_let}))
+    assert flops == 8
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to