This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity-staging
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit b3cf43af34e1d90b3fc841c2c2a12d36073a6159
Author: Jiawei Liu <[email protected]>
AuthorDate: Sat Feb 18 18:13:57 2023 -0600

    [Unity] Statement rewriter for DataflowBlock (#14043)
    
    This PR implements a few APIs to quickly perform statement-level mutation:
    `add`/`remove_unused`/`remove_all_unused`/`replace_all_uses`.
    It also implements `remove_all_unused` to remove dead statements inside 
`DataflowBlock`.
---
 include/tvm/relax/analysis.h               |  24 +++
 include/tvm/relax/binding_rewrite.h        | 115 ++++++++++
 include/tvm/relax/utils.h                  |   1 +
 python/tvm/relax/analysis/analysis.py      |  23 +-
 python/tvm/relax/binding_rewrite.py        | 155 +++++++++++++
 src/relax/ir/binding_rewrite.cc            | 324 ++++++++++++++++++++++++++++
 tests/python/relax/test_analysis.py        | 118 +++++++++-
 tests/python/relax/test_binding_rewrite.py | 334 +++++++++++++++++++++++++++++
 8 files changed, 1092 insertions(+), 2 deletions(-)

diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h
index 32e1582134..b9866577e9 100644
--- a/include/tvm/relax/analysis.h
+++ b/include/tvm/relax/analysis.h
@@ -341,6 +341,14 @@ TVM_DLL Map<Var, Expr> AnalyzeVar2Value(const Expr& expr);
  */
 TVM_DLL Map<Var, Expr> AnalyzeVar2Value(const DataflowBlock& dfb);
 
+/*!
+ * \brief Return a mapping from variable name to its Bindings.
+ *
+ * \param fn The function to be analyzed.
+ * \return A mapping from variable name to its Bindings.
+ */
+TVM_DLL Map<String, Array<Binding>> NameToBinding(const Function& fn);
+
 /*!
  * \brief Get the use-def chain of variables inside a dataflow block.
  *
@@ -349,6 +357,22 @@ TVM_DLL Map<Var, Expr> AnalyzeVar2Value(const 
DataflowBlock& dfb);
  */
 TVM_DLL Map<Var, Array<Var>> DataflowBlockUseDef(const DataflowBlock& dfb);
 
+/*!
+ * \brief Get the use-def chain of variables inside a function.
+ *
+ * \param fn The function to be analyzed.
+ * \return A map from variable definitions to a set of uses and variables 
needed by return value.
+ */
+std::pair<Map<Var, Array<Var>>, Array<Var>> FunctionUseDef(const Function& fn);
+
+/*!
+ * \brief Remove unused statements inside DataflowBlocks.
+ *
+ * \param fn The function to remove unused statements.
+ * \return The function that contains no unused statements in DataflowBlock.
+ */
+TVM_DLL Function RemoveAllUnused(const Function fn);
+
 /*!
  * \brief Annotate Op Pattern Kind for PrimFunc, which is used in relax 
FuseOps.
  *
diff --git a/include/tvm/relax/binding_rewrite.h 
b/include/tvm/relax/binding_rewrite.h
new file mode 100644
index 0000000000..a4b534965a
--- /dev/null
+++ b/include/tvm/relax/binding_rewrite.h
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relax/binding_rewrite.h
+ * \brief An IR rewriter to easily add/remove/replace bindings (statements).
+ */
+
+#ifndef TVM_RELAX_BINDING_REWRITE_H_
+
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/expr.h>
+#include <tvm/relax/utils.h>
+
+#include <map>
+#include <set>
+#include <type_traits>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace relax {
+
+/*! \brief Statement rewriter for relax.DataflowBlock. */
+class DataflowBlockRewriteNode : public Object {
+ public:
+  /*! \brief Replace all uses of old_var with new_var. */
+  void ReplaceAllUses(Var old_var, Var new_var);
+  /*! \brief Insert a Binding statement. */
+  void Add(Binding binding);
+  /*! \brief Insert an expression as VarBinding with variable name. */
+  void Add(String var_name, Expr expr, bool is_dfvar = false) {
+    auto var = is_dfvar ? DataflowVar(var_name, GetStructInfo(expr))  //
+                        : Var(var_name, GetStructInfo(expr));
+    Add(VarBinding(std::move(var), std::move(expr)));
+  }
+  /*! \brief Insert an expression as VarBinding with automatic variable name. 
*/
+  void Add(Expr expr, bool is_dfvar = false) {
+    Add(name_table_.GetUniqueName("tmp"), expr, is_dfvar);
+  }
+  /*! \brief Remove the definition statement of an unused variable. */
+  void RemoveUnused(Var unused, bool allow_undef = false);
+  /*! \brief Remove the definition statements of all unused variables. */
+  void RemoveAllUnused();
+
+  /*! \brief The rewritten dataflow block. */
+  DataflowBlock MutatedDataflowBlock() { return dfb_.value(); }
+  /*! \brief The rewritten function. */
+  Function MutatedFunc() { return root_fn_.value(); }
+  /*! \brief The rewritten IRModule. */
+  IRModule MutateIRModule(IRModule irmod);
+
+  /*! \brief Visit attributes. */
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("dfb", &dfb_);
+    v->Visit("root_fn", &root_fn_);
+  }
+
+  static constexpr const char* _type_key = "relax.DataflowBlockRewrite";
+  TVM_DECLARE_FINAL_OBJECT_INFO(DataflowBlockRewriteNode, Object);
+
+ protected:
+  friend class DataflowBlockRewrite;
+
+  Optional<DataflowBlock> dfb_;          //!< The rewritten dataflow block.
+  Optional<Function> root_fn_;           //!< The rewritten function.
+  const FunctionNode* original_fn_ptr_;  //!< Pointer to the original function.
+  Map<Var, Array<Var>> to_users_;        //!< Map from variable to its users.
+  Array<Var> fn_outputs_;                //!< Variables required by function 
outputs.
+
+ private:
+  NameTable name_table_;  //!< Name table for tracking and generating unique 
names.
+};
+
+/*!
+ * \brief A statement rewriter for relax.DataflowBlock.
+ * \sa DataflowBlockRewriteNode
+ */
+class DataflowBlockRewrite : public ObjectRef {
+ public:
+  TVM_DLL explicit DataflowBlockRewrite(DataflowBlock dfb, Function root_fn);
+
+  /*!
+   * \brief mutable accessor.
+   * \return mutable access pointer.
+   */
+  DataflowBlockRewriteNode* operator->() {
+    ICHECK(get() != nullptr);
+    return static_cast<DataflowBlockRewriteNode*>(get_mutable());
+  }
+
+  TVM_DEFINE_OBJECT_REF_METHODS(DataflowBlockRewrite, ObjectRef, 
DataflowBlockRewriteNode);
+};
+
+}  // namespace relax
+}  // namespace tvm
+
+#define TVM_RELAX_BINDING_REWRITE_H_
+#endif  // TVM_RELAX_BINDING_REWRITE_H_
diff --git a/include/tvm/relax/utils.h b/include/tvm/relax/utils.h
index 1457a16427..c1d984a21a 100644
--- a/include/tvm/relax/utils.h
+++ b/include/tvm/relax/utils.h
@@ -25,6 +25,7 @@
 #define TVM_RELAX_UTILS_H_
 
 #include <tvm/ir/module.h>
