masahi commented on a change in pull request #9482:
URL: https://github.com/apache/tvm/pull/9482#discussion_r790201656



##########
File path: src/tir/transforms/common_subexpr_elim_tools.cc
##########
@@ -0,0 +1,836 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+* \file common_subexpr_elim_tools.cc
+* \brief Implementation of analysis tools and utility functions used
+          by the Common Subexpression Elimination (CSE) pass.
+*/
+
+#include "common_subexpr_elim_tools.h"
+
+#include <tvm/ir/transform.h>  // For the class Pass and the class PassContext
+#include <tvm/runtime/container/string.h>
+#include <tvm/tir/analysis.h>  // For the ExprDeepEqual analysis
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/function.h>  // For the class PrimFunc
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>  // For the declaration of the pass
+
+#include <algorithm>      // For std::find_if
+#include <unordered_map>  // For the hashtable datatype
+#include <vector>
+
+#include "../analysis/check_contains.h"  // For the CheckContains analysis
+
+namespace tvm {
+namespace tir {
+
+// cache_ is a static variable of the class ComputationsDoneBy, and C++ 
requires to define here
+// such static attribute, otherwise it causes a linking error.
+CacheOfComputations ComputationsDoneBy::cache_;
+
+/* ********************************** Class ComputationsDoneBy 
**********************************
+***********************************************************************************************
 */
+
+/* This utility class of the CSE pass offers a way of knowing the eligible 
computations done by a
+   statement or expression. A "computation" here is a syntatical entity, 
represented by a PrimExpr.
+   This analysis returns a hashtable associating PrimExpr (a computation done) 
to a number (which
+   is the number of time that this computation is being seen).
+   This analysis is used by the CSE pass in order to find potential candidates 
for being introduced
+   into new variables (after having merged semantically equivalent 
computations).
+
+   This analysis is parametrized by two predicates : `is_eligible_computation` 
and
+   `can_contain_computations`.
+   The first one helps to select only "eligible" computations, and the second 
one helps to only
+   select computations that are located at appropriate location (i.e., it 
tells in which nodes the
+   analysis can recurse). The user of the class must define these notions of 
"eligible computation"
+   and of "nodes that can contain eligibile computations" for his own use case.
+
+   - On an statement, this analysis often returns the union of all the 
computations that appear in
+   its child nodes (ie, the union of the results of the recursive calls).
+   For instance, on the input statement [let a = x+y in Mem[i1+i2] = a+b] it 
will report (x+y)
+   seen once, (i1+i2) seen once, and (a+b) also seen once when used with 
typical predicates.
+   On some nodes, it will return something more complicated that uses the 
intersection of the
+   computations done by the children nodes.
+   For instance, on the input statement [if (x+y>z) then a = x+y else a = b-x] 
it will return
+   (x+y) seen twice but it won't report b-x as is it seen only the else branch.
+
+   - On an expression, this analysis returns the expression itself, except if 
it is not eligible
+   for being introduced by the CSE pass into a variable according to 
`is_eligible_computation_`
+   (often because it's a load node or a function call node for instance), in 
which case it will
+   return the union of the recursive calls on its children, as long as the 
other predicate
+   `can_contain_computations` evaluates to true to let the algorithm recurse 
deeper.
+   With typical predicates, on the expression ((w+x)+(y+z)) it will return 
only the expression
+   itself. But on the expression Load[i1+i2] it might return only (i1+i2) as 
the full Load node
+   might not be eligible.
+
+   This class uses an internal cache of results, so that if one queries it 
several times on the
+   same statement or expression, it will just retrieve the result from its 
internal cache.
+   That avoids some systematic recomputations, which would otherwise happen as 
the CSE pass first
+   analyses the program at the toplovel (asking for the computations done by 
the root), and then
+   dives deeper and deeper into the program, asking for the computations done 
by the children of
+   the root, which were necessarly previously obtained when computing the 
computations done by the
+   root (as the computations done by the root are by definition the union of 
the computations done
+   by the children nodes).
+
+   The somehow difficult aspect of the implementation is the interaction 
between this caching of
+   results, and the fact that the VisitStmt()/VisitExpr() of an analyzer (a 
StmtExprVisitor) are
+   void methods which can't return anything, and instead need to accumulate a 
result into a member
+   variable, which is called `table_of_computations_` here.
+
+   In particular, as the specialized methods (all the VisitStmt_() and 
VisitExpr_() methods), just
+   call VisitStmt()/VisitExpr() on all the children nodes within the same 
instance, if we don't
+   want to override each of these specialized methods to change this 
behaviour, then
+   `table_of_computations_` will necessary be shared by all the children of a 
given nodes.
+   That requires to be careful when trying to write into the cache.
+*/
+
+/*!
+ * \brief Does the union of two tables of computations.
+ * \param table_main One of the two tables. The union will be written into it.
+ * \param table_aux The other table, which won't change.
+ * \note Does it directly in the first argument A for efficiency, as the union 
of A and B
+ *       necessarily gives something which contains A, so we avoid its copy.
+ */
+void UnionOf2TablesOfComputations(TableOfComputations& table_main,
+                                  const TableOfComputations& table_aux) {
+  // Adds each element of the second table to the first one
+  for (const auto& current : table_aux) {
+    table_main[current.first] += current.second;
+  }
+}
+
+/*!
+ * \brief Does the union of three tables of computations.
+ * \param table1 One of the three tables, which won't change.
+ * \param table2 One of the three tables, which won't change.
+ * \param table3 One of the three tables, which won't change.
+ * \note We don't need (at least yet) to have a function working for N tables, 
even if this
+ *       function for 3 tables seems at first glance redundant with the one 
for 2 tables defined
+ *       just above. The reason is that in order to do the union for N tables, 
we need to know how
+ *       to do it for two. That's because we would compute for N tables using 
the associativity
+ *       of the union : T1 U T2 U T3 ... U Tn = ((T1 U T2) U T3) ... U Tn
+ *       Therefore, we need one for 2 tables anyway. And as the one for N 
(N>=3) would only be used
+ *       (at least for now) for N=3, there is at the moment no need for such a 
generic union over
+ *       N tables.
+ */
+TableOfComputations UnionOf3TablesOfComputations(const TableOfComputations& 
table1,
+            const TableOfComputations& table2, const TableOfComputations& 
table3) {
+  TableOfComputations result = table1; // Copy needed as the union of 2 writes 
into its first arg
+  UnionOf2TablesOfComputations(result, table2);
+  UnionOf2TablesOfComputations(result, table3);
+
+  return result;
+}
+
+/*!
+ * \brief Does the intersection of two tables of computations.
+ * \param table1 One of the two tables, which won't change.
+ * \param table2 The other table, which also won't change.
+ */
+TableOfComputations IntersectionOf2TablesOfComputations(const 
TableOfComputations& table1,
+                                                        const 
TableOfComputations& table2) {
+  TableOfComputations result;
+  for (const auto& current : table1) {
+    auto it = table2.find(current.first);
+    if (it != table2.end()) {
+      result[current.first] = current.second + it->second;
+    }
+  }
+  return result;
+}
+
+/*!
+ * \brief Does the intersection of three tables of computations.
+ * \param table1 One of the three tables, which won't change.
+ * \param table2 One of the three tables, which won't change.
+ * \param table3 One of the three tables, which won't change.
+ * \note We don't need (at least yet) to have a function working for N tables, 
even if this

Review comment:
       Don't need to repeat this explanation




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to