slyubomirsky commented on code in PR #14361: URL: https://github.com/apache/tvm/pull/14361#discussion_r1144153247
########## src/relax/transform/eliminate_common_subexpr.cc: ########## @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/transform/eliminate_common_subexpr.cc + * \brief Eliminrate common subexpression pass. + * + * Currently it removes common subexpressions within a DataflowBlock. + */ +#include <tvm/relax/expr_functor.h> +#include <tvm/relax/transform.h> + +namespace tvm { +namespace relax { + +class SubexprCounter : public ExprVisitor { + public: + // overriding VisitExpr ensures we do this for every subexpression + void VisitExpr(const Expr& e) override { + // Cases we ignore because we will not substitute them: + // 1. Vars of all kinds + // 2. Op nodes (nothing we can do) + // 3. Scalar constants (not much benefit from binding to a var) + if (!(e->IsInstance<VarNode>() || e->IsInstance<DataflowVarNode>() || + e->IsInstance<GlobalVarNode>() || e->IsInstance<tvm::OpNode>() || + (e.as<ConstantNode>() && (e.as<ConstantNode>()->is_scalar())))) { + int count = 0; + if (count_map_.count(e)) { + count = count_map_.at(e); + } + count_map_[e] = count + 1; + } + ExprVisitor::VisitExpr(e); + } + + // do not visit inner functions: we will do CSE within those + void VisitExpr_(const FunctionNode* func) override {} + + // we are not going to do replacements inside struct info to avoid binding lots of reused shapes + void VisitExprDepStructInfoField(const StructInfo& struct_info) override {} + + std::unordered_map<Expr, int, StructuralHash, StructuralEqual> Count( + const DataflowBlock& df_block) { + for (auto binding : df_block->bindings) { + VisitBinding(binding); + } + return count_map_; + } + + private: + std::unordered_map<Expr, int, StructuralHash, StructuralEqual> count_map_; +}; + +// forward declaration +DataflowBlock EliminateCommonSubexpr(const DataflowBlock&); + +class CommonSubexprEliminator : public ExprMutator { + public: + explicit CommonSubexprEliminator( + const std::unordered_map<Expr, int, StructuralHash, StructuralEqual>& count_map) + : count_map_(count_map) {} + + // overriding here ensures we visit every subexpression + Expr VisitExpr(const Expr& e) override { + if (count_map_.count(e) && count_map_.at(e) > 1) { + // if we already have a mapping for it, get it + if (replacements_.count(e)) { + return replacements_.at(e); + } + // Otherwise, insert a new binding for the current expression. + // Visit before emitting to do inner replacements + Expr new_e = ExprMutator::VisitExpr(e); + Var v = builder_->Emit(new_e); + replacements_[e] = v; + return v; + } + return ExprMutator::VisitExpr(e); + } + + // we are not going to do replacements inside struct info to avoid binding lots of reused shapes + StructInfo VisitExprDepStructInfoField(const StructInfo& struct_info) override { + return struct_info; + } + + Expr VisitExpr_(const FunctionNode* func) override { + // for an inner function, we will do CSE on its body + Expr new_body = ExprMutator::VisitExpr(func->body); + if (new_body.same_as(func->body)) { + return GetRef<Expr>(func); + } + return Function(func->params, new_body, func->ret_struct_info, func->attrs, func->span); + } + + // this should happen only for the inner function case + Expr VisitExpr_(const SeqExprNode* seq) override { + bool all_unchanged = true; + Array<BindingBlock> new_blocks; + // apply CSE within dataflow blocks only + for (auto block : seq->blocks) { + if (const DataflowBlockNode* df_block = block.as<DataflowBlockNode>()) { + auto new_df_block = EliminateCommonSubexpr(GetRef<DataflowBlock>(df_block)); + if (!new_df_block.same_as(block)) { + new_blocks.push_back(new_df_block); + all_unchanged = false; + continue; + } + } + new_blocks.push_back(block); + } + + if (all_unchanged) { + return GetRef<Expr>(seq); + } + // do not visit the body + return SeqExpr(new_blocks, seq->body, seq->span); + } Review Comment: Update: Based on the Unity Community Meeting discussion, it doesn't sound like there is much appetite for imposing phase orderings like this, so I would be interested instead if there is a clean way to deal with local functions in dataflow block passes (generalizing the approach shown here, for example). I am sure that other dataflow block passes don't handle the local function case and might exhibit strange bugs if given a program with local functions -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
