masahi commented on a change in pull request #9482: URL: https://github.com/apache/tvm/pull/9482#discussion_r790201656
########## File path: src/tir/transforms/common_subexpr_elim_tools.cc ########## @@ -0,0 +1,836 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! +* \file common_subexpr_elim_tools.cc +* \brief Implementation of analysis tools and utility functions used + by the Common Subexpression Elimination (CSE) pass. +*/ + +#include "common_subexpr_elim_tools.h" + +#include <tvm/ir/transform.h> // For the class Pass and the class PassContext +#include <tvm/runtime/container/string.h> +#include <tvm/tir/analysis.h> // For the ExprDeepEqual analysis +#include <tvm/tir/expr.h> +#include <tvm/tir/expr_functor.h> +#include <tvm/tir/function.h> // For the class PrimFunc +#include <tvm/tir/stmt.h> +#include <tvm/tir/stmt_functor.h> +#include <tvm/tir/transform.h> // For the declaration of the pass + +#include <algorithm> // For std::find_if +#include <unordered_map> // For the hashtable datatype +#include <vector> + +#include "../analysis/check_contains.h" // For the CheckContains analysis + +namespace tvm { +namespace tir { + +// cache_ is a static variable of the class ComputationsDoneBy, and C++ requires to define here +// such static attribute, otherwise it causes a linking error. +CacheOfComputations ComputationsDoneBy::cache_; + +/* ********************************** Class ComputationsDoneBy ********************************** +*********************************************************************************************** */ + +/* This utility class of the CSE pass offers a way of knowing the eligible computations done by a + statement or expression. A "computation" here is a syntatical entity, represented by a PrimExpr. + This analysis returns a hashtable associating PrimExpr (a computation done) to a number (which + is the number of time that this computation is being seen). + This analysis is used by the CSE pass in order to find potential candidates for being introduced + into new variables (after having merged semantically equivalent computations). + + This analysis is parametrized by two predicates : `is_eligible_computation` and + `can_contain_computations`. + The first one helps to select only "eligible" computations, and the second one helps to only + select computations that are located at appropriate location (i.e., it tells in which nodes the + analysis can recurse). The user of the class must define these notions of "eligible computation" + and of "nodes that can contain eligibile computations" for his own use case. + + - On an statement, this analysis often returns the union of all the computations that appear in + its child nodes (ie, the union of the results of the recursive calls). + For instance, on the input statement [let a = x+y in Mem[i1+i2] = a+b] it will report (x+y) + seen once, (i1+i2) seen once, and (a+b) also seen once when used with typical predicates. + On some nodes, it will return something more complicated that uses the intersection of the + computations done by the children nodes. + For instance, on the input statement [if (x+y>z) then a = x+y else a = b-x] it will return + (x+y) seen twice but it won't report b-x as is it seen only the else branch. + + - On an expression, this analysis returns the expression itself, except if it is not eligible + for being introduced by the CSE pass into a variable according to `is_eligible_computation_` + (often because it's a load node or a function call node for instance), in which case it will + return the union of the recursive calls on its children, as long as the other predicate + `can_contain_computations` evaluates to true to let the algorithm recurse deeper. + With typical predicates, on the expression ((w+x)+(y+z)) it will return only the expression + itself. But on the expression Load[i1+i2] it might return only (i1+i2) as the full Load node + might not be eligible. + + This class uses an internal cache of results, so that if one queries it several times on the + same statement or expression, it will just retrieve the result from its internal cache. + That avoids some systematic recomputations, which would otherwise happen as the CSE pass first + analyses the program at the toplovel (asking for the computations done by the root), and then + dives deeper and deeper into the program, asking for the computations done by the children of + the root, which were necessarly previously obtained when computing the computations done by the + root (as the computations done by the root are by definition the union of the computations done + by the children nodes). + + The somehow difficult aspect of the implementation is the interaction between this caching of + results, and the fact that the VisitStmt()/VisitExpr() of an analyzer (a StmtExprVisitor) are + void methods which can't return anything, and instead need to accumulate a result into a member + variable, which is called `table_of_computations_` here. + + In particular, as the specialized methods (all the VisitStmt_() and VisitExpr_() methods), just + call VisitStmt()/VisitExpr() on all the children nodes within the same instance, if we don't + want to override each of these specialized methods to change this behaviour, then + `table_of_computations_` will necessary be shared by all the children of a given nodes. + That requires to be careful when trying to write into the cache. +*/ + +/*! + * \brief Does the union of two tables of computations. + * \param table_main One of the two tables. The union will be written into it. + * \param table_aux The other table, which won't change. + * \note Does it directly in the first argument A for efficiency, as the union of A and B + * necessarily gives something which contains A, so we avoid its copy. + */ +void UnionOf2TablesOfComputations(TableOfComputations& table_main, + const TableOfComputations& table_aux) { + // Adds each element of the second table to the first one + for (const auto& current : table_aux) { + table_main[current.first] += current.second; + } +} + +/*! + * \brief Does the union of three tables of computations. + * \param table1 One of the three tables, which won't change. + * \param table2 One of the three tables, which won't change. + * \param table3 One of the three tables, which won't change. + * \note We don't need (at least yet) to have a function working for N tables, even if this + * function for 3 tables seems at first glance redundant with the one for 2 tables defined + * just above. The reason is that in order to do the union for N tables, we need to know how + * to do it for two. That's because we would compute for N tables using the associativity + * of the union : T1 U T2 U T3 ... U Tn = ((T1 U T2) U T3) ... U Tn + * Therefore, we need one for 2 tables anyway. And as the one for N (N>=3) would only be used + * (at least for now) for N=3, there is at the moment no need for such a generic union over + * N tables. + */ +TableOfComputations UnionOf3TablesOfComputations(const TableOfComputations& table1, + const TableOfComputations& table2, const TableOfComputations& table3) { + TableOfComputations result = table1; // Copy needed as the union of 2 writes into its first arg + UnionOf2TablesOfComputations(result, table2); + UnionOf2TablesOfComputations(result, table3); + + return result; +} + +/*! + * \brief Does the intersection of two tables of computations. + * \param table1 One of the two tables, which won't change. + * \param table2 The other table, which also won't change. + */ +TableOfComputations IntersectionOf2TablesOfComputations(const TableOfComputations& table1, + const TableOfComputations& table2) { + TableOfComputations result; + for (const auto& current : table1) { + auto it = table2.find(current.first); + if (it != table2.end()) { + result[current.first] = current.second + it->second; + } + } + return result; +} + +/*! + * \brief Does the intersection of three tables of computations. + * \param table1 One of the three tables, which won't change. + * \param table2 One of the three tables, which won't change. + * \param table3 One of the three tables, which won't change. + * \note We don't need (at least yet) to have a function working for N tables, even if this Review comment: Don't need to repeat this explanation -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