+#include <tvm/relax/expr.h>
 #include <tvm/runtime/logging.h>
 
 #include <algorithm>
diff --git a/python/tvm/relax/analysis/analysis.py 
b/python/tvm/relax/analysis/analysis.py
index 45c5b6f962..ffcdaceb40 100644
--- a/python/tvm/relax/analysis/analysis.py
+++ b/python/tvm/relax/analysis/analysis.py
@@ -28,7 +28,7 @@ from tvm import tir
 from tvm import IRModule
 from tvm.relax.ty import Type
 from tvm.relax.struct_info import StructInfo, FuncStructInfo
-from tvm.relax.expr import DataflowBlock, Var, Expr, Function, Call
+from tvm.relax.expr import DataflowBlock, Var, Expr, Function, Call, Binding
 from . import _ffi_api
 
 
@@ -244,6 +244,27 @@ def udchain(dfb: DataflowBlock) -> Dict[Var, List[Var]]:
     return _ffi_api.udchain(dfb)  # type: ignore
 
 
+def name_to_binding(func: Function) -> Dict[str, List[Binding]]:
+    """Return a map from variable name to its bindings."""
+    return _ffi_api.name_to_binding(func)  # type: ignore
+
+
+def remove_all_unused(func: Function) -> Function:
+    """Remove all unused variables from the function.
+
+    Parameters
+    ----------
+    func : Function
+        The input function to be analyzed.
+
+    Returns
+    -------
+    Function
+        The function with unused variables removed.
+    """
+    return _ffi_api.remove_all_unused(func)  # type: ignore
+
+
 def well_formed(mod: IRModule, check_struct_info: bool = True) -> bool:
     """Check if the IRModule is well formed.
 
diff --git a/python/tvm/relax/binding_rewrite.py 
b/python/tvm/relax/binding_rewrite.py
new file mode 100644
index 0000000000..a9f6d878ad
--- /dev/null
+++ b/python/tvm/relax/binding_rewrite.py
@@ -0,0 +1,155 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=no-else-return, invalid-name
+"""Developer API of add/remove/replace bindings in Relax."""
+
+from typing import Optional
+
+import tvm
+import tvm._ffi
+from tvm.runtime import Object
+from . import Binding, DataflowBlock, Expr, Function, Var
+from . import _ffi_api
+
+
+@tvm._ffi.register_object("relax.DataflowBlockRewrite")
+class DataflowBlockRewrite(Object):
+    """
+    A binding/statement-level dataflow block rewriter.
+
+    Notes
+    -----
+    Due to the immutable and copy-on-write nature of TVM AST nodes, the 
rewriting is not done in
+    place. Instead, a new DataflowBlock is created and returned with 
mutated_dfb. Similarly, its new
+    root Function is created and returned by mutated_root_fn. To apply this 
change for an IRModule,
+    use mutate_irmodule which rewrites the old function that registered in the 
constructor.
+    """
+
+    def __init__(self, dfb: DataflowBlock, root_fn: Function):
+        """
+        Construct a rewriter with the DataflowBlock to rewrite and its root 
function.
+
+        Parameters
+        ----------
+        dfb : DataflowBlock
+            The DataflowBlock to rewrite.
+        root_fn : Function
+            The root function of the DataflowBlock.
+        """
+        self.func_name = root_fn.__name__ if hasattr(root_fn, "__name__") else 
None
+        self.__init_handle_by_constructor__(
+            _ffi_api.DataflowBlockRewrite, dfb, root_fn  # type: ignore
+        )
+
+    def replace_all_uses(self, old_var: Var, new_var: Var) -> None:
+        """
+        Replace all uses of old_var with new_var.
+
+        Parameters
+        ----------
+        old_var : Var
+            The old variable to replace.
+        new_var : Var
+            The new variable to replace with.
+        """
+        _ffi_api.dfb_rewrite_replace_all_uses(self, old_var, new_var)  # type: 
ignore
+
+    def add_binding(self, binding: Binding) -> None:
+        return _ffi_api.dfb_rewrite_add_binding(self, binding)  # type: ignore
+
+    def add(self, expr: Expr, name: Optional[str] = None, is_dfvar: bool = 
False) -> None:
+        """
+        Add a new statement to the DataflowBlock with an automatically 
generated variable name.
+
+        Parameters
+        ----------
+        expr : Expr
+            The expression to add.
+        name : Optional[str], optional
+            Variable name, by default None
+        is_dfvar : bool, optional
+            The variable type, by default False
+
+        Notes
+        -----
+        If the variable name is not given, it will be automatically generated 
in a form of
+        "tmp${COUNTER}". The variable type will be DataflowVar if is_dfvar is 
True, otherwise
+        it will be Var. Being Var means the variables are output variables of 
the DataflowBlock.
+        While being DataflowVar means the variables are internal variables of 
the DataflowBlock.
+        """
+        _ffi_api.dfb_rewrite_add(self, expr, name, is_dfvar)  # type: ignore
+
+    def remove_unused(self, var: Var, allow_undef=False) -> None:
+        """
+        Remove a statement by its variable definition if and only if it is 
unused.
+
+        Parameters
+        ----------
+        var : Var
+            The unused variable definition.
+        allow_undef : bool, optional
+            Whether to allow var being undefined variable, by default False
+
+        Raises
+        ------
+        TVMError if the variable is used or undefined (allow_undef=False).
+        """
+        _ffi_api.dfb_rewrite_remove_unused(self, var, allow_undef)  # type: 
ignore
+
+    def remove_all_unused(self) -> None:
+        """
+        Remove all unused variables.
+
+        Notes
+        -----
+        This could remove unused variables in other DataflowBlocks as well.
+        """
+        _ffi_api.dfb_rewrite_remove_all_unused(self)  # type: ignore
+
+    def mutated_dfb(self) -> DataflowBlock:
+        """
+        Returns the mutated DataflowBlock.
+        """
+        return self.dfb
+
+    def mutated_root_fn(self) -> Function:
+        """
+        Returns the mutated root function.
+        """
+        ret = self.root_fn
+        if self.func_name:
+            ret.__name__ = self.func_name
+        return ret
+
+    def mutate_irmodule(self, irmodule: tvm.IRModule) -> tvm.IRModule:
+        """
+        Return an updated IRModule by replacing the old function with the 
mutated root function.
+
+        Parameters
+        ----------
+        irmodule : tvm.IRModule
+            The base IRModule to update.
+
+        Returns
+        -------
+        tvm.IRModule
+            The updated IRModule.
+        """
+        ret = _ffi_api.dfb_rewrite_mutate_irmodule(self, irmodule)  # type: 
ignore
+        if hasattr(irmodule, "__name__"):
+            ret.__name__ = irmodule.__name__
+        return ret
diff --git a/src/relax/ir/binding_rewrite.cc b/src/relax/ir/binding_rewrite.cc
new file mode 100644
index 0000000000..dd9fac9fdc
--- /dev/null
+++ b/src/relax/ir/binding_rewrite.cc
@@ -0,0 +1,324 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relax/ir/binding_rewrite.cc
+ * \brief Implementation of binding rewriters.
+ */
+
+#include <tvm/relax/binding_rewrite.h>
+#include <tvm/relax/block_builder.h>
+#include <tvm/relax/expr.h>
+#include <tvm/relax/expr_functor.h>
+
+#include <functional>
+#include <iterator>
+
+namespace tvm {
+namespace relax {
+
+TVM_REGISTER_NODE_TYPE(DataflowBlockRewriteNode);
+DataflowBlockRewrite::DataflowBlockRewrite(DataflowBlock dfb, Function 
root_fn) {
+  auto n = make_object<DataflowBlockRewriteNode>();
+  n->dfb_ = dfb;
+  n->root_fn_ = root_fn;
+  n->original_fn_ptr_ = root_fn.get();
+  auto p = FunctionUseDef(root_fn);
+  n->to_users_ = std::move(p.first);
+  n->fn_outputs_ = std::move(p.second);
+  n->name_table_ = NameTable(n->to_users_.begin(), n->to_users_.end(),
+                             [](const auto& p) { return p.first->name_hint(); 
});
+
+  data_ = std::move(n);
+}
+
+TVM_REGISTER_GLOBAL("relax.DataflowBlockRewrite")
+    .set_body_typed([](DataflowBlock dfb, Function root_fn) {
+      return DataflowBlockRewrite(dfb, root_fn);
+    });
+
+void DataflowBlockRewriteNode::ReplaceAllUses(Var old_var, Var new_var) {
+  class ReplaceAllUsePass : public ExprMutator {
+    Var old_var, new_var;
+    const DataflowBlockNode* const to_catch;
+
+   public:
+    const DataflowBlockNode* caught = nullptr;
+
+    ReplaceAllUsePass(Var old_var, Var new_var, const DataflowBlockNode* 
to_catch)
+        : old_var(old_var), new_var(new_var), to_catch(to_catch) {}
+
+    using ExprMutator::VisitExpr_;
+
+    Expr VisitExpr_(const VarNode* op) override {
+      return (op == old_var.get()) ? new_var : GetRef<Expr>(op);
+    }
+
+    Expr VisitExpr_(const DataflowVarNode* op) override {
+      return (op == old_var.get()) ? new_var : GetRef<Expr>(op);
+    }
+
+    BindingBlock VisitBindingBlock_(const DataflowBlockNode* op) override {
+      BindingBlock res = ExprMutator::VisitBindingBlock_(op);
+      if (op == to_catch) caught = static_cast<const 
DataflowBlockNode*>(res.get());
+      return res;
+    }
+  };
+
+  ICHECK(to_users_.find(old_var) != to_users_.end()) << "Cannot find " << 
old_var;
+  ICHECK(to_users_.find(new_var) != to_users_.end()) << "Cannot find " << 
new_var;
+
+  // replace uses inside the DataflowBlock.
+  ReplaceAllUsePass replacer(old_var, new_var, dfb_.get());
+  root_fn_ = Downcast<Function>(replacer.VisitExpr_(root_fn_.get()));
+  dfb_ = GetRef<DataflowBlock>(replacer.caught);
+
+  // update udchain
+  // old_var -> old_var users | changed to {}
+  // new_var -> {?}           | changed to old_var users
+  for (Var user : to_users_[old_var]) {
+    auto new_var_uses = to_users_[new_var];
+    if (new_var_uses.end() == std::find(new_var_uses.begin(), 
new_var_uses.end(), user)) {
+      new_var_uses.push_back(user);
+    }
+  }
+
+  to_users_.Set(old_var, {});
+
+  auto it_old_output = std::find(fn_outputs_.begin(), fn_outputs_.end(), 
old_var);
+  if (it_old_output != fn_outputs_.end()) {
+    fn_outputs_.Set(std::distance(fn_outputs_.begin(), it_old_output), 
new_var);
+  }
+}
+
+TVM_REGISTER_GLOBAL("relax.dfb_rewrite_replace_all_uses")
+    .set_body_typed([](DataflowBlockRewrite rwt, Var old_var, Var new_var) {
+      rwt->ReplaceAllUses(old_var, new_var);
+    });
+
+class UpdateDFB : public ExprMutator {
+ private:
+  DataflowBlock old_dfb, new_dfb;
+
+ public:
+  UpdateDFB(DataflowBlock old_dfb, DataflowBlock new_dfb)
+      : old_dfb(std::move(old_dfb)), new_dfb(std::move(new_dfb)) {}
+
+  BindingBlock VisitBindingBlock_(const DataflowBlockNode* op) override {
+    return old_dfb.get() == op ? new_dfb : old_dfb;
+  }
+};
+
+void DataflowBlockRewriteNode::Add(Binding binding) {
+  auto p = [binding] {
+    if (auto vb = binding.as<VarBindingNode>()) {
+      return std::make_pair(vb->var, vb->value);
+    } else if (auto mc = binding.as<MatchCastNode>()) {
+      return std::make_pair(mc->var, mc->value);
+    }
+    LOG(FATAL) << "Unsupported binding type";
+    return std::make_pair(Var{}, Expr{});
+  }();
+  Var var = p.first;
+  Expr val = p.second;
+
+  ICHECK(0 == to_users_.count(var)) << var << " has been defined so cannot be 
added.";
+
+  // Add this VarBinding statement after the definition of uses.
+  std::set<const VarNode*> used_vars = [val] {
+    class UsedVars : public ExprVisitor {
+     public:
+      std::set<const VarNode*> used_vars;
+      void VisitExpr_(const VarNode* op) override { used_vars.insert(op); }
+      void VisitExpr_(const DataflowVarNode* op) override { 
used_vars.insert(op); }
+    } uvar{};
+    uvar.VisitExpr(val);
+    return std::move(uvar.used_vars);
+  }();
+
+  size_t line_last_req_def = 0;
+  for (size_t i = 0; i < dfb_.value()->bindings.size(); ++i) {
+    auto line = dfb_.value()->bindings[i];
+    if (used_vars.find(line->var.get()) != used_vars.cend()) line_last_req_def 
= i;
+  }
+
+  auto old_dfb = dfb_.value();
+
+  dfb_ = [old_dfb, binding, line_last_req_def, this] {
+    auto new_dfb = dfb_.value();
+    new_dfb.CopyOnWrite()->bindings.insert(dfb_.value()->bindings.begin() + 1 
+ line_last_req_def,
+                                           binding);
+    return new_dfb;
+  }();
+
+  auto updater = UpdateDFB(old_dfb, dfb_.value());
+  root_fn_ = Downcast<Function>(updater.VisitExpr_(root_fn_.get()));
+
+  for (const VarNode* v : used_vars) 
to_users_.Get(GetRef<Var>(v)).value().push_back(var);
+}
+
+TVM_REGISTER_GLOBAL("relax.dfb_rewrite_add_binding")
+    .set_body_typed([](DataflowBlockRewrite rwt, Binding vb) { rwt->Add(vb); 
});
+
+TVM_REGISTER_GLOBAL("relax.dfb_rewrite_add")
+    .set_body_typed([](DataflowBlockRewrite rwt, Expr expr, Optional<String> 
name, bool is_dfvar) {
+      if (name.get()) {
+        rwt->Add(name.value(), expr, is_dfvar);
+      } else {
+        rwt->Add(expr, is_dfvar);
+      }
+    });
+
+class RemoveUnusedVars : public ExprMutator {
+ public:
+  std::set<Var> unused_vars;
+  Optional<DataflowBlock> caught_rewrite = NullOpt;
+
+  RemoveUnusedVars(Map<Var, Array<Var>> users, Array<Var> fn_outputs)
+      : unused_vars([&] {
+          std::vector<Var> unused;
+
+          // iterative dataflow algorithm.
+          size_t prev_size;
+          do {
+            prev_size = unused.size();
+
+            std::vector<Var> used;
+            used.reserve(users.size());
+            for (const auto& kv : users) {
+              // var -> [users...]
+              // var is unused iff
+              //   user -> empty
+              //   var is not output var
+              if (kv.second.empty() &&  // kv.first is not used by fn outputs.
+                  fn_outputs.end() == std::find(fn_outputs.begin(), 
fn_outputs.end(), kv.first)) {
+                unused.push_back(kv.first);
+              } else {
+                used.push_back(kv.first);
+              }
+            }
+
+            for (size_t i = prev_size; i < unused.size(); ++i) {
+              users.erase(unused[i]);
+              // remove def site.
+              for (const auto& used_var : used) {
+                ICHECK(users.count(used_var));
+                Array<Var> var_users = users[used_var];
+                // remove the unused var from the use site.
+                auto it = std::find(var_users.begin(), var_users.end(), 
unused[i]);
+                if (it != var_users.end()) {
+                  var_users.erase(it);
+                  users.Set(used_var, std::move(var_users));
+                }
+              }
+            }
+          } while (prev_size != unused.size());  // changed? => continue.
+
+          return std::set<Var>(unused.begin(), unused.end());
+        }()) {}
+
+  RemoveUnusedVars(std::pair<Map<Var, Array<Var>>, Array<Var>> 
users_and_outputs)
+      : RemoveUnusedVars(std::move(users_and_outputs.first), 
std::move(users_and_outputs.second)) {}
+  RemoveUnusedVars(Function fn) : RemoveUnusedVars(FunctionUseDef(fn)) {}
+  RemoveUnusedVars(std::set<Var> unused_vars) : 
unused_vars(std::move(unused_vars)) {}
+
+  BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) {
+    auto prev_dfb = GetRef<DataflowBlock>(block);
+    builder_->BeginDataflowBlock();
+    for (Binding binding : block->bindings) {
+      if (!unused_vars.count(binding->var)) {
+        VisitBinding(binding);
+      }
+    }
+    auto new_dfb = builder_->EndBlock();
+    if (caught_rewrite == prev_dfb) caught_rewrite = 
Downcast<DataflowBlock>(new_dfb);
+    return std::move(new_dfb);
+  }
+};
+
+void DataflowBlockRewriteNode::RemoveUnused(Var unused, bool allow_undef) {
+  // first need to check if this var is used.
+  if (0 == to_users_.count(unused)) {  // no def.
+    if (allow_undef) return;
+    LOG(FATAL) << unused << " undefined. Set allow_undef=True to allow 
'removing' undefined var";
+  }
+
+  ICHECK(to_users_[unused].empty())
+      << unused << " is used by " << to_users_[unused].size() << " vars";
+
+  auto old_dfb = dfb_.value();
+
+  RemoveUnusedVars remover({unused});
+  dfb_ = Downcast<DataflowBlock>(remover.VisitBindingBlock_(old_dfb.get()));
+
+  auto updater = UpdateDFB(old_dfb, dfb_.value());
+  root_fn_ = Downcast<Function>(updater.VisitExpr_(root_fn_.get()));
+
+  to_users_.erase(unused);  // update use-def chain.
+}
+
+TVM_REGISTER_GLOBAL("relax.dfb_rewrite_remove_unused")
+    .set_body_typed([](DataflowBlockRewrite rwt, Var unused, bool allow_undef) 
{
+      rwt->RemoveUnused(unused, allow_undef);
+    });
+
+void DataflowBlockRewriteNode::RemoveAllUnused() {
+  RemoveUnusedVars remover(to_users_, fn_outputs_);
+  remover.caught_rewrite = dfb_.value();
+
+  // this could also clean unused variables in other DataflowBlock.
+  root_fn_ = Downcast<Function>(remover.VisitExpr_(root_fn_.get()));
+
+  // DataflowBlock could be None.
+  dfb_ = remover.caught_rewrite.value();
+
+  // clean up use-def chain.
+  for (const auto& unused : remover.unused_vars) to_users_.erase(unused);
+}
+
+TVM_REGISTER_GLOBAL("relax.dfb_rewrite_remove_all_unused")
+    .set_body_typed([](DataflowBlockRewrite rwt) { rwt->RemoveAllUnused(); });
+
+Function RemoveAllUnused(Function fn) {
+  RemoveUnusedVars remover(fn);
+  return Downcast<Function>(remover.VisitExpr_(fn.get()));
+}
+
+TVM_REGISTER_GLOBAL("relax.analysis.remove_all_unused").set_body_typed(RemoveAllUnused);
+
+IRModule DataflowBlockRewriteNode::MutateIRModule(IRModule irmod) {
+  BlockBuilder builder = BlockBuilder::Create(irmod);
+
+  for (auto& p : irmod->functions) {
+    if (original_fn_ptr_ == p.second.get()) {
+      builder->UpdateFunction(p.first, root_fn_.value());
+      break;
+    }
+  }
+
+  return builder->GetContextIRModule();
+}
+
+TVM_REGISTER_GLOBAL("relax.dfb_rewrite_mutate_irmodule")
+    .set_body_typed([](DataflowBlockRewrite rwt, IRModule irmod) {
+      return rwt->MutateIRModule(irmod);
+    });
+
+}  // namespace relax
+}  // namespace tvm
diff --git a/tests/python/relax/test_analysis.py 
b/tests/python/relax/test_analysis.py
index 43558a52be..e939d2b208 100644
--- a/tests/python/relax/test_analysis.py
+++ b/tests/python/relax/test_analysis.py
@@ -21,7 +21,7 @@ import tvm
 import tvm.testing
 from tvm import tir
 from tvm import relax as rx
-from tvm.relax.analysis import has_reshape_pattern, udchain
+from tvm.relax.analysis import has_reshape_pattern, udchain, 
remove_all_unused, name_to_binding
 from tvm.script import relax as R, tir as T
 
 
@@ -46,6 +46,122 @@ def test_use_def():
     assert set(udc[gv0]) == set()
 
 
+def test_chained_remove_all_unused():
+    @tvm.script.ir_module
+    class IdentityUnused:
+        @R.function
+        def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
+            with R.dataflow():
+                lv0 = x
+                unused0 = R.call_tir("my_sigmoid", (x,), R.Tensor((32, 32), 
dtype="float32"))
+                unused1 = R.call_tir("my_sigmoid", (unused0,), R.Tensor((32, 
32), dtype="float32"))
+                R.output(lv0)
+            return lv0
+
+    optimized = remove_all_unused(IdentityUnused["main"])
+
+    @tvm.script.ir_module
+    class GroundTruth:
+        @R.function
+        def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
+            with R.dataflow():
+                lv0 = x
+                R.output(lv0)
+            return lv0
+
+    tvm.ir.assert_structural_equal(optimized, GroundTruth["main"])
+
+
+def test_binding_block_remove_all_unused():
+    @tvm.script.ir_module
+    class IdentityUnused:
+        @R.function
+        def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
+            with R.dataflow():
+                lv0 = x
+                unused0 = R.call_tir("my_sigmoid", (x,), R.Tensor((32, 32), 
dtype="float32"))
+                unused1 = R.call_tir("my_sigmoid", (unused0,), R.Tensor((32, 
32), dtype="float32"))
+                R.output(lv0)
+            z = R.call_packed("vm.builtin.copy", lv0, 
sinfo_args=(R.Tensor((32, 32), "float32")))
+            return z
+
+    optimized = remove_all_unused(IdentityUnused["main"])
+
+    @tvm.script.ir_module
+    class GroundTruth:
+        @R.function
+        def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
+            with R.dataflow():
+                lv0 = x
+                R.output(lv0)
+            z = R.call_packed("vm.builtin.copy", lv0, 
sinfo_args=(R.Tensor((32, 32), "float32")))
+            return z
+
+    tvm.ir.assert_structural_equal(optimized, GroundTruth["main"])
+
+
+def test_binding_block_fake_unused_remove_all_unused():
+    @tvm.script.ir_module
+    class IdentityUnused:
+        @R.function
+        def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
+            with R.dataflow():
+                lv0 = x
+                R.output(lv0)
+            z = R.call_packed("vm.builtin.copy", lv0, 
sinfo_args=(R.Tensor((32, 32), "float32")))
+            return lv0
+
+    optimized = remove_all_unused(IdentityUnused["main"])
+
+    @tvm.script.ir_module
+    class GroundTruth:
+        @R.function
+        def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
+            with R.dataflow():
+                lv0 = x
+                R.output(lv0)
+            # This might bring side effect so cannot be removed.
+            z = R.call_packed("vm.builtin.copy", lv0, 
sinfo_args=(R.Tensor((32, 32), "float32")))
+            return lv0
+
+    tvm.ir.assert_structural_equal(optimized, GroundTruth["main"])
+
+
+def test_edge_binding_block_fake_unused_remove_all_unused():
+    @tvm.script.ir_module
+    class IdentityUnused:
+        @R.function
+        def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor((32, 32), 
"float32"):
+            z = R.call_packed("vm.builtin.copy", x, sinfo_args=(R.Tensor((32, 
32), "float32")))
+            return x
+
+    optimized = remove_all_unused(IdentityUnused["main"])
+    tvm.ir.assert_structural_equal(optimized, IdentityUnused["main"])
+
+
+def test_name_to_binding_var_shadowing():
+    @R.function
+    def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
+        with R.dataflow():
+            lv0 = x
+            lv1 = lv0
+            R.output(lv1)
+
+        with R.dataflow():
+            lv0 = lv1  # shadowing
+            lv2 = lv0
+            R.output(lv2)
+        return lv2
+
+    n2binding = name_to_binding(main)
+
+    assert "lv0" in n2binding
+    assert "lv1" in n2binding
+    assert "lv2" in n2binding
+
+    assert len(n2binding["lv0"]) == 2
+
+
 def test_reshape_pattern_reshape():
     @T.prim_func
     def reshape(
diff --git a/tests/python/relax/test_binding_rewrite.py 
b/tests/python/relax/test_binding_rewrite.py
new file mode 100644
index 0000000000..1b424b9792
--- /dev/null
+++ b/tests/python/relax/test_binding_rewrite.py
@@ -0,0 +1,334 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+import tvm
+import tvm.testing
+from tvm._ffi.base import TVMError
+from tvm.relax.analysis import name_to_binding
+from tvm.relax.binding_rewrite import DataflowBlockRewrite
+from tvm.relax.expr import DataflowVar, Var
+from tvm.script import relax as R
+
+
[email protected]_module
+class Identity:
+    @R.function
+    def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
+        with R.dataflow():
+            lv0 = x
+            R.output(lv0)
+        return lv0
+
+
+def assert_immutability(rwt, original_dfb, original_root_fn):
+    assert rwt.mutated_dfb() != original_dfb
+    assert rwt.mutated_root_fn() != original_root_fn
+    assert rwt.mutated_root_fn().body.blocks[0] != original_dfb
+    assert rwt.mutated_root_fn().body.blocks[0] == rwt.mutated_dfb()
+
+
+def test_null_construct():
+    root_fn = Identity["main"]
+    dfb = root_fn.body.blocks[0]
+
+    DataflowBlockRewrite(dfb, root_fn)
+
+
+def test_simple_add():
+    root_fn = Identity["main"]
+    dfb = root_fn.body.blocks[0]
+
+    rwt = DataflowBlockRewrite(dfb, root_fn)
+    rwt.add(name="tmp", expr=Identity["main"].params[0], is_dfvar=True)
+
+    assert_immutability(rwt, dfb, root_fn)
+
+    # check "tmp" added
+    assert "tmp" in name_to_binding(rwt.mutated_root_fn())
+
+    @tvm.script.ir_module
+    class GroundTruth:
+        @R.function
+        def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
+            with R.dataflow():
+                lv0 = x
+                tmp: R.Tensor((32, 32), "float32") = x
+                R.output(lv0)
+            return lv0
+
+    tvm.ir.assert_structural_equal(rwt.mutated_root_fn(), GroundTruth["main"])
+
+
+def test_simple_auto_add_var():
+    root_fn = Identity["main"]
+    dfb = root_fn.body.blocks[0]
+
+    rwt = DataflowBlockRewrite(dfb, root_fn)
+    rwt.add(root_fn.params[0], is_dfvar=False)
+
+    assert isinstance(rwt.mutated_dfb().bindings[-1].var, Var)
+
+    assert_immutability(rwt, dfb, root_fn)
+
+
+def test_simple_auto_add_dfvar():
+    root_fn = Identity["main"]
+    dfb = root_fn.body.blocks[0]
+
+    rwt = DataflowBlockRewrite(dfb, root_fn)
+    rwt.add(root_fn.params[0], is_dfvar=True)
+
+    assert isinstance(rwt.mutated_dfb().bindings[-1].var, DataflowVar)
+
+    # immutatbility
+    assert_immutability(rwt, dfb, root_fn)
+
+
+def test_simple_remove_unused():
+    @tvm.script.ir_module
+    class IdentityUnused:
+        @R.function
+        def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
+            with R.dataflow():
+                lv0 = x
+                unused = lv0
+                R.output(lv0)
+            return lv0
+
+    root_fn = IdentityUnused["main"]
+    dfb = root_fn.body.blocks[0]
+
+    n2binding = name_to_binding(IdentityUnused["main"])
+
+    rwt = DataflowBlockRewrite(dfb, root_fn)
+    rwt.remove_unused(n2binding["unused"][0].var)
+
+    assert_immutability(rwt, dfb, root_fn)
+
+    # check "unused" removed
+    assert "unused" not in name_to_binding(rwt.mutated_root_fn())
+
+    @tvm.script.ir_module
+    class GroundTruth:
+        @R.function
+        def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
+            with R.dataflow():
+                lv0 = x
+                R.output(lv0)
+            return lv0
+
+    tvm.ir.assert_structural_equal(rwt.mutated_root_fn(), GroundTruth["main"])
+
+
+def test_remove_unused_undef():
+    root_fn = Identity["main"]
+    dfb = root_fn.body.blocks[0]
+
+    with pytest.raises(TVMError):
+        rwt = DataflowBlockRewrite(dfb, root_fn)
+        rwt.remove_unused(Var("whatever"))
+
+    rwt = DataflowBlockRewrite(dfb, root_fn)
+    rwt.remove_unused(Var("whatever"), allow_undef=True)
+
+    assert root_fn == rwt.mutated_root_fn()
+
+
+def test_simple_rm_all_unused():
+    @tvm.script.ir_module
+    class IdentityUnused:
+        @R.function
+        def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
+            with R.dataflow():
+                lv0 = x
+                unused0 = lv0
+                unused1 = lv0
+                R.output(lv0)
+            return lv0
+
+    root_fn = IdentityUnused["main"]
+    dfb = root_fn.body.blocks[0]
+
+    rwt = DataflowBlockRewrite(dfb, root_fn)
+    rwt.remove_all_unused()
+
+    @tvm.script.ir_module
+    class GroundTruth:
+        @R.function
+        def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
+            with R.dataflow():
+                lv0 = x
+                R.output(lv0)
+            return lv0
+
+    tvm.ir.assert_structural_equal(rwt.mutated_root_fn(), GroundTruth["main"])
+
+
[email protected]_module
+class DeadDFBlock:
+    @R.function
+    def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor((32, 32), 
"float32"):
+        with R.dataflow():
+            lv0 = x
+            R.output(lv0)
+        return x
+
+
+def test_empty_dfb_after_removal():
+    root_fn = DeadDFBlock["main"]
+    dfb = root_fn.body.blocks[0]
+
+    rwt = DataflowBlockRewrite(dfb, root_fn)
+    rwt.remove_unused(DeadDFBlock["main"].body.blocks[0].bindings[0].var)
+
+    @tvm.script.ir_module
+    class GroundTruth:
+        @R.function
+        def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor((32, 32), 
"float32"):
+            return x
+
+    tvm.ir.assert_structural_equal(rwt.mutated_root_fn(), GroundTruth["main"])
+
+
+def test_empty_dfb_after_all_removal():
+    dfb = DeadDFBlock["main"].body.blocks[0]
+    root_fn = DeadDFBlock["main"]
+
+    rwt = DataflowBlockRewrite(dfb, root_fn)
+    rwt.remove_all_unused()
+
+    @tvm.script.ir_module
+    class GroundTruth:
+        @R.function
+        def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor((32, 32), 
"float32"):
+            return x
+
+    tvm.ir.assert_structural_equal(rwt.mutated_root_fn(), GroundTruth["main"])
+
+
+def test_chained_rm_all_unused():
+    @tvm.script.ir_module
+    class IdentityChainedUnused:
+        @R.function
+        def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
+            with R.dataflow():
+                lv0 = x
+                unused0 = R.call_tir("my_sigmoid", (x,), R.Tensor((32, 32), 
dtype="float32"))
+                unused1 = R.call_tir("my_sigmoid", (unused0,), R.Tensor((32, 
32), dtype="float32"))
+                R.output(lv0)
+            return lv0
+
+    root_fn = IdentityChainedUnused["main"]
+    dfb = root_fn.body.blocks[0]
+
+    rwt = DataflowBlockRewrite(dfb, root_fn)
+    rwt.remove_all_unused()
+
+    @tvm.script.ir_module
+    class GroundTruth:
+        @R.function
+        def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
+            with R.dataflow():
+                lv0 = x
+                R.output(lv0)
+            return lv0
+
+    tvm.ir.assert_structural_equal(rwt.mutated_root_fn(), GroundTruth["main"])
+
+
+def test_simple_replace_all_uses():
+    @tvm.script.ir_module
+    class Lv0To1:
+        @R.function
+        def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor((32, 32), 
"float32"):
+            #   lv0 => lv1
+            #  /   \
+            # lv2  lv3
+            #  \   /
+            #   lv4
+            with R.dataflow():
+                lv0: R.Tensor((32, 32), "float32") = R.call_tir(
+                    "my_relu", (x,), R.Tensor((32, 32), dtype="float32")
+                )
+                lv1: R.Tensor((32, 32), "float32") = R.call_tir(
+                    "my_sigmoid", (x,), R.Tensor((32, 32), dtype="float32")
+                )
+                lv2: R.Tensor((32, 32), "float32") = R.call_tir(
+                    "my_add", (x, lv0), R.Tensor((32, 32), dtype="float32")
+                )
+                lv3: R.Tensor((32, 32), "float32") = R.call_tir(
+                    "my_mul", (x, lv0), R.Tensor((32, 32), dtype="float32")
+                )
+                lv4: R.Tensor((32, 32), "float32") = R.call_tir(
+                    "my_whatever", (lv2, lv3), R.Tensor((32, 32), 
dtype="float32")
+                )
+                R.output(lv4)
+            return lv4
+
+    root_fn = Lv0To1["main"]
+    dfb = root_fn.body.blocks[0]
+
+    n2binding = name_to_binding(root_fn)
+
+    rwt = DataflowBlockRewrite(dfb, root_fn)
+    rwt.replace_all_uses(n2binding["lv0"][0].var, n2binding["lv1"][0].var)
+    rwt.remove_unused(n2binding["lv0"][0].var)
+
+    assert_immutability(rwt, dfb, root_fn)
+
+    n2binding_after = name_to_binding(rwt.mutated_root_fn())
+    assert "lv0" not in n2binding_after
+
+
+def test_simple_module_update():
+    @tvm.script.ir_module
+    class Identity:
+        @R.function
+        def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
+            with R.dataflow():
+                lv0 = x
+                R.output(lv0)
+            return lv0
+
+    root_fn = Identity["main"]
+    dfb = root_fn.body.blocks[0]
+
+    rwt = DataflowBlockRewrite(dfb, root_fn)
+    rwt.add(name="tmp", expr=root_fn.params[0], is_dfvar=True)
+
+    new_ir = rwt.mutate_irmodule(Identity)
+
+    # immutatbility
+    assert new_ir != Identity
+    assert 2 == len(new_ir["main"].body.blocks[0].bindings)
+
+    @tvm.script.ir_module
+    class GroundTruth:
+        @R.function
+        def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
+            with R.dataflow():
+                lv0 = x
+                tmp: R.Tensor((32, 32), "float32") = x
+                R.output(lv0)
+            return lv0
+
+    tvm.ir.assert_structural_equal(new_ir, GroundTruth)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()


Reply via email to