This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new befd95662c [Unity] Statement rewriter for DataflowBlock (#14043)
befd95662c is described below
commit befd95662c94eba3ee8cce5f6b75238ecb1a0bc4
Author: Jiawei Liu <[email protected]>
AuthorDate: Sat Feb 18 18:13:57 2023 -0600
[Unity] Statement rewriter for DataflowBlock (#14043)
---
include/tvm/relax/analysis.h | 24 +++
include/tvm/relax/binding_rewrite.h | 115 ++++++++++
include/tvm/relax/utils.h | 1 +
python/tvm/relax/analysis/analysis.py | 23 +-
python/tvm/relax/binding_rewrite.py | 155 +++++++++++++
src/relax/ir/binding_rewrite.cc | 324 ++++++++++++++++++++++++++++
tests/python/relax/test_analysis.py | 118 +++++++++-
tests/python/relax/test_binding_rewrite.py | 334 +++++++++++++++++++++++++++++
8 files changed, 1092 insertions(+), 2 deletions(-)
diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h
index 32e1582134..b9866577e9 100644
--- a/include/tvm/relax/analysis.h
+++ b/include/tvm/relax/analysis.h
@@ -341,6 +341,14 @@ TVM_DLL Map<Var, Expr> AnalyzeVar2Value(const Expr& expr);
*/
TVM_DLL Map<Var, Expr> AnalyzeVar2Value(const DataflowBlock& dfb);
+/*!
+ * \brief Return a mapping from variable name to its Bindings.
+ *
+ * \param fn The function to be analyzed.
+ * \return A mapping from variable name to its Bindings.
+ */
+TVM_DLL Map<String, Array<Binding>> NameToBinding(const Function& fn);
+
/*!
* \brief Get the use-def chain of variables inside a dataflow block.
*
@@ -349,6 +357,22 @@ TVM_DLL Map<Var, Expr> AnalyzeVar2Value(const
DataflowBlock& dfb);
*/
TVM_DLL Map<Var, Array<Var>> DataflowBlockUseDef(const DataflowBlock& dfb);
+/*!
+ * \brief Get the use-def chain of variables inside a function.
+ *
+ * \param fn The function to be analyzed.
+ * \return A map from variable definitions to a set of uses and variables
needed by return value.
+ */
+std::pair<Map<Var, Array<Var>>, Array<Var>> FunctionUseDef(const Function& fn);
+
+/*!
+ * \brief Remove unused statements inside DataflowBlocks.
+ *
+ * \param fn The function to remove unused statements.
+ * \return The function that contains no unused statements in DataflowBlock.
+ */
+TVM_DLL Function RemoveAllUnused(const Function fn);
+
/*!
* \brief Annotate Op Pattern Kind for PrimFunc, which is used in relax
FuseOps.
*
diff --git a/include/tvm/relax/binding_rewrite.h
b/include/tvm/relax/binding_rewrite.h
new file mode 100644
index 0000000000..a4b534965a
--- /dev/null
+++ b/include/tvm/relax/binding_rewrite.h
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relax/binding_rewrite.h
+ * \brief An IR rewriter to easily add/remove/replace bindings (statements).
+ */
+
+#ifndef TVM_RELAX_BINDING_REWRITE_H_
+
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/expr.h>
+#include <tvm/relax/utils.h>
+
+#include <map>
+#include <set>
+#include <type_traits>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace relax {
+
+/*! \brief Statement rewriter for relax.DataflowBlock. */
+class DataflowBlockRewriteNode : public Object {
+ public:
+ /*! \brief Replace all uses of old_var with new_var. */
+ void ReplaceAllUses(Var old_var, Var new_var);
+ /*! \brief Insert a Binding statement. */
+ void Add(Binding binding);
+ /*! \brief Insert an expression as VarBinding with variable name. */
+ void Add(String var_name, Expr expr, bool is_dfvar = false) {
+ auto var = is_dfvar ? DataflowVar(var_name, GetStructInfo(expr)) //
+ : Var(var_name, GetStructInfo(expr));
+ Add(VarBinding(std::move(var), std::move(expr)));
+ }
+ /*! \brief Insert an expression as VarBinding with automatic variable name.
*/
+ void Add(Expr expr, bool is_dfvar = false) {
+ Add(name_table_.GetUniqueName("tmp"), expr, is_dfvar);
+ }
+ /*! \brief Remove the definition statement of an unused variable. */
+ void RemoveUnused(Var unused, bool allow_undef = false);
+ /*! \brief Remove the definition statements of all unused variables. */
+ void RemoveAllUnused();
+
+ /*! \brief The rewritten dataflow block. */
+ DataflowBlock MutatedDataflowBlock() { return dfb_.value(); }
+ /*! \brief The rewritten function. */
+ Function MutatedFunc() { return root_fn_.value(); }
+ /*! \brief The rewritten IRModule. */
+ IRModule MutateIRModule(IRModule irmod);
+
+ /*! \brief Visit attributes. */
+ void VisitAttrs(AttrVisitor* v) {
+ v->Visit("dfb", &dfb_);
+ v->Visit("root_fn", &root_fn_);
+ }
+
+ static constexpr const char* _type_key = "relax.DataflowBlockRewrite";
+ TVM_DECLARE_FINAL_OBJECT_INFO(DataflowBlockRewriteNode, Object);
+
+ protected:
+ friend class DataflowBlockRewrite;
+
+ Optional<DataflowBlock> dfb_; //!< The rewritten dataflow block.
+ Optional<Function> root_fn_; //!< The rewritten function.
+ const FunctionNode* original_fn_ptr_; //!< Pointer to the original function.
+ Map<Var, Array<Var>> to_users_; //!< Map from variable to its users.
+ Array<Var> fn_outputs_; //!< Variables required by function
outputs.
+
+ private:
+ NameTable name_table_; //!< Name table for tracking and generating unique
names.
+};
+
+/*!
+ * \brief A statement rewriter for relax.DataflowBlock.
+ * \sa DataflowBlockRewriteNode
+ */
+class DataflowBlockRewrite : public ObjectRef {
+ public:
+ TVM_DLL explicit DataflowBlockRewrite(DataflowBlock dfb, Function root_fn);
+
+ /*!
+ * \brief mutable accessor.
+ * \return mutable access pointer.
+ */
+ DataflowBlockRewriteNode* operator->() {
+ ICHECK(get() != nullptr);
+ return static_cast<DataflowBlockRewriteNode*>(get_mutable());
+ }
+
+ TVM_DEFINE_OBJECT_REF_METHODS(DataflowBlockRewrite, ObjectRef,
DataflowBlockRewriteNode);
+};
+
+} // namespace relax
+} // namespace tvm
+
+#define TVM_RELAX_BINDING_REWRITE_H_
+#endif // TVM_RELAX_BINDING_REWRITE_H_
diff --git a/include/tvm/relax/utils.h b/include/tvm/relax/utils.h
index 1457a16427..c1d984a21a 100644
--- a/include/tvm/relax/utils.h
+++ b/include/tvm/relax/utils.h
@@ -25,6 +25,7 @@
#define TVM_RELAX_UTILS_H_
#include <tvm/ir/module.h>
+#include <tvm/relax/expr.h>
#include <tvm/runtime/logging.h>
#include <algorithm>
diff --git a/python/tvm/relax/analysis/analysis.py
b/python/tvm/relax/analysis/analysis.py
index 45c5b6f962..ffcdaceb40 100644
--- a/python/tvm/relax/analysis/analysis.py
+++ b/python/tvm/relax/analysis/analysis.py
@@ -28,7 +28,7 @@ from tvm import tir
from tvm import IRModule
from tvm.relax.ty import Type
from tvm.relax.struct_info import StructInfo, FuncStructInfo
-from tvm.relax.expr import DataflowBlock, Var, Expr, Function, Call
+from tvm.relax.expr import DataflowBlock, Var, Expr, Function, Call, Binding
from . import _ffi_api
@@ -244,6 +244,27 @@ def udchain(dfb: DataflowBlock) -> Dict[Var, List[Var]]:
return _ffi_api.udchain(dfb) # type: ignore
+def name_to_binding(func: Function) -> Dict[str, List[Binding]]:
+ """Return a map from variable name to its bindings."""
+ return _ffi_api.name_to_binding(func) # type: ignore
+
+
+def remove_all_unused(func: Function) -> Function:
+ """Remove all unused variables from the function.
+
+ Parameters
+ ----------
+ func : Function
+ The input function to be analyzed.
+
+ Returns
+ -------
+ Function
+ The function with unused variables removed.
+ """
+ return _ffi_api.remove_all_unused(func) # type: ignore
+
+
def well_formed(mod: IRModule, check_struct_info: bool = True) -> bool:
"""Check if the IRModule is well formed.
diff --git a/python/tvm/relax/binding_rewrite.py
b/python/tvm/relax/binding_rewrite.py
new file mode 100644
index 0000000000..a9f6d878ad
--- /dev/null
+++ b/python/tvm/relax/binding_rewrite.py
@@ -0,0 +1,155 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=no-else-return, invalid-name
+"""Developer API of add/remove/replace bindings in Relax."""
+
+from typing import Optional
+
+import tvm
+import tvm._ffi
+from tvm.runtime import Object
+from . import Binding, DataflowBlock, Expr, Function, Var
+from . import _ffi_api
+
+
+@tvm._ffi.register_object("relax.DataflowBlockRewrite")
+class DataflowBlockRewrite(Object):
+ """
+ A binding/statement-level dataflow block rewriter.
+
+ Notes
+ -----
+ Due to the immutable and copy-on-write nature of TVM AST nodes, the
rewriting is not done in
+ place. Instead, a new DataflowBlock is created and returned with
mutated_dfb. Similarly, its new
+ root Function is created and returned by mutated_root_fn. To apply this
change for an IRModule,
+ use mutate_irmodule which rewrites the old function that registered in the
constructor.
+ """
+
+ def __init__(self, dfb: DataflowBlock, root_fn: Function):
+ """
+ Construct a rewriter with the DataflowBlock to rewrite and its root
function.
+
+ Parameters
+ ----------
+ dfb : DataflowBlock
+ The DataflowBlock to rewrite.
+ root_fn : Function
+ The root function of the DataflowBlock.
+ """
+ self.func_name = root_fn.__name__ if hasattr(root_fn, "__name__") else
None
+ self.__init_handle_by_constructor__(
+ _ffi_api.DataflowBlockRewrite, dfb, root_fn # type: ignore
+ )
+
+ def replace_all_uses(self, old_var: Var, new_var: Var) -> None:
+ """
+ Replace all uses of old_var with new_var.
+
+ Parameters
+ ----------
+ old_var : Var
+ The old variable to replace.
+ new_var : Var
+ The new variable to replace with.
+ """
+ _ffi_api.dfb_rewrite_replace_all_uses(self, old_var, new_var) # type:
ignore
+
+ def add_binding(self, binding: Binding) -> None:
+ return _ffi_api.dfb_rewrite_add_binding(self, binding) # type: ignore
+
+ def add(self, expr: Expr, name: Optional[str] = None, is_dfvar: bool =
False) -> None:
+ """
+ Add a new statement to the DataflowBlock with an automatically
generated variable name.
+
+ Parameters
+ ----------
+ expr : Expr
+ The expression to add.
+ name : Optional[str], optional
+ Variable name, by default None
+ is_dfvar : bool, optional
+ The variable type, by default False
+
+ Notes
+ -----
+ If the variable name is not given, it will be automatically generated
in a form of
+ "tmp${COUNTER}". The variable type will be DataflowVar if is_dfvar is
True, otherwise
+ it will be Var. Being Var means the variables are output variables of
the DataflowBlock.
+ While being DataflowVar means the variables are internal variables of
the DataflowBlock.
+ """
+ _ffi_api.dfb_rewrite_add(self, expr, name, is_dfvar) # type: ignore
+
+ def remove_unused(self, var: Var, allow_undef=False) -> None:
+ """
+ Remove a statement by its variable definition if and only if it is
unused.
+
+ Parameters
+ ----------
+ var : Var
+ The unused variable definition.
+ allow_undef : bool, optional
+ Whether to allow var being undefined variable, by default False
+
+ Raises
+ ------
+ TVMError if the variable is used or undefined (allow_undef=False).
+ """
+ _ffi_api.dfb_rewrite_remove_unused(self, var, allow_undef) # type:
ignore
+
+ def remove_all_unused(self) -> None:
+ """
+ Remove all unused variables.
+
+ Notes
+ -----
+ This could remove unused variables in other DataflowBlocks as well.
+ """
+ _ffi_api.dfb_rewrite_remove_all_unused(self) # type: ignore
+
+ def mutated_dfb(self) -> DataflowBlock:
+ """
+ Returns the mutated DataflowBlock.
+ """
+ return self.dfb
+
+ def mutated_root_fn(self) -> Function:
+ """
+ Returns the mutated root function.
+ """
+ ret = self.root_fn
+ if self.func_name:
+ ret.__name__ = self.func_name
+ return ret
+
+ def mutate_irmodule(self, irmodule: tvm.IRModule) -> tvm.IRModule:
+ """
+ Return an updated IRModule by replacing the old function with the
mutated root function.
+
+ Parameters
+ ----------
+ irmodule : tvm.IRModule
+ The base IRModule to update.
+
+ Returns
+ -------
+ tvm.IRModule
+ The updated IRModule.
+ """
+ ret = _ffi_api.dfb_rewrite_mutate_irmodule(self, irmodule) # type:
ignore
+ if hasattr(irmodule, "__name__"):
+ ret.__name__ = irmodule.__name__
+ return ret
diff --git a/src/relax/ir/binding_rewrite.cc b/src/relax/ir/binding_rewrite.cc
new file mode 100644
index 0000000000..dd9fac9fdc
--- /dev/null
+++ b/src/relax/ir/binding_rewrite.cc
@@ -0,0 +1,324 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relax/ir/binding_rewrite.cc
+ * \brief Implementation of binding rewriters.
+ */
+
+#include <tvm/relax/binding_rewrite.h>
+#include <tvm/relax/block_builder.h>
+#include <tvm/relax/expr.h>
+#include <tvm/relax/expr_functor.h>
+
+#include <functional>
+#include <iterator>
+
+namespace tvm {
+namespace relax {
+
+TVM_REGISTER_NODE_TYPE(DataflowBlockRewriteNode);
+DataflowBlockRewrite::DataflowBlockRewrite(DataflowBlock dfb, Function
root_fn) {
+ auto n = make_object<DataflowBlockRewriteNode>();
+ n->dfb_ = dfb;
+ n->root_fn_ = root_fn;
+ n->original_fn_ptr_ = root_fn.get();
+ auto p = FunctionUseDef(root_fn);
+ n->to_users_ = std::move(p.first);
+ n->fn_outputs_ = std::move(p.second);
+ n->name_table_ = NameTable(n->to_users_.begin(), n->to_users_.end(),
+ [](const auto& p) { return p.first->name_hint();
});
+
+ data_ = std::move(n);
+}
+
+TVM_REGISTER_GLOBAL("relax.DataflowBlockRewrite")
+ .set_body_typed([](DataflowBlock dfb, Function root_fn) {
+ return DataflowBlockRewrite(dfb, root_fn);
+ });
+
+void DataflowBlockRewriteNode::ReplaceAllUses(Var old_var, Var new_var) {
+ class ReplaceAllUsePass : public ExprMutator {
+ Var old_var, new_var;
+ const DataflowBlockNode* const to_catch;
+
+ public:
+ const DataflowBlockNode* caught = nullptr;
+
+ ReplaceAllUsePass(Var old_var, Var new_var, const DataflowBlockNode*
to_catch)
+ : old_var(old_var), new_var(new_var), to_catch(to_catch) {}
+
+ using ExprMutator::VisitExpr_;
+
+ Expr VisitExpr_(const VarNode* op) override {
+ return (op == old_var.get()) ? new_var : GetRef<Expr>(op);
+ }
+
+ Expr VisitExpr_(const DataflowVarNode* op) override {
+ return (op == old_var.get()) ? new_var : GetRef<Expr>(op);
+ }
+
+ BindingBlock VisitBindingBlock_(const DataflowBlockNode* op) override {
+ BindingBlock res = ExprMutator::VisitBindingBlock_(op);
+ if (op == to_catch) caught = static_cast<const
DataflowBlockNode*>(res.get());
+ return res;
+ }
+ };
+
+ ICHECK(to_users_.find(old_var) != to_users_.end()) << "Cannot find " <<
old_var;
+ ICHECK(to_users_.find(new_var) != to_users_.end()) << "Cannot find " <<
new_var;
+
+ // replace uses inside the DataflowBlock.
+ ReplaceAllUsePass replacer(old_var, new_var, dfb_.get());
+ root_fn_ = Downcast<Function>(replacer.VisitExpr_(root_fn_.get()));
+ dfb_ = GetRef<DataflowBlock>(replacer.caught);
+
+ // update udchain
+ // old_var -> old_var users | changed to {}
+ // new_var -> {?} | changed to old_var users
+ for (Var user : to_users_[old_var]) {
+ auto new_var_uses = to_users_[new_var];
+ if (new_var_uses.end() == std::find(new_var_uses.begin(),
new_var_uses.end(), user)) {
+ new_var_uses.push_back(user);
+ }
+ }
+
+ to_users_.Set(old_var, {});
+
+ auto it_old_output = std::find(fn_outputs_.begin(), fn_outputs_.end(),
old_var);
+ if (it_old_output != fn_outputs_.end()) {
+ fn_outputs_.Set(std::distance(fn_outputs_.begin(), it_old_output),
new_var);
+ }
+}
+
+TVM_REGISTER_GLOBAL("relax.dfb_rewrite_replace_all_uses")
+ .set_body_typed([](DataflowBlockRewrite rwt, Var old_var, Var new_var) {
+ rwt->ReplaceAllUses(old_var, new_var);
+ });
+
+class UpdateDFB : public ExprMutator {
+ private:
+ DataflowBlock old_dfb, new_dfb;
+
+ public:
+ UpdateDFB(DataflowBlock old_dfb, DataflowBlock new_dfb)
+ : old_dfb(std::move(old_dfb)), new_dfb(std::move(new_dfb)) {}
+
+ BindingBlock VisitBindingBlock_(const DataflowBlockNode* op) override {
+ return old_dfb.get() == op ? new_dfb : old_dfb;
+ }
+};
+
+void DataflowBlockRewriteNode::Add(Binding binding) {
+ auto p = [binding] {
+ if (auto vb = binding.as<VarBindingNode>()) {
+ return std::make_pair(vb->var, vb->value);
+ } else if (auto mc = binding.as<MatchCastNode>()) {
+ return std::make_pair(mc->var, mc->value);
+ }
+ LOG(FATAL) << "Unsupported binding type";
+ return std::make_pair(Var{}, Expr{});
+ }();
+ Var var = p.first;
+ Expr val = p.second;
+
+ ICHECK(0 == to_users_.count(var)) << var << " has been defined so cannot be
added.";
+
+ // Add this VarBinding statement after the definition of uses.
+ std::set<const VarNode*> used_vars = [val] {
+ class UsedVars : public ExprVisitor {
+ public:
+ std::set<const VarNode*> used_vars;
+ void VisitExpr_(const VarNode* op) override { used_vars.insert(op); }
+ void VisitExpr_(const DataflowVarNode* op) override {
used_vars.insert(op); }
+ } uvar{};
+ uvar.VisitExpr(val);
+ return std::move(uvar.used_vars);
+ }();
+
+ size_t line_last_req_def = 0;
+ for (size_t i = 0; i < dfb_.value()->bindings.size(); ++i) {
+ auto line = dfb_.value()->bindings[i];
+ if (used_vars.find(line->var.get()) != used_vars.cend()) line_last_req_def
= i;
+ }
+
+ auto old_dfb = dfb_.value();
+
+ dfb_ = [old_dfb, binding, line_last_req_def, this] {
+ auto new_dfb = dfb_.value();
+ new_dfb.CopyOnWrite()->bindings.insert(dfb_.value()->bindings.begin() + 1
+ line_last_req_def,
+ binding);
+ return new_dfb;
+ }();
+
+ auto updater = UpdateDFB(old_dfb, dfb_.value());
+ root_fn_ = Downcast<Function>(updater.VisitExpr_(root_fn_.get()));
+
+ for (const VarNode* v : used_vars)
to_users_.Get(GetRef<Var>(v)).value().push_back(var);
+}
+
+TVM_REGISTER_GLOBAL("relax.dfb_rewrite_add_binding")
+ .set_body_typed([](DataflowBlockRewrite rwt, Binding vb) { rwt->Add(vb);
});
+
+TVM_REGISTER_GLOBAL("relax.dfb_rewrite_add")
+ .set_body_typed([](DataflowBlockRewrite rwt, Expr expr, Optional<String>
name, bool is_dfvar) {
+ if (name.get()) {
+ rwt->Add(name.value(), expr, is_dfvar);
+ } else {
+ rwt->Add(expr, is_dfvar);
+ }
+ });
+
+class RemoveUnusedVars : public ExprMutator {
+ public:
+ std::set<Var> unused_vars;
+ Optional<DataflowBlock> caught_rewrite = NullOpt;
+
+ RemoveUnusedVars(Map<Var, Array<Var>> users, Array<Var> fn_outputs)
+ : unused_vars([&] {
+ std::vector<Var> unused;
+
+ // iterative dataflow algorithm.
+ size_t prev_size;
+ do {
+ prev_size = unused.size();
+
+ std::vector<Var> used;
+ used.reserve(users.size());
+ for (const auto& kv : users) {
+ // var -> [users...]
+ // var is unused iff
+ // user -> empty
+ // var is not output var
+ if (kv.second.empty() && // kv.first is not used by fn outputs.
+ fn_outputs.end() == std::find(fn_outputs.begin(),
fn_outputs.end(), kv.first)) {
+ unused.push_back(kv.first);
+ } else {
+ used.push_back(kv.first);
+ }
+ }
+
+ for (size_t i = prev_size; i < unused.size(); ++i) {
+ users.erase(unused[i]);
+ // remove def site.
+ for (const auto& used_var : used) {
+ ICHECK(users.count(used_var));
+ Array<Var> var_users = users[used_var];
+ // remove the unused var from the use site.
+ auto it = std::find(var_users.begin(), var_users.end(),
unused[i]);
+ if (it != var_users.end()) {
+ var_users.erase(it);
+ users.Set(used_var, std::move(var_users));
+ }
+ }
+ }
+ } while (prev_size != unused.size()); // changed? => continue.
+
+ return std::set<Var>(unused.begin(), unused.end());
+ }()) {}
+
+ RemoveUnusedVars(std::pair<Map<Var, Array<Var>>, Array<Var>>
users_and_outputs)
+ : RemoveUnusedVars(std::move(users_and_outputs.first),
std::move(users_and_outputs.second)) {}
+ RemoveUnusedVars(Function fn) : RemoveUnusedVars(FunctionUseDef(fn)) {}
+ RemoveUnusedVars(std::set<Var> unused_vars) :
unused_vars(std::move(unused_vars)) {}
+
+ BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) {
+ auto prev_dfb = GetRef<DataflowBlock>(block);
+ builder_->BeginDataflowBlock();
+ for (Binding binding : block->bindings) {
+ if (!unused_vars.count(binding->var)) {
+ VisitBinding(binding);
+ }
+ }
+ auto new_dfb = builder_->EndBlock();
+ if (caught_rewrite == prev_dfb) caught_rewrite =
Downcast<DataflowBlock>(new_dfb);
+ return std::move(new_dfb);
+ }
+};
+
+void DataflowBlockRewriteNode::RemoveUnused(Var unused, bool allow_undef) {
+ // first need to check if this var is used.
+ if (0 == to_users_.count(unused)) { // no def.
+ if (allow_undef) return;
+ LOG(FATAL) << unused << " undefined. Set allow_undef=True to allow
'removing' undefined var";
+ }
+
+ ICHECK(to_users_[unused].empty())
+ << unused << " is used by " << to_users_[unused].size() << " vars";
+
+ auto old_dfb = dfb_.value();
+
+ RemoveUnusedVars remover({unused});
+ dfb_ = Downcast<DataflowBlock>(remover.VisitBindingBlock_(old_dfb.get()));
+
+ auto updater = UpdateDFB(old_dfb, dfb_.value());
+ root_fn_ = Downcast<Function>(updater.VisitExpr_(root_fn_.get()));
+
+ to_users_.erase(unused); // update use-def chain.
+}
+
+TVM_REGISTER_GLOBAL("relax.dfb_rewrite_remove_unused")
+ .set_body_typed([](DataflowBlockRewrite rwt, Var unused, bool allow_undef)
{
+ rwt->RemoveUnused(unused, allow_undef);
+ });
+
+void DataflowBlockRewriteNode::RemoveAllUnused() {
+ RemoveUnusedVars remover(to_users_, fn_outputs_);
+ remover.caught_rewrite = dfb_.value();
+
+ // this could also clean unused variables in other DataflowBlock.
+ root_fn_ = Downcast<Function>(remover.VisitExpr_(root_fn_.get()));
+
+ // DataflowBlock could be None.
+ dfb_ = remover.caught_rewrite.value();
+
+ // clean up use-def chain.
+ for (const auto& unused : remover.unused_vars) to_users_.erase(unused);
+}
+
+TVM_REGISTER_GLOBAL("relax.dfb_rewrite_remove_all_unused")
+ .set_body_typed([](DataflowBlockRewrite rwt) { rwt->RemoveAllUnused(); });
+
+Function RemoveAllUnused(Function fn) {
+ RemoveUnusedVars remover(fn);
+ return Downcast<Function>(remover.VisitExpr_(fn.get()));
+}
+
+TVM_REGISTER_GLOBAL("relax.analysis.remove_all_unused").set_body_typed(RemoveAllUnused);
+
+IRModule DataflowBlockRewriteNode::MutateIRModule(IRModule irmod) {
+ BlockBuilder builder = BlockBuilder::Create(irmod);
+
+ for (auto& p : irmod->functions) {
+ if (original_fn_ptr_ == p.second.get()) {
+ builder->UpdateFunction(p.first, root_fn_.value());
+ break;
+ }
+ }
+
+ return builder->GetContextIRModule();
+}
+
+TVM_REGISTER_GLOBAL("relax.dfb_rewrite_mutate_irmodule")
+ .set_body_typed([](DataflowBlockRewrite rwt, IRModule irmod) {
+ return rwt->MutateIRModule(irmod);
+ });
+
+} // namespace relax
+} // namespace tvm
diff --git a/tests/python/relax/test_analysis.py
b/tests/python/relax/test_analysis.py
index 43558a52be..e939d2b208 100644
--- a/tests/python/relax/test_analysis.py
+++ b/tests/python/relax/test_analysis.py
@@ -21,7 +21,7 @@ import tvm
import tvm.testing
from tvm import tir
from tvm import relax as rx
-from tvm.relax.analysis import has_reshape_pattern, udchain
+from tvm.relax.analysis import has_reshape_pattern, udchain,
remove_all_unused, name_to_binding
from tvm.script import relax as R, tir as T
@@ -46,6 +46,122 @@ def test_use_def():
assert set(udc[gv0]) == set()
+def test_chained_remove_all_unused():
+ @tvm.script.ir_module
+ class IdentityUnused:
+ @R.function
+ def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
+ with R.dataflow():
+ lv0 = x
+ unused0 = R.call_tir("my_sigmoid", (x,), R.Tensor((32, 32),
dtype="float32"))
+ unused1 = R.call_tir("my_sigmoid", (unused0,), R.Tensor((32,
32), dtype="float32"))
+ R.output(lv0)
+ return lv0
+
+ optimized = remove_all_unused(IdentityUnused["main"])
+
+ @tvm.script.ir_module
+ class GroundTruth:
+ @R.function
+ def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
+ with R.dataflow():
+ lv0 = x
+ R.output(lv0)
+ return lv0
+
+ tvm.ir.assert_structural_equal(optimized, GroundTruth["main"])
+
+
+def test_binding_block_remove_all_unused():
+ @tvm.script.ir_module
+ class IdentityUnused:
+ @R.function
+ def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
+ with R.dataflow():
+ lv0 = x
+ unused0 = R.call_tir("my_sigmoid", (x,), R.Tensor((32, 32),
dtype="float32"))
+ unused1 = R.call_tir("my_sigmoid", (unused0,), R.Tensor((32,
32), dtype="float32"))
+ R.output(lv0)
+ z = R.call_packed("vm.builtin.copy", lv0,
sinfo_args=(R.Tensor((32, 32), "float32")))
+ return z
+
+ optimized = remove_all_unused(IdentityUnused["main"])
+
+ @tvm.script.ir_module
+ class GroundTruth:
+ @R.function
+ def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
+ with R.dataflow():
+ lv0 = x
+ R.output(lv0)
+ z = R.call_packed("vm.builtin.copy", lv0,
sinfo_args=(R.Tensor((32, 32), "float32")))
+ return z
+
+ tvm.ir.assert_structural_equal(optimized, GroundTruth["main"])
+
+
+def test_binding_block_fake_unused_remove_all_unused():
+ @tvm.script.ir_module
+ class IdentityUnused:
+ @R.function
+ def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
+ with R.dataflow():
+ lv0 = x
+ R.output(lv0)
+ z = R.call_packed("vm.builtin.copy", lv0,
sinfo_args=(R.Tensor((32, 32), "float32")))
+ return lv0
+
+ optimized = remove_all_unused(IdentityUnused["main"])
+
+ @tvm.script.ir_module
+ class GroundTruth:
+ @R.function
+ def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
+ with R.dataflow():
+ lv0 = x
+ R.output(lv0)
+ # This might bring side effect so cannot be removed.
+ z = R.call_packed("vm.builtin.copy", lv0,
sinfo_args=(R.Tensor((32, 32), "float32")))
+ return lv0
+
+ tvm.ir.assert_structural_equal(optimized, GroundTruth["main"])
+
+
+def test_edge_binding_block_fake_unused_remove_all_unused():
+ @tvm.script.ir_module
+ class IdentityUnused:
+ @R.function
+ def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor((32, 32),
"float32"):
+ z = R.call_packed("vm.builtin.copy", x, sinfo_args=(R.Tensor((32,
32), "float32")))
+ return x
+
+ optimized = remove_all_unused(IdentityUnused["main"])
+ tvm.ir.assert_structural_equal(optimized, IdentityUnused["main"])
+
+
+def test_name_to_binding_var_shadowing():
+ @R.function
+ def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
+ with R.dataflow():
+ lv0 = x
+ lv1 = lv0
+ R.output(lv1)
+
+ with R.dataflow():
+ lv0 = lv1 # shadowing
+ lv2 = lv0
+ R.output(lv2)
+ return lv2
+
+ n2binding = name_to_binding(main)
+
+ assert "lv0" in n2binding
+ assert "lv1" in n2binding
+ assert "lv2" in n2binding
+
+ assert len(n2binding["lv0"]) == 2
+
+
def test_reshape_pattern_reshape():
@T.prim_func
def reshape(
diff --git a/tests/python/relax/test_binding_rewrite.py
b/tests/python/relax/test_binding_rewrite.py
new file mode 100644
index 0000000000..1b424b9792
--- /dev/null
+++ b/tests/python/relax/test_binding_rewrite.py
@@ -0,0 +1,334 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+import tvm
+import tvm.testing
+from tvm._ffi.base import TVMError
+from tvm.relax.analysis import name_to_binding
+from tvm.relax.binding_rewrite import DataflowBlockRewrite
+from tvm.relax.expr import DataflowVar, Var
+from tvm.script import relax as R
+
+
[email protected]_module
+class Identity:
+ @R.function
+ def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
+ with R.dataflow():
+ lv0 = x
+ R.output(lv0)
+ return lv0
+
+
+def assert_immutability(rwt, original_dfb, original_root_fn):
+ assert rwt.mutated_dfb() != original_dfb
+ assert rwt.mutated_root_fn() != original_root_fn
+ assert rwt.mutated_root_fn().body.blocks[0] != original_dfb
+ assert rwt.mutated_root_fn().body.blocks[0] == rwt.mutated_dfb()
+
+
+def test_null_construct():
+ root_fn = Identity["main"]
+ dfb = root_fn.body.blocks[0]
+
+ DataflowBlockRewrite(dfb, root_fn)
+
+
+def test_simple_add():
+ root_fn = Identity["main"]
+ dfb = root_fn.body.blocks[0]
+
+ rwt = DataflowBlockRewrite(dfb, root_fn)
+ rwt.add(name="tmp", expr=Identity["main"].params[0], is_dfvar=True)
+
+ assert_immutability(rwt, dfb, root_fn)
+
+ # check "tmp" added
+ assert "tmp" in name_to_binding(rwt.mutated_root_fn())
+
+ @tvm.script.ir_module
+ class GroundTruth:
+ @R.function
+ def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
+ with R.dataflow():
+ lv0 = x
+ tmp: R.Tensor((32, 32), "float32") = x
+ R.output(lv0)
+ return lv0
+
+ tvm.ir.assert_structural_equal(rwt.mutated_root_fn(), GroundTruth["main"])
+
+
+def test_simple_auto_add_var():
+ root_fn = Identity["main"]
+ dfb = root_fn.body.blocks[0]
+
+ rwt = DataflowBlockRewrite(dfb, root_fn)
+ rwt.add(root_fn.params[0], is_dfvar=False)
+
+ assert isinstance(rwt.mutated_dfb().bindings[-1].var, Var)
+
+ assert_immutability(rwt, dfb, root_fn)
+
+
+def test_simple_auto_add_dfvar():
+ root_fn = Identity["main"]
+ dfb = root_fn.body.blocks[0]
+
+ rwt = DataflowBlockRewrite(dfb, root_fn)
+ rwt.add(root_fn.params[0], is_dfvar=True)
+
+ assert isinstance(rwt.mutated_dfb().bindings[-1].var, DataflowVar)
+
+ # immutatbility
+ assert_immutability(rwt, dfb, root_fn)
+
+
+def test_simple_remove_unused():
+ @tvm.script.ir_module
+ class IdentityUnused:
+ @R.function
+ def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
+ with R.dataflow():
+ lv0 = x
+ unused = lv0
+ R.output(lv0)
+ return lv0
+
+ root_fn = IdentityUnused["main"]
+ dfb = root_fn.body.blocks[0]
+
+ n2binding = name_to_binding(IdentityUnused["main"])
+
+ rwt = DataflowBlockRewrite(dfb, root_fn)
+ rwt.remove_unused(n2binding["unused"][0].var)
+
+ assert_immutability(rwt, dfb, root_fn)
+
+ # check "unused" removed
+ assert "unused" not in name_to_binding(rwt.mutated_root_fn())
+
+ @tvm.script.ir_module
+ class GroundTruth:
+ @R.function
+ def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
+ with R.dataflow():
+ lv0 = x
+ R.output(lv0)
+ return lv0
+
+ tvm.ir.assert_structural_equal(rwt.mutated_root_fn(), GroundTruth["main"])
+
+
+def test_remove_unused_undef():
+ root_fn = Identity["main"]
+ dfb = root_fn.body.blocks[0]
+
+ with pytest.raises(TVMError):
+ rwt = DataflowBlockRewrite(dfb, root_fn)
+ rwt.remove_unused(Var("whatever"))
+
+ rwt = DataflowBlockRewrite(dfb, root_fn)
+ rwt.remove_unused(Var("whatever"), allow_undef=True)
+
+ assert root_fn == rwt.mutated_root_fn()
+
+
+def test_simple_rm_all_unused():
+ @tvm.script.ir_module
+ class IdentityUnused:
+ @R.function
+ def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
+ with R.dataflow():
+ lv0 = x
+ unused0 = lv0
+ unused1 = lv0
+ R.output(lv0)
+ return lv0
+
+ root_fn = IdentityUnused["main"]
+ dfb = root_fn.body.blocks[0]
+
+ rwt = DataflowBlockRewrite(dfb, root_fn)
+ rwt.remove_all_unused()
+
+ @tvm.script.ir_module
+ class GroundTruth:
+ @R.function
+ def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
+ with R.dataflow():
+ lv0 = x
+ R.output(lv0)
+ return lv0
+
+ tvm.ir.assert_structural_equal(rwt.mutated_root_fn(), GroundTruth["main"])
+
+
[email protected]_module
+class DeadDFBlock:
+ @R.function
+ def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor((32, 32),
"float32"):
+ with R.dataflow():
+ lv0 = x
+ R.output(lv0)
+ return x
+
+
+def test_empty_dfb_after_removal():
+ root_fn = DeadDFBlock["main"]
+ dfb = root_fn.body.blocks[0]
+
+ rwt = DataflowBlockRewrite(dfb, root_fn)
+ rwt.remove_unused(DeadDFBlock["main"].body.blocks[0].bindings[0].var)
+
+ @tvm.script.ir_module
+ class GroundTruth:
+ @R.function
+ def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor((32, 32),
"float32"):
+ return x
+
+ tvm.ir.assert_structural_equal(rwt.mutated_root_fn(), GroundTruth["main"])
+
+
+def test_empty_dfb_after_all_removal():
+ dfb = DeadDFBlock["main"].body.blocks[0]
+ root_fn = DeadDFBlock["main"]
+
+ rwt = DataflowBlockRewrite(dfb, root_fn)
+ rwt.remove_all_unused()
+
+ @tvm.script.ir_module
+ class GroundTruth:
+ @R.function
+ def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor((32, 32),
"float32"):
+ return x
+
+ tvm.ir.assert_structural_equal(rwt.mutated_root_fn(), GroundTruth["main"])
+
+
+def test_chained_rm_all_unused():
+ @tvm.script.ir_module
+ class IdentityChainedUnused:
+ @R.function
+ def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
+ with R.dataflow():
+ lv0 = x
+ unused0 = R.call_tir("my_sigmoid", (x,), R.Tensor((32, 32),
dtype="float32"))
+ unused1 = R.call_tir("my_sigmoid", (unused0,), R.Tensor((32,
32), dtype="float32"))
+ R.output(lv0)
+ return lv0
+
+ root_fn = IdentityChainedUnused["main"]
+ dfb = root_fn.body.blocks[0]
+
+ rwt = DataflowBlockRewrite(dfb, root_fn)
+ rwt.remove_all_unused()
+
+ @tvm.script.ir_module
+ class GroundTruth:
+ @R.function
+ def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
+ with R.dataflow():
+ lv0 = x
+ R.output(lv0)
+ return lv0
+
+ tvm.ir.assert_structural_equal(rwt.mutated_root_fn(), GroundTruth["main"])
+
+
+def test_simple_replace_all_uses():
+ @tvm.script.ir_module
+ class Lv0To1:
+ @R.function
+ def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor((32, 32),
"float32"):
+ # lv0 => lv1
+ # / \
+ # lv2 lv3
+ # \ /
+ # lv4
+ with R.dataflow():
+ lv0: R.Tensor((32, 32), "float32") = R.call_tir(
+ "my_relu", (x,), R.Tensor((32, 32), dtype="float32")
+ )
+ lv1: R.Tensor((32, 32), "float32") = R.call_tir(
+ "my_sigmoid", (x,), R.Tensor((32, 32), dtype="float32")
+ )
+ lv2: R.Tensor((32, 32), "float32") = R.call_tir(
+ "my_add", (x, lv0), R.Tensor((32, 32), dtype="float32")
+ )
+ lv3: R.Tensor((32, 32), "float32") = R.call_tir(
+ "my_mul", (x, lv0), R.Tensor((32, 32), dtype="float32")
+ )
+ lv4: R.Tensor((32, 32), "float32") = R.call_tir(
+ "my_whatever", (lv2, lv3), R.Tensor((32, 32),
dtype="float32")
+ )
+ R.output(lv4)
+ return lv4
+
+ root_fn = Lv0To1["main"]
+ dfb = root_fn.body.blocks[0]
+
+ n2binding = name_to_binding(root_fn)
+
+ rwt = DataflowBlockRewrite(dfb, root_fn)
+ rwt.replace_all_uses(n2binding["lv0"][0].var, n2binding["lv1"][0].var)
+ rwt.remove_unused(n2binding["lv0"][0].var)
+
+ assert_immutability(rwt, dfb, root_fn)
+
+ n2binding_after = name_to_binding(rwt.mutated_root_fn())
+ assert "lv0" not in n2binding_after
+
+
+def test_simple_module_update():
+ @tvm.script.ir_module
+ class Identity:
+ @R.function
+ def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
+ with R.dataflow():
+ lv0 = x
+ R.output(lv0)
+ return lv0
+
+ root_fn = Identity["main"]
+ dfb = root_fn.body.blocks[0]
+
+ rwt = DataflowBlockRewrite(dfb, root_fn)
+ rwt.add(name="tmp", expr=root_fn.params[0], is_dfvar=True)
+
+ new_ir = rwt.mutate_irmodule(Identity)
+
+ # immutatbility
+ assert new_ir != Identity
+ assert 2 == len(new_ir["main"].body.blocks[0].bindings)
+
+ @tvm.script.ir_module
+ class GroundTruth:
+ @R.function
+ def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
+ with R.dataflow():
+ lv0 = x
+ tmp: R.Tensor((32, 32), "float32") = x
+ R.output(lv0)
+ return lv0
+
+ tvm.ir.assert_structural_equal(new_ir, GroundTruth)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()