This is an automated email from the ASF dual-hosted git repository.

wrongtest pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 85556df6d9 [TIR][FIX] update FlopEstimator to include missing nodes 
(#17598)
85556df6d9 is described below

commit 85556df6d9be2aa51b3c311af2a59d49563eec2b
Author: PatrikPerssonInceptron 
<[email protected]>
AuthorDate: Tue Jan 28 14:15:24 2025 +0100

    [TIR][FIX] update FlopEstimator to include missing nodes (#17598)
    
    * updated estimate flops to include the AllocateNode
    
    * added AttrStmtNode
    
    * added a visit to the AttrStmtNode body
    
    * added a visit to value of AttrStmtNode
---
 src/tir/analysis/estimate_flops.cc | 6 ++++++
 1 file changed, 6 insertions(+)

diff --git a/src/tir/analysis/estimate_flops.cc 
b/src/tir/analysis/estimate_flops.cc
index ee2869d993..c4851e255f 100644
--- a/src/tir/analysis/estimate_flops.cc
+++ b/src/tir/analysis/estimate_flops.cc
@@ -138,6 +138,11 @@ class FlopEstimator : private ExprFunctor<TResult(const 
PrimExpr& n)>,
   }
 
   TResult VisitExpr_(const BufferLoadNode* op) override { return TResult(); }
+  TResult VisitStmt_(const AttrStmtNode* op) override {
+    TResult result = VisitStmt(op->body);
+    result += VisitExpr(op->value);
+    return result;
+  }
   TResult VisitStmt_(const BufferStoreNode* store) override { return 
VisitExpr(store->value); }
   TResult VisitStmt_(const BlockRealizeNode* block) override {
     return VisitStmt(block->block->body);
@@ -186,6 +191,7 @@ class FlopEstimator : private ExprFunctor<TResult(const 
PrimExpr& n)>,
   TResult VisitExpr_(const FloatImmNode* op) override { return TResult(); }
   TResult VisitExpr_(const CastNode* op) override { return 
VisitExpr(op->value); }
   TResult VisitStmt_(const AllocateConstNode* op) override { return 
VisitStmt(op->body); }
+  TResult VisitStmt_(const AllocateNode* op) override { return 
VisitStmt(op->body); }
   TResult VisitStmt_(const DeclBufferNode* op) override { return 
VisitStmt(op->body); }
 
   TResult VisitStmt_(const SeqStmtNode* seq) override {

Reply via email to