ganler commented on code in PR #14262:
URL: https://github.com/apache/tvm/pull/14262#discussion_r1141399888


##########
src/relax/transform/dead_code_elimination.cc:
##########
@@ -105,16 +104,82 @@ IRModule RemoveUnusedFunctions(IRModule mod_, 
Array<runtime::String> entry_funcs
   return mod_;
 }
 
-}  // namespace relax
+class DeadCodeEliminator : public ExprMutator {
+ private:
+  Expr VisitExpr_(const VarNode* op) final {
+    ICHECK(!used_vars_.empty());
+    used_vars_.back().insert(GetRef<Var>(op));
+    return GetRef<Expr>(op);
+  }
+
+  Expr VisitExpr_(const DataflowVarNode* op) final {
+    ICHECK(!used_vars_.empty());
+    used_vars_.back().insert(GetRef<Var>(op));
+    return GetRef<Expr>(op);
+  }
+
+  void VisitBinding_(const VarBindingNode* binding) { 
this->VisitExpr(binding->value); }
+
+  void VisitBinding_(const MatchCastNode* binding) {
+    this->VisitExpr(binding->value);
+    this->VisitAndCheckStructInfoFieldUnchanged(binding->struct_info);
+  }
+
+  BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) final {
+    // reverse scan the data flow block to find the used vars
+    used_vars_.push_back({});
+
+    std::vector<Binding> new_bindings;
+    for (auto rit = block->bindings.rbegin(); rit != block->bindings.rend(); 
rit++) {
+      const Var& var = (*rit)->var;
+      // only keep the used bindings or non dataflow var bindings
+      if (used_vars_.back().count(var) || !var->IsInstance<DataflowVarNode>()) 
{
+        new_bindings.push_back(*rit);
+        // collect the used vars
+        this->VisitBinding((*rit));
+      }
+    }
+
+    used_vars_.pop_back();
+    // reverse the bindings
+    std::reverse(new_bindings.begin(), new_bindings.end());
+    if (new_bindings.size() == block->bindings.size()) {
+      return GetRef<BindingBlock>(block);
+    } else {
+      auto n = make_object<DataflowBlockNode>(*block);
+      n->bindings = std::move(new_bindings);
+      return BindingBlock(n);
+    }
+  }
+
+  BindingBlock VisitBindingBlock_(const BindingBlockNode* block) final {
+    return GetRef<BindingBlock>(block);
+  }
+
+  std::vector<std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>> 
used_vars_{{}};
+};
+
+IRModule DeadCodeElimination(const IRModule& mod, Array<runtime::String> 
entry_functions) {
+  DeadCodeEliminator eliminator;
+  for (const auto& gv : mod->GetGlobalVars()) {
+    auto func = mod->Lookup(gv);
+    if (func->IsInstance<FunctionNode>()) {
+      mod->Update(gv, Downcast<Function>(eliminator.VisitExpr(func)));
+    }
+  }
+  return RemoveUnusedFunctions(mod, entry_functions);

Review Comment:
   Maybe do `dce_fn` -> `dce_var` -> `dce_fn` can optimize cases when a unused 
function a very complicated structure -- the current code still scans any 
functions in the first place.



##########
src/relax/transform/dead_code_elimination.cc:
##########
@@ -105,16 +104,82 @@ IRModule RemoveUnusedFunctions(IRModule mod_, 
Array<runtime::String> entry_funcs
   return mod_;
 }
 
-}  // namespace relax
+class DeadCodeEliminator : public ExprMutator {
+ private:
+  Expr VisitExpr_(const VarNode* op) final {
+    ICHECK(!used_vars_.empty());
+    used_vars_.back().insert(GetRef<Var>(op));
+    return GetRef<Expr>(op);
+  }
+
+  Expr VisitExpr_(const DataflowVarNode* op) final {
+    ICHECK(!used_vars_.empty());
+    used_vars_.back().insert(GetRef<Var>(op));
+    return GetRef<Expr>(op);
+  }
+
+  void VisitBinding_(const VarBindingNode* binding) { 
this->VisitExpr(binding->value); }
+
+  void VisitBinding_(const MatchCastNode* binding) {
+    this->VisitExpr(binding->value);
+    this->VisitAndCheckStructInfoFieldUnchanged(binding->struct_info);
+  }
+
+  BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) final {
+    // reverse scan the data flow block to find the used vars
+    used_vars_.push_back({});
+
+    std::vector<Binding> new_bindings;
+    for (auto rit = block->bindings.rbegin(); rit != block->bindings.rend(); 
rit++) {
+      const Var& var = (*rit)->var;
+      // only keep the used bindings or non dataflow var bindings
+      if (used_vars_.back().count(var) || !var->IsInstance<DataflowVarNode>()) 
{
+        new_bindings.push_back(*rit);
+        // collect the used vars
+        this->VisitBinding((*rit));
+      }
+    }
+
+    used_vars_.pop_back();
+    // reverse the bindings
+    std::reverse(new_bindings.begin(), new_bindings.end());
+    if (new_bindings.size() == block->bindings.size()) {
+      return GetRef<BindingBlock>(block);
+    } else {
+      auto n = make_object<DataflowBlockNode>(*block);
+      n->bindings = std::move(new_bindings);
+      return BindingBlock(n);
+    }
+  }
+
+  BindingBlock VisitBindingBlock_(const BindingBlockNode* block) final {
+    return GetRef<BindingBlock>(block);
+  }
+
+  std::vector<std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>> 
used_vars_{{}};

