This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 768becd  [AutoScheduler] Fix FLOPS estimation (#8695)
768becd is described below

commit 768becd0db2be3ae1857a01d3b8574deba89f63f
Author: Cody Yu <[email protected]>
AuthorDate: Mon Aug 9 21:42:08 2021 -0700

    [AutoScheduler] Fix FLOPS estimation (#8695)
---
 src/auto_scheduler/compute_dag.cc                        | 12 ++++++++----
 tests/python/unittest/test_auto_scheduler_compute_dag.py |  5 +++++
 2 files changed, 13 insertions(+), 4 deletions(-)

diff --git a/src/auto_scheduler/compute_dag.cc 
b/src/auto_scheduler/compute_dag.cc
index abbcba2..e82830f 100644
--- a/src/auto_scheduler/compute_dag.cc
+++ b/src/auto_scheduler/compute_dag.cc
@@ -611,10 +611,14 @@ class FlopEstimator : public ExprFunctor<double(const 
PrimExpr& n)> {
            std::max(VisitExpr(op->true_value), VisitExpr(op->false_value));
   }
 
-#define VisitBinary(Node)                                         \
-  double VisitExpr_(const Node* op) final {                       \
-    double base = op->dtype.code() == cur_type_code_ ? 1.0 : 0.0; \
-    return base + VisitExpr(op->a) + VisitExpr(op->b);            \
+// Index calculations (e.g., the "i + j" expression in A[i + j]) are not 
counted in FLOPS.
+#define VisitBinary(Node)                                                      
               \
+  double VisitExpr_(const Node* op) final {                                    
               \
+    double base = 1.0;                                                         
               \
+    if ((op->a->dtype.code() != cur_type_code_) && (op->b->dtype.code() != 
cur_type_code_)) { \
+      base = 0.0;                                                              
               \
+    }                                                                          
               \
+    return base + VisitExpr(op->a) + VisitExpr(op->b);                         
               \
   }
 
 #define VisitUnary(Node)                                          \
diff --git a/tests/python/unittest/test_auto_scheduler_compute_dag.py 
b/tests/python/unittest/test_auto_scheduler_compute_dag.py
index b303ef5..e394115 100644
--- a/tests/python/unittest/test_auto_scheduler_compute_dag.py
+++ b/tests/python/unittest/test_auto_scheduler_compute_dag.py
@@ -62,6 +62,11 @@ def test_estimate_flop():
     dag = auto_scheduler.ComputeDAG([A, B, F])
     assert abs(dag.flop_ct - (2 * N ** 3 + 1234)) < 0.5
 
+    A = te.placeholder((N, N), dtype="float32", name="A")
+    F = te.compute((N, N), lambda i, j: te.if_then_else(A[i, j] > 0, A[i, j], 
0))
+    dag = auto_scheduler.ComputeDAG([A, F])
+    assert abs(dag.flop_ct - N ** 2) < 0.5
+
 
 def test_stage_order():
     """Test if the stage order is preserved when recovering a DAG."""

Reply via email to