Lunderberg commented on code in PR #16599:
URL: https://github.com/apache/tvm/pull/16599#discussion_r1499538408
##########
src/relax/transform/eliminate_common_subexpr.cc:
##########
@@ -20,223 +20,162 @@
/*!
* \file tvm/relax/transform/eliminate_common_subexpr.cc
- * \brief Eliminrate common subexpression pass.
+ * \brief Eliminate common subexpression pass.
*
* Currently it removes common subexpressions within a Function.
*/
+#include <tvm/relax/analysis.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/transform.h>
#include <tvm/relax/utils.h>
-#include "utils.h"
+#include "../../support/utils.h"
namespace tvm {
namespace relax {
-
-// Checks if a given expression contains an impure subexpression
-// Caches the results of checks to avoid revisiting subexpressions
-class ImpurityDetector : public ExprVisitor {
- public:
- bool Detect(const Expr& expr) {
- impure_found_ = false;
- VisitExpr(expr);
- return impure_found_;
+namespace {
+/* \brief Lookup key for subexpression replacements
+ *
+ * The lookup key must contain the expression being bound, along with
+ * the struct info used for a match cast, if applicable. Using
+ * `MatchCast` with StructuralEqual and StructuralHash would be almost
+ * correct, but acts as a point of definition for symbolic variables
+ * within the output struct info. As a result, it would erroneously
+ * de-duplicate `R.match_cast(A, R.Tensor([m,n]))` and
+ * `R.match_cast(A, R.Tensor([p,q]))`, even though they define
+ * different symbolic variables.
+ */
+struct ReplacementKey {
+ tvm::relax::Expr bound_value;
+ tvm::Optional<tvm::relax::StructInfo> match_cast = tvm::NullOpt;
+
+ explicit ReplacementKey(const tvm::relax::Binding& binding)
+ : bound_value(GetBoundValue(binding)) {
+ if (const auto* ptr = binding.as<tvm::relax::MatchCastNode>()) {
+ match_cast = ptr->struct_info;
+ }
}
- void VisitExpr(const Expr& expr) {
- // already checked: do not revisit
- if (purity_map_.count(expr)) {
- impure_found_ = impure_found_ || !purity_map_.at(expr);
- return;
- }
+ friend bool operator==(const ReplacementKey& a, const ReplacementKey& b) {
+ tvm::StructuralEqual eq;
+ return eq(a.bound_value, b.bound_value) && eq(a.match_cast, b.match_cast);
+ }
+};
- // in principle, we could stop checking once we find an impurity,
- // but not doing so lets us fully populate the cache
+} // namespace
+} // namespace relax
+} // namespace tvm
- // store the previous state so we could assess the purity of this
subexpression alone
- bool prev_state = impure_found_;
- impure_found_ = false;
- ExprVisitor::VisitExpr(expr);
- // if impure_found_ remains false, then the expression is pure
- purity_map_[expr] = !impure_found_;
- impure_found_ = prev_state || impure_found_;
+/* \brief Definition of std::hash<ReplacementKey>
+ *
+ * Specialization of std::hash must occur outside of tvm::relax
+ * namespace, and before its usage in the constructor of
+ * `CommonSubexprEliminator`.
+ */
+template <>
+struct std::hash<tvm::relax::ReplacementKey> {
+ std::size_t operator()(const tvm::relax::ReplacementKey& key) const {
+ tvm::StructuralHash hasher;
+ return tvm::support::HashCombine(hasher(key.bound_value),
hasher(key.match_cast));
}
+};
- void VisitExpr_(const CallNode* call) {
- // the only possible impurities can come from call nodes
- bool is_impure = IsImpureCall(GetRef<Call>(call));
- impure_found_ = impure_found_ || is_impure;
- ExprVisitor::VisitExpr_(call);
- }
+namespace tvm {
+namespace relax {
- private:
- bool impure_found_ = false;
- std::unordered_map<Expr, bool, StructuralHash, StructuralEqual> purity_map_;
-};
+namespace {
-class SubexprCounter : public ExprVisitor {
+class CommonSubexprEliminator : public ExprMutator {
public:
- static std::unordered_map<Expr, int, StructuralHash, StructuralEqual>
Count(const Expr& expr) {
- SubexprCounter visitor;
- visitor(expr);
- return visitor.count_map_;
+ explicit CommonSubexprEliminator(bool call_only = false) :
call_only_(call_only) {}
+
+ BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) override {
+ auto cache_exprs = expr_replacements_;
+ auto cache_vars = var_remap_;
+ auto output = ExprMutator::VisitBindingBlock_(block);
+ expr_replacements_ = cache_exprs;
+ var_remap_ = cache_vars;
+ return output;
}
- // overriding VisitExpr ensures we do this for every subexpression
- void VisitExpr(const Expr& e) override {
- // Cases we ignore because we will not substitute them:
- // 1. Vars of all kinds
- // 2. Op nodes (nothing we can do)
- // 3. PrimValue nodes (not much benefit from binding to a var)
- // 4. StringImm nodes (not much benefit from binding to a var)
- // 5. Scalar constants (not much benefit from binding to a var)
- // 6. Shape expressions (exist to hold several PrimValue objects)
- // 7. DataType nodes (no need to modify dtype nodes)
- if (!(e->IsInstance<VarNode>() || e->IsInstance<DataflowVarNode>() ||
- e->IsInstance<GlobalVarNode>() || e->IsInstance<tvm::OpNode>() ||
- e->IsInstance<PrimValueNode>() || e->IsInstance<StringImmNode>() ||
- e->IsInstance<ShapeExprNode>() || e->IsInstance<ExternFuncNode>() ||
- e->IsInstance<ConstantNode>() || e->IsInstance<DataTypeImmNode>())) {
- // also if e has an impure subexpression, we will not deduplicate it
- if (!impurity_detector_.Detect(e)) {
- int count = 0;
- if (count_map_.count(e)) {
- count = count_map_.at(e);
- }
- count_map_[e] = count + 1;
+ void VisitBinding(const Binding& binding) override {
+ Expr bound_value = VisitExpr(GetBoundValue(binding));
+
+ Binding output_binding = [&]() -> Binding {
+ if (binding.as<VarBindingNode>()) {
+ return VarBinding(binding->var, bound_value);
+ } else if (auto match_cast = binding.as<MatchCastNode>()) {
+ return MatchCast(binding->var, bound_value, match_cast->struct_info);
+ } else {
+ LOG(FATAL) << "Binding must be either VarBinding or MatchCast, "
+ << "but was " << binding->GetTypeKey();
}
- }
+ }();
- // Only visit the interior of objects that we might still keep
- // around. Otherwise, double-counting these would lead to extra
- // variable bindings.
- //
- // Before:
- // y = f(a+b)
- // z = f(a+b)
- //
- // Expected:
- // y = f(a+b) // De-duped from (y==z)
- // z = y
- //
- // Erroneous output:
- // c = a+b // Incorrect, a+b only has a single usage.
- // y = f(c) // De-duped from
- // z = y
- //
- if (auto it = count_map_.find(e); it == count_map_.end() || it->second <
2) {
- ExprVisitor::VisitExpr(e);
- }
- }
+ ReplacementKey lookup_key(output_binding);
- // do not visit inner functions: we will do CSE within those
- void VisitExpr_(const FunctionNode* func) override {}
+ if (call_only_ && !bound_value->IsInstance<relax::CallNode>()) {
+ VLOG(1) << "Since call_only_ is true, it is forbidden to de-duplicate "
<< bound_value;
- // we are not going to do replacements inside struct info to avoid binding
lots of reused shapes
- void VisitExprDepStructInfoField(const StructInfo& struct_info) override {}
+ } else if (ContainsImpureCall(bound_value)) {
+ VLOG(1) << "Since the expression is impure, cannot de-duplicate " <<
bound_value;
- private:
- std::unordered_map<Expr, int, StructuralHash, StructuralEqual> count_map_;
- ImpurityDetector impurity_detector_;
-};
+ } else if (auto it = expr_replacements_.find(lookup_key); it !=
expr_replacements_.end()) {
+ VLOG(1) << "Value " << bound_value << " has previously been bound as "
<< it->second
+ << ". The duplicate binding of this value to " << binding->var
+ << " will be replaced with a trivial binding, "
+ << "and occurrences of " << binding->var << " will be replaced
with " << it->second;
+ output_binding = VarBinding(binding->var, it->second);
Review Comment:
This is to handle cases where the variable remapping would not be allowed,
but the variable binding is still valid. If we the de-duplication results in
remapping a non-dataflow variable to a dataflow variable, that remapping is
only valid within the dataflow block. Outside of that dataflow block, the
variable remaps are reset, and any remaining usages would still need to be
definition.
I've added the following unit test to test this behavior.
```python
def test_do_not_eliminate_impure():
@I.ir_module
class Before:
@R.function(pure=False)
def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3),
dtype="float32")):
# it's a repeated subexpression but it would be wrong to
deduplicate it
p1 = R.print(format="Message")
p2 = R.print(format="Message")
a1 = R.assert_op(R.const(False), format="Always fails")
lv0 = R.add(x, y)
lv1 = R.add(x, y)
gv = R.multiply(lv0, lv1)
a2 = R.assert_op(R.const(False), format="Always fails")
return gv
@I.ir_module
class Expected:
@R.function(pure=False)
def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3),
dtype="float32")):
p1 = R.print(format="Message")
p2 = R.print(format="Message")
a1 = R.assert_op(R.const(False), format="Always fails")
lv0 = R.add(x, y)
lv1 = lv0
gv = R.multiply(lv0, lv0)
a2 = R.assert_op(R.const(False), format="Always fails")
return gv
verify(Before, Expected)
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]