This is an automated email from the ASF dual-hosted git repository.

sslyu pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 82586fb9ea [Unity] Extend EliminateCommonSubexpr to operate on 
relax::Expr (#15815)
82586fb9ea is described below

commit 82586fb9ea3156828d076f3b60d895d297b54729
Author: Eric Lunderberg <[email protected]>
AuthorDate: Wed Sep 27 13:37:14 2023 -0500

    [Unity] Extend EliminateCommonSubexpr to operate on relax::Expr (#15815)
    
    Prior to this commit, the `EliminateCommonSubexpr` utility could only
    apply to `relax::Function` instances.  This commit extends the allowed
    usage to apply to any `relax::Expr` that contains variable bindings.
    This is only included as an internal utility within the C++
    implementation, and is not currently exposed for external use.
---
 src/relax/transform/eliminate_common_subexpr.cc | 43 +++++++++++++------------
 src/relax/transform/utils.h                     | 14 ++++++++
 2 files changed, 37 insertions(+), 20 deletions(-)

diff --git a/src/relax/transform/eliminate_common_subexpr.cc 
b/src/relax/transform/eliminate_common_subexpr.cc
index 8bbb05f327..fa90d41933 100644
--- a/src/relax/transform/eliminate_common_subexpr.cc
+++ b/src/relax/transform/eliminate_common_subexpr.cc
@@ -28,6 +28,8 @@
 #include <tvm/relax/transform.h>
 #include <tvm/relax/utils.h>
 
+#include "utils.h"
+
 namespace tvm {
 namespace relax {
 
@@ -74,6 +76,12 @@ class ImpurityDetector : public ExprVisitor {
 
 class SubexprCounter : public ExprVisitor {
  public:
+  static std::unordered_map<Expr, int, StructuralHash, StructuralEqual> 
Count(const Expr& expr) {
+    SubexprCounter visitor;
+    visitor(expr);
+    return visitor.count_map_;
+  }
+
   // overriding VisitExpr ensures we do this for every subexpression
   void VisitExpr(const Expr& e) override {
     // Cases we ignore because we will not substitute them:
@@ -106,25 +114,17 @@ class SubexprCounter : public ExprVisitor {
   // we are not going to do replacements inside struct info to avoid binding 
lots of reused shapes
   void VisitExprDepStructInfoField(const StructInfo& struct_info) override {}
 
-  std::unordered_map<Expr, int, StructuralHash, StructuralEqual> Count(const 
Function& func) {
-    VisitExpr(func->body);
-    return count_map_;
-  }
-
  private:
   std::unordered_map<Expr, int, StructuralHash, StructuralEqual> count_map_;
   ImpurityDetector impurity_detector_;
 };
 
-// forward declaration
-Function EliminateCommonSubexpr(const Function&, bool call_only);
-
 class CommonSubexprEliminator : public ExprMutator {
  public:
   explicit CommonSubexprEliminator(
-      const std::unordered_map<Expr, int, StructuralHash, StructuralEqual>& 
count_map,
+      std::unordered_map<Expr, int, StructuralHash, StructuralEqual> count_map,
       bool call_only = false)
-      : count_map_(count_map), call_only_(call_only) {}
+      : count_map_(std::move(count_map)), call_only_(call_only) {}
 
   // overriding here ensures we visit every subexpression
   Expr VisitExpr(const Expr& e) override {
@@ -151,9 +151,15 @@ class CommonSubexprEliminator : public ExprMutator {
     return struct_info;
   }
 
-  Expr VisitExpr_(const FunctionNode* func) override {
-    // do full CSE within the function
-    return EliminateCommonSubexpr(GetRef<Function>(func), call_only_);
+  Expr VisitExpr_(const FunctionNode* op) override {
+    Function func = GetRef<Function>(op);
+
+    auto cache = SubexprCounter::Count(op->body);
+    std::swap(cache, count_map_);
+    Expr output = ExprMutator::VisitExpr_(op);
+    std::swap(cache, count_map_);
+
+    return output;
   }
 
   void VisitBinding_(const VarBindingNode* binding) override {
@@ -203,17 +209,14 @@ class CommonSubexprEliminator : public ExprMutator {
     return VisitExpr(bound_value);
   }
 
-  const std::unordered_map<Expr, int, StructuralHash, StructuralEqual>& 
count_map_;
+  std::unordered_map<Expr, int, StructuralHash, StructuralEqual> count_map_;
   std::unordered_map<Expr, Var, StructuralHash, StructuralEqual> replacements_;
   bool call_only_{false};
 };
 
-Function EliminateCommonSubexpr(const Function& func, bool call_only) {
-  SubexprCounter counter;
-  auto count_map = counter.Count(func);
-  CommonSubexprEliminator eliminator(count_map, call_only);
-  return Function(func->params, eliminator.VisitExpr(func->body), 
func->ret_struct_info,
-                  func->is_pure, func->attrs, func->span);
+Expr EliminateCommonSubexpr(const Expr& expr, bool call_only) {
+  CommonSubexprEliminator mutator(SubexprCounter::Count(expr), call_only);
+  return mutator(expr);
 }
 
 namespace transform {
diff --git a/src/relax/transform/utils.h b/src/relax/transform/utils.h
index a51c71d788..6e44f07aa6 100644
--- a/src/relax/transform/utils.h
+++ b/src/relax/transform/utils.h
@@ -374,6 +374,20 @@ inline String GetCodegenName(const std::string& 
composite_name) {
   return composite_name.substr(0, delim_pos);
 }
 
+/* \brief Eliminate common subexpressions
+ *
+ * Utility for simplifying relax expressions by removing common
+ * subexpressions.
+ *
+ * \param expr The expression to be updated
+ *
+ * \param call_only If true, only eliminate relax::Call nodes.  If
+ * false, eliminate any common subexpressions.
+ *
+ * \ret The updated expression
+ */
+Expr EliminateCommonSubexpr(const Expr& expr, bool call_only = false);
+
 }  // namespace relax
 }  // namespace tvm
 

Reply via email to