This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new a5d7dfa583 [FIX,TIR] Handle LetStmt in EstimateTIRFLops (#12138)
a5d7dfa583 is described below
commit a5d7dfa5833075492009014079e23f6ce1a1c64d
Author: Tristan Konolige <[email protected]>
AuthorDate: Mon Aug 1 05:35:25 2022 -0700
[FIX,TIR] Handle LetStmt in EstimateTIRFLops (#12138)
* [FIX,TIR] Handle LetStmt in EstimateTIRFLops
Add flops for the right hand side of let statements to the total flops.
* assert flops
---
.../postproc/rewrite_parallel_vectorize_unroll.cc | 13 ++++++++-----
src/tir/analysis/estimate_flops.cc | 6 ++++++
.../unittest/test_tir_analysis_estimate_tir_flops.py | 14 ++++++++++++++
3 files changed, 28 insertions(+), 5 deletions(-)
diff --git a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc
b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc
index 001c97645b..f3c2b1328b 100644
--- a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc
+++ b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc
@@ -303,11 +303,14 @@ bool FindAnnotatedRootBlock(const Schedule& sch,
ParsedAnnotation* parsed, Block
const GlobalVar& g_var = kv.first;
const BaseFunc& base_func = kv.second;
if (const auto* prim_func = base_func.as<PrimFuncNode>()) {
- Block block = Downcast<BlockRealize>(prim_func->body)->block;
- if (ParseAnnotation(block, parsed)) {
- *root_rv = sch->GetBlock(block->name_hint, g_var->name_hint);
- RemoveParsedAnn(sch, *root_rv, *parsed);
- return true;
+ const BlockRealizeNode* block_realize =
prim_func->body.as<BlockRealizeNode>();
+ if (block_realize != nullptr) {
+ Block block = block_realize->block;
+ if (ParseAnnotation(block, parsed)) {
+ *root_rv = sch->GetBlock(block->name_hint, g_var->name_hint);
+ RemoveParsedAnn(sch, *root_rv, *parsed);
+ return true;
+ }
}
}
}
diff --git a/src/tir/analysis/estimate_flops.cc
b/src/tir/analysis/estimate_flops.cc
index 895ae798e3..44a58f8792 100644
--- a/src/tir/analysis/estimate_flops.cc
+++ b/src/tir/analysis/estimate_flops.cc
@@ -152,6 +152,12 @@ class FlopEstimator : private ExprFunctor<TResult(const
PrimExpr& n)>,
return cond;
}
+ TResult VisitStmt_(const LetStmtNode* let) override {
+ TResult value = VisitExpr(let->value);
+ value += VisitStmt(let->body);
+ return value;
+ }
+
TResult VisitExpr_(const SelectNode* op) override {
TResult cond = VisitExpr(op->condition);
cond += VisitExpr(op->true_value).MaxWith(VisitExpr(op->false_value));
diff --git a/tests/python/unittest/test_tir_analysis_estimate_tir_flops.py
b/tests/python/unittest/test_tir_analysis_estimate_tir_flops.py
index 7aa015831a..1cba1a739c 100644
--- a/tests/python/unittest/test_tir_analysis_estimate_tir_flops.py
+++ b/tests/python/unittest/test_tir_analysis_estimate_tir_flops.py
@@ -18,9 +18,11 @@
import sys
import pytest
+
import tvm.testing
from tvm.ir import IRModule
from tvm.meta_schedule.testing.te_workload import create_te_workload
+from tvm.script import tir as T
from tvm.tir.analysis import estimate_tir_flops
@@ -48,5 +50,17 @@ def test_te_workload(workload, flops):
assert float(flops) == estimate_tir_flops(mod)
[email protected]_func
+def flops_with_let(a: T.Buffer[16, "float32"]):
+ for i in range(8):
+ j = i + 8
+ a[j] = a[i]
+
+
+def test_flops_with_let():
+ flops = estimate_tir_flops(IRModule({"main": flops_with_let}))
+ assert flops == 8
+
+
if __name__ == "__main__":
tvm.testing.main()