Review Comment:
   nit: `std::stack`



##########
src/relax/transform/dead_code_elimination.cc:
##########
@@ -105,16 +104,82 @@ IRModule RemoveUnusedFunctions(IRModule mod_, 
Array<runtime::String> entry_funcs
   return mod_;
 }
 
-}  // namespace relax
+class DeadCodeEliminator : public ExprMutator {
+ private:
+  Expr VisitExpr_(const VarNode* op) final {
+    ICHECK(!used_vars_.empty());
+    used_vars_.back().insert(GetRef<Var>(op));
+    return GetRef<Expr>(op);
+  }
+
+  Expr VisitExpr_(const DataflowVarNode* op) final {
+    ICHECK(!used_vars_.empty());
+    used_vars_.back().insert(GetRef<Var>(op));
+    return GetRef<Expr>(op);
+  }
+
+  void VisitBinding_(const VarBindingNode* binding) { 
this->VisitExpr(binding->value); }
+
+  void VisitBinding_(const MatchCastNode* binding) {
+    this->VisitExpr(binding->value);
+    this->VisitAndCheckStructInfoFieldUnchanged(binding->struct_info);
+  }
+
+  BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) final {
+    // reverse scan the data flow block to find the used vars
+    used_vars_.push_back({});
+
+    std::vector<Binding> new_bindings;
+    for (auto rit = block->bindings.rbegin(); rit != block->bindings.rend(); 
rit++) {
+      const Var& var = (*rit)->var;
+      // only keep the used bindings or non dataflow var bindings
+      if (used_vars_.back().count(var) || !var->IsInstance<DataflowVarNode>()) 
{
+        new_bindings.push_back(*rit);
+        // collect the used vars
+        this->VisitBinding((*rit));
+      }
+    }
+
+    used_vars_.pop_back();
+    // reverse the bindings
+    std::reverse(new_bindings.begin(), new_bindings.end());
+    if (new_bindings.size() == block->bindings.size()) {
+      return GetRef<BindingBlock>(block);
+    } else {
+      auto n = make_object<DataflowBlockNode>(*block);
+      n->bindings = std::move(new_bindings);
+      return BindingBlock(n);
+    }
+  }
+
+  BindingBlock VisitBindingBlock_(const BindingBlockNode* block) final {
+    return GetRef<BindingBlock>(block);
+  }
+
+  std::vector<std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>> 
used_vars_{{}};

Review Comment:
   curious why it is initialized with a vector including an empty set? (i.e., 
`{{}}.size() == 1`). If so, `used_vars_` being a stack and letting `pop_back` 
pdom `push_back`, `ICHECK(!used_vars_.empty());` will never be false, right?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to