This is an automated email from the ASF dual-hosted git repository.
wrongtest pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 85556df6d9 [TIR][FIX] update FlopEstimator to include missing nodes
(#17598)
85556df6d9 is described below
commit 85556df6d9be2aa51b3c311af2a59d49563eec2b
Author: PatrikPerssonInceptron
<[email protected]>
AuthorDate: Tue Jan 28 14:15:24 2025 +0100
[TIR][FIX] update FlopEstimator to include missing nodes (#17598)
* updated estimate flops to include the AllocateNode
* added AttrStmtNode
* added a visit to the AttrStmtNode body
* added a visit to value of AttrStmtNode
---
src/tir/analysis/estimate_flops.cc | 6 ++++++
1 file changed, 6 insertions(+)
diff --git a/src/tir/analysis/estimate_flops.cc
b/src/tir/analysis/estimate_flops.cc
index ee2869d993..c4851e255f 100644
--- a/src/tir/analysis/estimate_flops.cc
+++ b/src/tir/analysis/estimate_flops.cc
@@ -138,6 +138,11 @@ class FlopEstimator : private ExprFunctor<TResult(const
PrimExpr& n)>,
}
TResult VisitExpr_(const BufferLoadNode* op) override { return TResult(); }
+ TResult VisitStmt_(const AttrStmtNode* op) override {
+ TResult result = VisitStmt(op->body);
+ result += VisitExpr(op->value);
+ return result;
+ }
TResult VisitStmt_(const BufferStoreNode* store) override { return
VisitExpr(store->value); }
TResult VisitStmt_(const BlockRealizeNode* block) override {
return VisitStmt(block->block->body);
@@ -186,6 +191,7 @@ class FlopEstimator : private ExprFunctor<TResult(const
PrimExpr& n)>,
TResult VisitExpr_(const FloatImmNode* op) override { return TResult(); }
TResult VisitExpr_(const CastNode* op) override { return
VisitExpr(op->value); }
TResult VisitStmt_(const AllocateConstNode* op) override { return
VisitStmt(op->body); }
+ TResult VisitStmt_(const AllocateNode* op) override { return
VisitStmt(op->body); }
TResult VisitStmt_(const DeclBufferNode* op) override { return
VisitStmt(op->body); }
TResult VisitStmt_(const SeqStmtNode* seq) override {