This is an automated email from the ASF dual-hosted git repository.
sslyu pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 82586fb9ea [Unity] Extend EliminateCommonSubexpr to operate on
relax::Expr (#15815)
82586fb9ea is described below
commit 82586fb9ea3156828d076f3b60d895d297b54729
Author: Eric Lunderberg <[email protected]>
AuthorDate: Wed Sep 27 13:37:14 2023 -0500
[Unity] Extend EliminateCommonSubexpr to operate on relax::Expr (#15815)
Prior to this commit, the `EliminateCommonSubexpr` utility could only
apply to `relax::Function` instances. This commit extends the allowed
usage to apply to any `relax::Expr` that contains variable bindings.
This is only included as an internal utility within the C++
implementation, and is not currently exposed for external use.
---
src/relax/transform/eliminate_common_subexpr.cc | 43 +++++++++++++------------
src/relax/transform/utils.h | 14 ++++++++
2 files changed, 37 insertions(+), 20 deletions(-)
diff --git a/src/relax/transform/eliminate_common_subexpr.cc
b/src/relax/transform/eliminate_common_subexpr.cc
index 8bbb05f327..fa90d41933 100644
--- a/src/relax/transform/eliminate_common_subexpr.cc
+++ b/src/relax/transform/eliminate_common_subexpr.cc
@@ -28,6 +28,8 @@
#include <tvm/relax/transform.h>
#include <tvm/relax/utils.h>
+#include "utils.h"
+
namespace tvm {
namespace relax {
@@ -74,6 +76,12 @@ class ImpurityDetector : public ExprVisitor {
class SubexprCounter : public ExprVisitor {
public:
+ static std::unordered_map<Expr, int, StructuralHash, StructuralEqual>
Count(const Expr& expr) {
+ SubexprCounter visitor;
+ visitor(expr);
+ return visitor.count_map_;
+ }
+
// overriding VisitExpr ensures we do this for every subexpression
void VisitExpr(const Expr& e) override {
// Cases we ignore because we will not substitute them:
@@ -106,25 +114,17 @@ class SubexprCounter : public ExprVisitor {
// we are not going to do replacements inside struct info to avoid binding
lots of reused shapes
void VisitExprDepStructInfoField(const StructInfo& struct_info) override {}
- std::unordered_map<Expr, int, StructuralHash, StructuralEqual> Count(const
Function& func) {
- VisitExpr(func->body);
- return count_map_;
- }
-
private:
std::unordered_map<Expr, int, StructuralHash, StructuralEqual> count_map_;
ImpurityDetector impurity_detector_;
};
-// forward declaration
-Function EliminateCommonSubexpr(const Function&, bool call_only);
-
class CommonSubexprEliminator : public ExprMutator {
public:
explicit CommonSubexprEliminator(
- const std::unordered_map<Expr, int, StructuralHash, StructuralEqual>&
count_map,
+ std::unordered_map<Expr, int, StructuralHash, StructuralEqual> count_map,
bool call_only = false)
- : count_map_(count_map), call_only_(call_only) {}
+ : count_map_(std::move(count_map)), call_only_(call_only) {}
// overriding here ensures we visit every subexpression
Expr VisitExpr(const Expr& e) override {
@@ -151,9 +151,15 @@ class CommonSubexprEliminator : public ExprMutator {
return struct_info;
}
- Expr VisitExpr_(const FunctionNode* func) override {
- // do full CSE within the function
- return EliminateCommonSubexpr(GetRef<Function>(func), call_only_);
+ Expr VisitExpr_(const FunctionNode* op) override {
+ Function func = GetRef<Function>(op);
+
+ auto cache = SubexprCounter::Count(op->body);
+ std::swap(cache, count_map_);
+ Expr output = ExprMutator::VisitExpr_(op);
+ std::swap(cache, count_map_);
+
+ return output;
}
void VisitBinding_(const VarBindingNode* binding) override {
@@ -203,17 +209,14 @@ class CommonSubexprEliminator : public ExprMutator {
return VisitExpr(bound_value);
}
- const std::unordered_map<Expr, int, StructuralHash, StructuralEqual>&
count_map_;
+ std::unordered_map<Expr, int, StructuralHash, StructuralEqual> count_map_;
std::unordered_map<Expr, Var, StructuralHash, StructuralEqual> replacements_;
bool call_only_{false};
};
-Function EliminateCommonSubexpr(const Function& func, bool call_only) {
- SubexprCounter counter;
- auto count_map = counter.Count(func);
- CommonSubexprEliminator eliminator(count_map, call_only);
- return Function(func->params, eliminator.VisitExpr(func->body),
func->ret_struct_info,
- func->is_pure, func->attrs, func->span);
+Expr EliminateCommonSubexpr(const Expr& expr, bool call_only) {
+ CommonSubexprEliminator mutator(SubexprCounter::Count(expr), call_only);
+ return mutator(expr);
}
namespace transform {
diff --git a/src/relax/transform/utils.h b/src/relax/transform/utils.h
index a51c71d788..6e44f07aa6 100644
--- a/src/relax/transform/utils.h
+++ b/src/relax/transform/utils.h
@@ -374,6 +374,20 @@ inline String GetCodegenName(const std::string&
composite_name) {
return composite_name.substr(0, delim_pos);
}
+/* \brief Eliminate common subexpressions
+ *
+ * Utility for simplifying relax expressions by removing common
+ * subexpressions.
+ *
+ * \param expr The expression to be updated
+ *
+ * \param call_only If true, only eliminate relax::Call nodes. If
+ * false, eliminate any common subexpressions.
+ *
+ * \ret The updated expression
+ */
+Expr EliminateCommonSubexpr(const Expr& expr, bool call_only = false);
+
} // namespace relax
} // namespace tvm