This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 768becd [AutoScheduler] Fix FLOPS estimation (#8695)
768becd is described below
commit 768becd0db2be3ae1857a01d3b8574deba89f63f
Author: Cody Yu <[email protected]>
AuthorDate: Mon Aug 9 21:42:08 2021 -0700
[AutoScheduler] Fix FLOPS estimation (#8695)
---
src/auto_scheduler/compute_dag.cc | 12 ++++++++----
tests/python/unittest/test_auto_scheduler_compute_dag.py | 5 +++++
2 files changed, 13 insertions(+), 4 deletions(-)
diff --git a/src/auto_scheduler/compute_dag.cc
b/src/auto_scheduler/compute_dag.cc
index abbcba2..e82830f 100644
--- a/src/auto_scheduler/compute_dag.cc
+++ b/src/auto_scheduler/compute_dag.cc
@@ -611,10 +611,14 @@ class FlopEstimator : public ExprFunctor<double(const
PrimExpr& n)> {
std::max(VisitExpr(op->true_value), VisitExpr(op->false_value));
}
-#define VisitBinary(Node) \
- double VisitExpr_(const Node* op) final { \
- double base = op->dtype.code() == cur_type_code_ ? 1.0 : 0.0; \
- return base + VisitExpr(op->a) + VisitExpr(op->b); \
+// Index calculations (e.g., the "i + j" expression in A[i + j]) are not
counted in FLOPS.
+#define VisitBinary(Node)
\
+ double VisitExpr_(const Node* op) final {
\
+ double base = 1.0;
\
+ if ((op->a->dtype.code() != cur_type_code_) && (op->b->dtype.code() !=
cur_type_code_)) { \
+ base = 0.0;
\
+ }
\
+ return base + VisitExpr(op->a) + VisitExpr(op->b);
\
}
#define VisitUnary(Node) \
diff --git a/tests/python/unittest/test_auto_scheduler_compute_dag.py
b/tests/python/unittest/test_auto_scheduler_compute_dag.py
index b303ef5..e394115 100644
--- a/tests/python/unittest/test_auto_scheduler_compute_dag.py
+++ b/tests/python/unittest/test_auto_scheduler_compute_dag.py
@@ -62,6 +62,11 @@ def test_estimate_flop():
dag = auto_scheduler.ComputeDAG([A, B, F])
assert abs(dag.flop_ct - (2 * N ** 3 + 1234)) < 0.5
+ A = te.placeholder((N, N), dtype="float32", name="A")
+ F = te.compute((N, N), lambda i, j: te.if_then_else(A[i, j] > 0, A[i, j],
0))
+ dag = auto_scheduler.ComputeDAG([A, F])
+ assert abs(dag.flop_ct - N ** 2) < 0.5
+
def test_stage_order():
"""Test if the stage order is preserved when recovering a DAG."""