This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch unity-staging in repository https://gitbox.apache.org/repos/asf/tvm.git
commit b3cf43af34e1d90b3fc841c2c2a12d36073a6159 Author: Jiawei Liu <[email protected]> AuthorDate: Sat Feb 18 18:13:57 2023 -0600 [Unity] Statement rewriter for DataflowBlock (#14043) This PR implements a few APIs to quickly perform statement-level mutation: `add`/`remove_unused`/`remove_all_unused`/`replace_all_uses`. It also implements `remove_all_unused` to remove dead statements inside `DataflowBlock`. --- include/tvm/relax/analysis.h | 24 +++ include/tvm/relax/binding_rewrite.h | 115 ++++++++++ include/tvm/relax/utils.h | 1 + python/tvm/relax/analysis/analysis.py | 23 +- python/tvm/relax/binding_rewrite.py | 155 +++++++++++++ src/relax/ir/binding_rewrite.cc | 324 ++++++++++++++++++++++++++++ tests/python/relax/test_analysis.py | 118 +++++++++- tests/python/relax/test_binding_rewrite.py | 334 +++++++++++++++++++++++++++++ 8 files changed, 1092 insertions(+), 2 deletions(-) diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h index 32e1582134..b9866577e9 100644 --- a/include/tvm/relax/analysis.h +++ b/include/tvm/relax/analysis.h @@ -341,6 +341,14 @@ TVM_DLL Map<Var, Expr> AnalyzeVar2Value(const Expr& expr); */ TVM_DLL Map<Var, Expr> AnalyzeVar2Value(const DataflowBlock& dfb); +/*! + * \brief Return a mapping from variable name to its Bindings. + * + * \param fn The function to be analyzed. + * \return A mapping from variable name to its Bindings. + */ +TVM_DLL Map<String, Array<Binding>> NameToBinding(const Function& fn); + /*! * \brief Get the use-def chain of variables inside a dataflow block. * @@ -349,6 +357,22 @@ TVM_DLL Map<Var, Expr> AnalyzeVar2Value(const DataflowBlock& dfb); */ TVM_DLL Map<Var, Array<Var>> DataflowBlockUseDef(const DataflowBlock& dfb); +/*! + * \brief Get the use-def chain of variables inside a function. + * + * \param fn The function to be analyzed. + * \return A map from variable definitions to a set of uses and variables needed by return value. + */ +std::pair<Map<Var, Array<Var>>, Array<Var>> FunctionUseDef(const Function& fn); + +/*! + * \brief Remove unused statements inside DataflowBlocks. + * + * \param fn The function to remove unused statements. + * \return The function that contains no unused statements in DataflowBlock. + */ +TVM_DLL Function RemoveAllUnused(const Function fn); + /*! * \brief Annotate Op Pattern Kind for PrimFunc, which is used in relax FuseOps. * diff --git a/include/tvm/relax/binding_rewrite.h b/include/tvm/relax/binding_rewrite.h new file mode 100644 index 0000000000..a4b534965a --- /dev/null +++ b/include/tvm/relax/binding_rewrite.h @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/binding_rewrite.h + * \brief An IR rewriter to easily add/remove/replace bindings (statements). + */ + +#ifndef TVM_RELAX_BINDING_REWRITE_H_ + +#include <tvm/relax/analysis.h> +#include <tvm/relax/expr.h> +#include <tvm/relax/utils.h> + +#include <map> +#include <set> +#include <type_traits> +#include <utility> +#include <vector> + +namespace tvm { +namespace relax { + +/*! \brief Statement rewriter for relax.DataflowBlock. */ +class DataflowBlockRewriteNode : public Object { + public: + /*! \brief Replace all uses of old_var with new_var. */ + void ReplaceAllUses(Var old_var, Var new_var); + /*! \brief Insert a Binding statement. */ + void Add(Binding binding); + /*! \brief Insert an expression as VarBinding with variable name. */ + void Add(String var_name, Expr expr, bool is_dfvar = false) { + auto var = is_dfvar ? DataflowVar(var_name, GetStructInfo(expr)) // + : Var(var_name, GetStructInfo(expr)); + Add(VarBinding(std::move(var), std::move(expr))); + } + /*! \brief Insert an expression as VarBinding with automatic variable name. */ + void Add(Expr expr, bool is_dfvar = false) { + Add(name_table_.GetUniqueName("tmp"), expr, is_dfvar); + } + /*! \brief Remove the definition statement of an unused variable. */ + void RemoveUnused(Var unused, bool allow_undef = false); + /*! \brief Remove the definition statements of all unused variables. */ + void RemoveAllUnused(); + + /*! \brief The rewritten dataflow block. */ + DataflowBlock MutatedDataflowBlock() { return dfb_.value(); } + /*! \brief The rewritten function. */ + Function MutatedFunc() { return root_fn_.value(); } + /*! \brief The rewritten IRModule. */ + IRModule MutateIRModule(IRModule irmod); + + /*! \brief Visit attributes. */ + void VisitAttrs(AttrVisitor* v) { + v->Visit("dfb", &dfb_); + v->Visit("root_fn", &root_fn_); + } + + static constexpr const char* _type_key = "relax.DataflowBlockRewrite"; + TVM_DECLARE_FINAL_OBJECT_INFO(DataflowBlockRewriteNode, Object); + + protected: + friend class DataflowBlockRewrite; + + Optional<DataflowBlock> dfb_; //!< The rewritten dataflow block. + Optional<Function> root_fn_; //!< The rewritten function. + const FunctionNode* original_fn_ptr_; //!< Pointer to the original function. + Map<Var, Array<Var>> to_users_; //!< Map from variable to its users. + Array<Var> fn_outputs_; //!< Variables required by function outputs. + + private: + NameTable name_table_; //!< Name table for tracking and generating unique names. +}; + +/*! + * \brief A statement rewriter for relax.DataflowBlock. + * \sa DataflowBlockRewriteNode + */ +class DataflowBlockRewrite : public ObjectRef { + public: + TVM_DLL explicit DataflowBlockRewrite(DataflowBlock dfb, Function root_fn); + + /*! + * \brief mutable accessor. + * \return mutable access pointer. + */ + DataflowBlockRewriteNode* operator->() { + ICHECK(get() != nullptr); + return static_cast<DataflowBlockRewriteNode*>(get_mutable()); + } + + TVM_DEFINE_OBJECT_REF_METHODS(DataflowBlockRewrite, ObjectRef, DataflowBlockRewriteNode); +}; + +} // namespace relax +} // namespace tvm + +#define TVM_RELAX_BINDING_REWRITE_H_ +#endif // TVM_RELAX_BINDING_REWRITE_H_ diff --git a/include/tvm/relax/utils.h b/include/tvm/relax/utils.h index 1457a16427..c1d984a21a 100644 --- a/include/tvm/relax/utils.h +++ b/include/tvm/relax/utils.h @@ -25,6 +25,7 @@ #define TVM_RELAX_UTILS_H_ #include <tvm/ir/module.h> +#include <tvm/relax/expr.h> #include <tvm/runtime/logging.h> #include <algorithm> diff --git a/python/tvm/relax/analysis/analysis.py b/python/tvm/relax/analysis/analysis.py index 45c5b6f962..ffcdaceb40 100644 --- a/python/tvm/relax/analysis/analysis.py +++ b/python/tvm/relax/analysis/analysis.py @@ -28,7 +28,7 @@ from tvm import tir from tvm import IRModule from tvm.relax.ty import Type from tvm.relax.struct_info import StructInfo, FuncStructInfo -from tvm.relax.expr import DataflowBlock, Var, Expr, Function, Call +from tvm.relax.expr import DataflowBlock, Var, Expr, Function, Call, Binding from . import _ffi_api @@ -244,6 +244,27 @@ def udchain(dfb: DataflowBlock) -> Dict[Var, List[Var]]: return _ffi_api.udchain(dfb) # type: ignore +def name_to_binding(func: Function) -> Dict[str, List[Binding]]: + """Return a map from variable name to its bindings.""" + return _ffi_api.name_to_binding(func) # type: ignore + + +def remove_all_unused(func: Function) -> Function: + """Remove all unused variables from the function. + + Parameters + ---------- + func : Function + The input function to be analyzed. + + Returns + ------- + Function + The function with unused variables removed. + """ + return _ffi_api.remove_all_unused(func) # type: ignore + + def well_formed(mod: IRModule, check_struct_info: bool = True) -> bool: """Check if the IRModule is well formed. diff --git a/python/tvm/relax/binding_rewrite.py b/python/tvm/relax/binding_rewrite.py new file mode 100644 index 0000000000..a9f6d878ad --- /dev/null +++ b/python/tvm/relax/binding_rewrite.py @@ -0,0 +1,155 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=no-else-return, invalid-name +"""Developer API of add/remove/replace bindings in Relax.""" + +from typing import Optional + +import tvm +import tvm._ffi +from tvm.runtime import Object +from . import Binding, DataflowBlock, Expr, Function, Var +from . import _ffi_api + + +@tvm._ffi.register_object("relax.DataflowBlockRewrite") +class DataflowBlockRewrite(Object): + """ + A binding/statement-level dataflow block rewriter. + + Notes + ----- + Due to the immutable and copy-on-write nature of TVM AST nodes, the rewriting is not done in + place. Instead, a new DataflowBlock is created and returned with mutated_dfb. Similarly, its new + root Function is created and returned by mutated_root_fn. To apply this change for an IRModule, + use mutate_irmodule which rewrites the old function that registered in the constructor. + """ + + def __init__(self, dfb: DataflowBlock, root_fn: Function): + """ + Construct a rewriter with the DataflowBlock to rewrite and its root function. + + Parameters + ---------- + dfb : DataflowBlock + The DataflowBlock to rewrite. + root_fn : Function + The root function of the DataflowBlock. + """ + self.func_name = root_fn.__name__ if hasattr(root_fn, "__name__") else None + self.__init_handle_by_constructor__( + _ffi_api.DataflowBlockRewrite, dfb, root_fn # type: ignore + ) + + def replace_all_uses(self, old_var: Var, new_var: Var) -> None: + """ + Replace all uses of old_var with new_var. + + Parameters + ---------- + old_var : Var + The old variable to replace. + new_var : Var + The new variable to replace with. + """ + _ffi_api.dfb_rewrite_replace_all_uses(self, old_var, new_var) # type: ignore + + def add_binding(self, binding: Binding) -> None: + return _ffi_api.dfb_rewrite_add_binding(self, binding) # type: ignore + + def add(self, expr: Expr, name: Optional[str] = None, is_dfvar: bool = False) -> None: + """ + Add a new statement to the DataflowBlock with an automatically generated variable name. + + Parameters + ---------- + expr : Expr + The expression to add. + name : Optional[str], optional + Variable name, by default None + is_dfvar : bool, optional + The variable type, by default False + + Notes + ----- + If the variable name is not given, it will be automatically generated in a form of + "tmp${COUNTER}". The variable type will be DataflowVar if is_dfvar is True, otherwise + it will be Var. Being Var means the variables are output variables of the DataflowBlock. + While being DataflowVar means the variables are internal variables of the DataflowBlock. + """ + _ffi_api.dfb_rewrite_add(self, expr, name, is_dfvar) # type: ignore + + def remove_unused(self, var: Var, allow_undef=False) -> None: + """ + Remove a statement by its variable definition if and only if it is unused. + + Parameters + ---------- + var : Var + The unused variable definition. + allow_undef : bool, optional + Whether to allow var being undefined variable, by default False + + Raises + ------ + TVMError if the variable is used or undefined (allow_undef=False). + """ + _ffi_api.dfb_rewrite_remove_unused(self, var, allow_undef) # type: ignore + + def remove_all_unused(self) -> None: + """ + Remove all unused variables. + + Notes + ----- + This could remove unused variables in other DataflowBlocks as well. + """ + _ffi_api.dfb_rewrite_remove_all_unused(self) # type: ignore + + def mutated_dfb(self) -> DataflowBlock: + """ + Returns the mutated DataflowBlock. + """ + return self.dfb + + def mutated_root_fn(self) -> Function: + """ + Returns the mutated root function. + """ + ret = self.root_fn + if self.func_name: + ret.__name__ = self.func_name + return ret + + def mutate_irmodule(self, irmodule: tvm.IRModule) -> tvm.IRModule: + """ + Return an updated IRModule by replacing the old function with the mutated root function. + + Parameters + ---------- + irmodule : tvm.IRModule + The base IRModule to update. + + Returns + ------- + tvm.IRModule + The updated IRModule. + """ + ret = _ffi_api.dfb_rewrite_mutate_irmodule(self, irmodule) # type: ignore + if hasattr(irmodule, "__name__"): + ret.__name__ = irmodule.__name__ + return ret diff --git a/src/relax/ir/binding_rewrite.cc b/src/relax/ir/binding_rewrite.cc new file mode 100644 index 0000000000..dd9fac9fdc --- /dev/null +++ b/src/relax/ir/binding_rewrite.cc @@ -0,0 +1,324 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/ir/binding_rewrite.cc + * \brief Implementation of binding rewriters. + */ + +#include <tvm/relax/binding_rewrite.h> +#include <tvm/relax/block_builder.h> +#include <tvm/relax/expr.h> +#include <tvm/relax/expr_functor.h> + +#include <functional> +#include <iterator> + +namespace tvm { +namespace relax { + +TVM_REGISTER_NODE_TYPE(DataflowBlockRewriteNode); +DataflowBlockRewrite::DataflowBlockRewrite(DataflowBlock dfb, Function root_fn) { + auto n = make_object<DataflowBlockRewriteNode>(); + n->dfb_ = dfb; + n->root_fn_ = root_fn; + n->original_fn_ptr_ = root_fn.get(); + auto p = FunctionUseDef(root_fn); + n->to_users_ = std::move(p.first); + n->fn_outputs_ = std::move(p.second); + n->name_table_ = NameTable(n->to_users_.begin(), n->to_users_.end(), + [](const auto& p) { return p.first->name_hint(); }); + + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.DataflowBlockRewrite") + .set_body_typed([](DataflowBlock dfb, Function root_fn) { + return DataflowBlockRewrite(dfb, root_fn); + }); + +void DataflowBlockRewriteNode::ReplaceAllUses(Var old_var, Var new_var) { + class ReplaceAllUsePass : public ExprMutator { + Var old_var, new_var; + const DataflowBlockNode* const to_catch; + + public: + const DataflowBlockNode* caught = nullptr; + + ReplaceAllUsePass(Var old_var, Var new_var, const DataflowBlockNode* to_catch) + : old_var(old_var), new_var(new_var), to_catch(to_catch) {} + + using ExprMutator::VisitExpr_; + + Expr VisitExpr_(const VarNode* op) override { + return (op == old_var.get()) ? new_var : GetRef<Expr>(op); + } + + Expr VisitExpr_(const DataflowVarNode* op) override { + return (op == old_var.get()) ? new_var : GetRef<Expr>(op); + } + + BindingBlock VisitBindingBlock_(const DataflowBlockNode* op) override { + BindingBlock res = ExprMutator::VisitBindingBlock_(op); + if (op == to_catch) caught = static_cast<const DataflowBlockNode*>(res.get()); + return res; + } + }; + + ICHECK(to_users_.find(old_var) != to_users_.end()) << "Cannot find " << old_var; + ICHECK(to_users_.find(new_var) != to_users_.end()) << "Cannot find " << new_var; + + // replace uses inside the DataflowBlock. + ReplaceAllUsePass replacer(old_var, new_var, dfb_.get()); + root_fn_ = Downcast<Function>(replacer.VisitExpr_(root_fn_.get())); + dfb_ = GetRef<DataflowBlock>(replacer.caught); + + // update udchain + // old_var -> old_var users | changed to {} + // new_var -> {?} | changed to old_var users + for (Var user : to_users_[old_var]) { + auto new_var_uses = to_users_[new_var]; + if (new_var_uses.end() == std::find(new_var_uses.begin(), new_var_uses.end(), user)) { + new_var_uses.push_back(user); + } + } + + to_users_.Set(old_var, {}); + + auto it_old_output = std::find(fn_outputs_.begin(), fn_outputs_.end(), old_var); + if (it_old_output != fn_outputs_.end()) { + fn_outputs_.Set(std::distance(fn_outputs_.begin(), it_old_output), new_var); + } +} + +TVM_REGISTER_GLOBAL("relax.dfb_rewrite_replace_all_uses") + .set_body_typed([](DataflowBlockRewrite rwt, Var old_var, Var new_var) { + rwt->ReplaceAllUses(old_var, new_var); + }); + +class UpdateDFB : public ExprMutator { + private: + DataflowBlock old_dfb, new_dfb; + + public: + UpdateDFB(DataflowBlock old_dfb, DataflowBlock new_dfb) + : old_dfb(std::move(old_dfb)), new_dfb(std::move(new_dfb)) {} + + BindingBlock VisitBindingBlock_(const DataflowBlockNode* op) override { + return old_dfb.get() == op ? new_dfb : old_dfb; + } +}; + +void DataflowBlockRewriteNode::Add(Binding binding) { + auto p = [binding] { + if (auto vb = binding.as<VarBindingNode>()) { + return std::make_pair(vb->var, vb->value); + } else if (auto mc = binding.as<MatchCastNode>()) { + return std::make_pair(mc->var, mc->value); + } + LOG(FATAL) << "Unsupported binding type"; + return std::make_pair(Var{}, Expr{}); + }(); + Var var = p.first; + Expr val = p.second; + + ICHECK(0 == to_users_.count(var)) << var << " has been defined so cannot be added."; + + // Add this VarBinding statement after the definition of uses. + std::set<const VarNode*> used_vars = [val] { + class UsedVars : public ExprVisitor { + public: + std::set<const VarNode*> used_vars; + void VisitExpr_(const VarNode* op) override { used_vars.insert(op); } + void VisitExpr_(const DataflowVarNode* op) override { used_vars.insert(op); } + } uvar{}; + uvar.VisitExpr(val); + return std::move(uvar.used_vars); + }(); + + size_t line_last_req_def = 0; + for (size_t i = 0; i < dfb_.value()->bindings.size(); ++i) { + auto line = dfb_.value()->bindings[i]; + if (used_vars.find(line->var.get()) != used_vars.cend()) line_last_req_def = i; + } + + auto old_dfb = dfb_.value(); + + dfb_ = [old_dfb, binding, line_last_req_def, this] { + auto new_dfb = dfb_.value(); + new_dfb.CopyOnWrite()->bindings.insert(dfb_.value()->bindings.begin() + 1 + line_last_req_def, + binding); + return new_dfb; + }(); + + auto updater = UpdateDFB(old_dfb, dfb_.value()); + root_fn_ = Downcast<Function>(updater.VisitExpr_(root_fn_.get())); + + for (const VarNode* v : used_vars) to_users_.Get(GetRef<Var>(v)).value().push_back(var); +} + +TVM_REGISTER_GLOBAL("relax.dfb_rewrite_add_binding") + .set_body_typed([](DataflowBlockRewrite rwt, Binding vb) { rwt->Add(vb); }); + +TVM_REGISTER_GLOBAL("relax.dfb_rewrite_add") + .set_body_typed([](DataflowBlockRewrite rwt, Expr expr, Optional<String> name, bool is_dfvar) { + if (name.get()) { + rwt->Add(name.value(), expr, is_dfvar); + } else { + rwt->Add(expr, is_dfvar); + } + }); + +class RemoveUnusedVars : public ExprMutator { + public: + std::set<Var> unused_vars; + Optional<DataflowBlock> caught_rewrite = NullOpt; + + RemoveUnusedVars(Map<Var, Array<Var>> users, Array<Var> fn_outputs) + : unused_vars([&] { + std::vector<Var> unused; + + // iterative dataflow algorithm. + size_t prev_size; + do { + prev_size = unused.size(); + + std::vector<Var> used; + used.reserve(users.size()); + for (const auto& kv : users) { + // var -> [users...] + // var is unused iff + // user -> empty + // var is not output var + if (kv.second.empty() && // kv.first is not used by fn outputs. + fn_outputs.end() == std::find(fn_outputs.begin(), fn_outputs.end(), kv.first)) { + unused.push_back(kv.first); + } else { + used.push_back(kv.first); + } + } + + for (size_t i = prev_size; i < unused.size(); ++i) { + users.erase(unused[i]); + // remove def site. + for (const auto& used_var : used) { + ICHECK(users.count(used_var)); + Array<Var> var_users = users[used_var]; + // remove the unused var from the use site. + auto it = std::find(var_users.begin(), var_users.end(), unused[i]); + if (it != var_users.end()) { + var_users.erase(it); + users.Set(used_var, std::move(var_users)); + } + } + } + } while (prev_size != unused.size()); // changed? => continue. + + return std::set<Var>(unused.begin(), unused.end()); + }()) {} + + RemoveUnusedVars(std::pair<Map<Var, Array<Var>>, Array<Var>> users_and_outputs) + : RemoveUnusedVars(std::move(users_and_outputs.first), std::move(users_and_outputs.second)) {} + RemoveUnusedVars(Function fn) : RemoveUnusedVars(FunctionUseDef(fn)) {} + RemoveUnusedVars(std::set<Var> unused_vars) : unused_vars(std::move(unused_vars)) {} + + BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) { + auto prev_dfb = GetRef<DataflowBlock>(block); + builder_->BeginDataflowBlock(); + for (Binding binding : block->bindings) { + if (!unused_vars.count(binding->var)) { + VisitBinding(binding); + } + } + auto new_dfb = builder_->EndBlock(); + if (caught_rewrite == prev_dfb) caught_rewrite = Downcast<DataflowBlock>(new_dfb); + return std::move(new_dfb); + } +}; + +void DataflowBlockRewriteNode::RemoveUnused(Var unused, bool allow_undef) { + // first need to check if this var is used. + if (0 == to_users_.count(unused)) { // no def. + if (allow_undef) return; + LOG(FATAL) << unused << " undefined. Set allow_undef=True to allow 'removing' undefined var"; + } + + ICHECK(to_users_[unused].empty()) + << unused << " is used by " << to_users_[unused].size() << " vars"; + + auto old_dfb = dfb_.value(); + + RemoveUnusedVars remover({unused}); + dfb_ = Downcast<DataflowBlock>(remover.VisitBindingBlock_(old_dfb.get())); + + auto updater = UpdateDFB(old_dfb, dfb_.value()); + root_fn_ = Downcast<Function>(updater.VisitExpr_(root_fn_.get())); + + to_users_.erase(unused); // update use-def chain. +} + +TVM_REGISTER_GLOBAL("relax.dfb_rewrite_remove_unused") + .set_body_typed([](DataflowBlockRewrite rwt, Var unused, bool allow_undef) { + rwt->RemoveUnused(unused, allow_undef); + }); + +void DataflowBlockRewriteNode::RemoveAllUnused() { + RemoveUnusedVars remover(to_users_, fn_outputs_); + remover.caught_rewrite = dfb_.value(); + + // this could also clean unused variables in other DataflowBlock. + root_fn_ = Downcast<Function>(remover.VisitExpr_(root_fn_.get())); + + // DataflowBlock could be None. + dfb_ = remover.caught_rewrite.value(); + + // clean up use-def chain. + for (const auto& unused : remover.unused_vars) to_users_.erase(unused); +} + +TVM_REGISTER_GLOBAL("relax.dfb_rewrite_remove_all_unused") + .set_body_typed([](DataflowBlockRewrite rwt) { rwt->RemoveAllUnused(); }); + +Function RemoveAllUnused(Function fn) { + RemoveUnusedVars remover(fn); + return Downcast<Function>(remover.VisitExpr_(fn.get())); +} + +TVM_REGISTER_GLOBAL("relax.analysis.remove_all_unused").set_body_typed(RemoveAllUnused); + +IRModule DataflowBlockRewriteNode::MutateIRModule(IRModule irmod) { + BlockBuilder builder = BlockBuilder::Create(irmod); + + for (auto& p : irmod->functions) { + if (original_fn_ptr_ == p.second.get()) { + builder->UpdateFunction(p.first, root_fn_.value()); + break; + } + } + + return builder->GetContextIRModule(); +} + +TVM_REGISTER_GLOBAL("relax.dfb_rewrite_mutate_irmodule") + .set_body_typed([](DataflowBlockRewrite rwt, IRModule irmod) { + return rwt->MutateIRModule(irmod); + }); + +} // namespace relax +} // namespace tvm diff --git a/tests/python/relax/test_analysis.py b/tests/python/relax/test_analysis.py index 43558a52be..e939d2b208 100644 --- a/tests/python/relax/test_analysis.py +++ b/tests/python/relax/test_analysis.py @@ -21,7 +21,7 @@ import tvm import tvm.testing from tvm import tir from tvm import relax as rx -from tvm.relax.analysis import has_reshape_pattern, udchain +from tvm.relax.analysis import has_reshape_pattern, udchain, remove_all_unused, name_to_binding from tvm.script import relax as R, tir as T @@ -46,6 +46,122 @@ def test_use_def(): assert set(udc[gv0]) == set() +def test_chained_remove_all_unused(): + @tvm.script.ir_module + class IdentityUnused: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = x + unused0 = R.call_tir("my_sigmoid", (x,), R.Tensor((32, 32), dtype="float32")) + unused1 = R.call_tir("my_sigmoid", (unused0,), R.Tensor((32, 32), dtype="float32")) + R.output(lv0) + return lv0 + + optimized = remove_all_unused(IdentityUnused["main"]) + + @tvm.script.ir_module + class GroundTruth: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = x + R.output(lv0) + return lv0 + + tvm.ir.assert_structural_equal(optimized, GroundTruth["main"]) + + +def test_binding_block_remove_all_unused(): + @tvm.script.ir_module + class IdentityUnused: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = x + unused0 = R.call_tir("my_sigmoid", (x,), R.Tensor((32, 32), dtype="float32")) + unused1 = R.call_tir("my_sigmoid", (unused0,), R.Tensor((32, 32), dtype="float32")) + R.output(lv0) + z = R.call_packed("vm.builtin.copy", lv0, sinfo_args=(R.Tensor((32, 32), "float32"))) + return z + + optimized = remove_all_unused(IdentityUnused["main"]) + + @tvm.script.ir_module + class GroundTruth: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = x + R.output(lv0) + z = R.call_packed("vm.builtin.copy", lv0, sinfo_args=(R.Tensor((32, 32), "float32"))) + return z + + tvm.ir.assert_structural_equal(optimized, GroundTruth["main"]) + + +def test_binding_block_fake_unused_remove_all_unused(): + @tvm.script.ir_module + class IdentityUnused: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = x + R.output(lv0) + z = R.call_packed("vm.builtin.copy", lv0, sinfo_args=(R.Tensor((32, 32), "float32"))) + return lv0 + + optimized = remove_all_unused(IdentityUnused["main"]) + + @tvm.script.ir_module + class GroundTruth: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = x + R.output(lv0) + # This might bring side effect so cannot be removed. + z = R.call_packed("vm.builtin.copy", lv0, sinfo_args=(R.Tensor((32, 32), "float32"))) + return lv0 + + tvm.ir.assert_structural_equal(optimized, GroundTruth["main"]) + + +def test_edge_binding_block_fake_unused_remove_all_unused(): + @tvm.script.ir_module + class IdentityUnused: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor((32, 32), "float32"): + z = R.call_packed("vm.builtin.copy", x, sinfo_args=(R.Tensor((32, 32), "float32"))) + return x + + optimized = remove_all_unused(IdentityUnused["main"]) + tvm.ir.assert_structural_equal(optimized, IdentityUnused["main"]) + + +def test_name_to_binding_var_shadowing(): + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = x + lv1 = lv0 + R.output(lv1) + + with R.dataflow(): + lv0 = lv1 # shadowing + lv2 = lv0 + R.output(lv2) + return lv2 + + n2binding = name_to_binding(main) + + assert "lv0" in n2binding + assert "lv1" in n2binding + assert "lv2" in n2binding + + assert len(n2binding["lv0"]) == 2 + + def test_reshape_pattern_reshape(): @T.prim_func def reshape( diff --git a/tests/python/relax/test_binding_rewrite.py b/tests/python/relax/test_binding_rewrite.py new file mode 100644 index 0000000000..1b424b9792 --- /dev/null +++ b/tests/python/relax/test_binding_rewrite.py @@ -0,0 +1,334 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +import tvm +import tvm.testing +from tvm._ffi.base import TVMError +from tvm.relax.analysis import name_to_binding +from tvm.relax.binding_rewrite import DataflowBlockRewrite +from tvm.relax.expr import DataflowVar, Var +from tvm.script import relax as R + + [email protected]_module +class Identity: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = x + R.output(lv0) + return lv0 + + +def assert_immutability(rwt, original_dfb, original_root_fn): + assert rwt.mutated_dfb() != original_dfb + assert rwt.mutated_root_fn() != original_root_fn + assert rwt.mutated_root_fn().body.blocks[0] != original_dfb + assert rwt.mutated_root_fn().body.blocks[0] == rwt.mutated_dfb() + + +def test_null_construct(): + root_fn = Identity["main"] + dfb = root_fn.body.blocks[0] + + DataflowBlockRewrite(dfb, root_fn) + + +def test_simple_add(): + root_fn = Identity["main"] + dfb = root_fn.body.blocks[0] + + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.add(name="tmp", expr=Identity["main"].params[0], is_dfvar=True) + + assert_immutability(rwt, dfb, root_fn) + + # check "tmp" added + assert "tmp" in name_to_binding(rwt.mutated_root_fn()) + + @tvm.script.ir_module + class GroundTruth: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = x + tmp: R.Tensor((32, 32), "float32") = x + R.output(lv0) + return lv0 + + tvm.ir.assert_structural_equal(rwt.mutated_root_fn(), GroundTruth["main"]) + + +def test_simple_auto_add_var(): + root_fn = Identity["main"] + dfb = root_fn.body.blocks[0] + + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.add(root_fn.params[0], is_dfvar=False) + + assert isinstance(rwt.mutated_dfb().bindings[-1].var, Var) + + assert_immutability(rwt, dfb, root_fn) + + +def test_simple_auto_add_dfvar(): + root_fn = Identity["main"] + dfb = root_fn.body.blocks[0] + + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.add(root_fn.params[0], is_dfvar=True) + + assert isinstance(rwt.mutated_dfb().bindings[-1].var, DataflowVar) + + # immutatbility + assert_immutability(rwt, dfb, root_fn) + + +def test_simple_remove_unused(): + @tvm.script.ir_module + class IdentityUnused: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = x + unused = lv0 + R.output(lv0) + return lv0 + + root_fn = IdentityUnused["main"] + dfb = root_fn.body.blocks[0] + + n2binding = name_to_binding(IdentityUnused["main"]) + + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.remove_unused(n2binding["unused"][0].var) + + assert_immutability(rwt, dfb, root_fn) + + # check "unused" removed + assert "unused" not in name_to_binding(rwt.mutated_root_fn()) + + @tvm.script.ir_module + class GroundTruth: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = x + R.output(lv0) + return lv0 + + tvm.ir.assert_structural_equal(rwt.mutated_root_fn(), GroundTruth["main"]) + + +def test_remove_unused_undef(): + root_fn = Identity["main"] + dfb = root_fn.body.blocks[0] + + with pytest.raises(TVMError): + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.remove_unused(Var("whatever")) + + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.remove_unused(Var("whatever"), allow_undef=True) + + assert root_fn == rwt.mutated_root_fn() + + +def test_simple_rm_all_unused(): + @tvm.script.ir_module + class IdentityUnused: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = x + unused0 = lv0 + unused1 = lv0 + R.output(lv0) + return lv0 + + root_fn = IdentityUnused["main"] + dfb = root_fn.body.blocks[0] + + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.remove_all_unused() + + @tvm.script.ir_module + class GroundTruth: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = x + R.output(lv0) + return lv0 + + tvm.ir.assert_structural_equal(rwt.mutated_root_fn(), GroundTruth["main"]) + + [email protected]_module +class DeadDFBlock: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor((32, 32), "float32"): + with R.dataflow(): + lv0 = x + R.output(lv0) + return x + + +def test_empty_dfb_after_removal(): + root_fn = DeadDFBlock["main"] + dfb = root_fn.body.blocks[0] + + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.remove_unused(DeadDFBlock["main"].body.blocks[0].bindings[0].var) + + @tvm.script.ir_module + class GroundTruth: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor((32, 32), "float32"): + return x + + tvm.ir.assert_structural_equal(rwt.mutated_root_fn(), GroundTruth["main"]) + + +def test_empty_dfb_after_all_removal(): + dfb = DeadDFBlock["main"].body.blocks[0] + root_fn = DeadDFBlock["main"] + + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.remove_all_unused() + + @tvm.script.ir_module + class GroundTruth: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor((32, 32), "float32"): + return x + + tvm.ir.assert_structural_equal(rwt.mutated_root_fn(), GroundTruth["main"]) + + +def test_chained_rm_all_unused(): + @tvm.script.ir_module + class IdentityChainedUnused: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = x + unused0 = R.call_tir("my_sigmoid", (x,), R.Tensor((32, 32), dtype="float32")) + unused1 = R.call_tir("my_sigmoid", (unused0,), R.Tensor((32, 32), dtype="float32")) + R.output(lv0) + return lv0 + + root_fn = IdentityChainedUnused["main"] + dfb = root_fn.body.blocks[0] + + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.remove_all_unused() + + @tvm.script.ir_module + class GroundTruth: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = x + R.output(lv0) + return lv0 + + tvm.ir.assert_structural_equal(rwt.mutated_root_fn(), GroundTruth["main"]) + + +def test_simple_replace_all_uses(): + @tvm.script.ir_module + class Lv0To1: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor((32, 32), "float32"): + # lv0 => lv1 + # / \ + # lv2 lv3 + # \ / + # lv4 + with R.dataflow(): + lv0: R.Tensor((32, 32), "float32") = R.call_tir( + "my_relu", (x,), R.Tensor((32, 32), dtype="float32") + ) + lv1: R.Tensor((32, 32), "float32") = R.call_tir( + "my_sigmoid", (x,), R.Tensor((32, 32), dtype="float32") + ) + lv2: R.Tensor((32, 32), "float32") = R.call_tir( + "my_add", (x, lv0), R.Tensor((32, 32), dtype="float32") + ) + lv3: R.Tensor((32, 32), "float32") = R.call_tir( + "my_mul", (x, lv0), R.Tensor((32, 32), dtype="float32") + ) + lv4: R.Tensor((32, 32), "float32") = R.call_tir( + "my_whatever", (lv2, lv3), R.Tensor((32, 32), dtype="float32") + ) + R.output(lv4) + return lv4 + + root_fn = Lv0To1["main"] + dfb = root_fn.body.blocks[0] + + n2binding = name_to_binding(root_fn) + + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.replace_all_uses(n2binding["lv0"][0].var, n2binding["lv1"][0].var) + rwt.remove_unused(n2binding["lv0"][0].var) + + assert_immutability(rwt, dfb, root_fn) + + n2binding_after = name_to_binding(rwt.mutated_root_fn()) + assert "lv0" not in n2binding_after + + +def test_simple_module_update(): + @tvm.script.ir_module + class Identity: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = x + R.output(lv0) + return lv0 + + root_fn = Identity["main"] + dfb = root_fn.body.blocks[0] + + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.add(name="tmp", expr=root_fn.params[0], is_dfvar=True) + + new_ir = rwt.mutate_irmodule(Identity) + + # immutatbility + assert new_ir != Identity + assert 2 == len(new_ir["main"].body.blocks[0].bindings) + + @tvm.script.ir_module + class GroundTruth: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = x + tmp: R.Tensor((32, 32), "float32") = x + R.output(lv0) + return lv0 + + tvm.ir.assert_structural_equal(new_ir, GroundTruth) + + +if __name__ == "__main__": + tvm.testing.main()
