This is an automated email from the ASF dual-hosted git repository.
sslyu pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 233ac7aee0 [Unity] Avoid trivial var-to-var bindings in
CanonicalizeBindings (#15840)
233ac7aee0 is described below
commit 233ac7aee09e65e42c18596b9877de8c0476b148
Author: Eric Lunderberg <[email protected]>
AuthorDate: Mon Oct 2 17:09:05 2023 -0500
[Unity] Avoid trivial var-to-var bindings in CanonicalizeBindings (#15840)
Prior to this commit, the `relax.transform.CanonicalizeBindings`
transform would detect trivial bindings `var_y = var_x`, and replace
later usage of `var_y` with `var_x`. However, the trivial binding
`var_y = var_x` would be left in the canonicalized function.
This commit updates the `CanonicalizeBindings` transform to remove
trivial bindings. This is not intended as a full dead-code
elimination, as that is better handled as a separate pass, but is
instead intended to avoid introduction of dead code during
canonicalization.
---
src/relax/transform/canonicalize_bindings.cc | 49 ++++++++++------------
.../relax/test_transform_canonicalize_bindings.py | 41 ++----------------
2 files changed, 26 insertions(+), 64 deletions(-)
diff --git a/src/relax/transform/canonicalize_bindings.cc
b/src/relax/transform/canonicalize_bindings.cc
index ea5a612e1a..d8e3a9ba98 100644
--- a/src/relax/transform/canonicalize_bindings.cc
+++ b/src/relax/transform/canonicalize_bindings.cc
@@ -38,16 +38,6 @@ class BindingCanonicalizer : public ExprMutator {
using ExprMutator::VisitExpr_;
- Expr VisitExpr_(const VarNode* op) override {
- // remap first
- Var v = Downcast<Var>(ExprMutator::VisitExpr_(op));
- if (!CanCanonicalizeVar(v)) {
- return Downcast<Expr>(v);
- }
- // visit again in case we need to do a substitution in the value
- return ExprMutator::VisitExpr_(LookupBinding(v).as<VarNode>());
- }
-
Expr VisitExpr_(const TupleGetItemNode* tuple_get_item) override {
if (auto tuple_var = tuple_get_item->tuple.as<Var>()) {
if (auto tuple_value = LookupBinding(tuple_var.value())) {
@@ -71,12 +61,14 @@ class BindingCanonicalizer : public ExprMutator {
Expr new_value = this->VisitExpr(binding->value);
Var new_var = this->VisitVarDef(binding->var);
- if (new_var.same_as(binding->var) && new_value.same_as(binding->value)) {
+ if (auto opt_var = new_value.as<Var>();
+ opt_var && CanCanonicalizeVar(new_var, opt_var.value())) {
+ var_remap_[new_var->vid] = opt_var.value();
+ } else if (new_var.same_as(binding->var) &&
new_value.same_as(binding->value)) {
this->builder_->EmitNormalized(GetRef<VarBinding>(binding));
- return;
+ } else {
+ this->builder_->EmitNormalized(VarBinding(new_var, new_value));
}
-
- this->builder_->EmitNormalized(VarBinding(new_var, new_value));
}
void VisitBinding_(const MatchCastNode* binding) override {
@@ -84,9 +76,19 @@ class BindingCanonicalizer : public ExprMutator {
// we can canonicalize to a var binding
Expr new_value = this->VisitExpr(binding->value);
- // if the LHS and RHS have the same struct info, we canonicalize to a var
binding instead
- if (StructuralEqual()(binding->struct_info, GetStructInfo(new_value))) {
- builder_->EmitNormalized(VarBinding(binding->var, new_value));
+ bool has_same_struct_info = StructuralEqual()(binding->struct_info,
GetStructInfo(new_value));
+
+ if (has_same_struct_info) {
+ if (auto parent = new_value.as<Var>();
+ parent && CanCanonicalizeVar(binding->var, parent.value())) {
+ // LHS and RHS have the same struct info, and occur in a
+ // context where the RHS can replace the LHS.
+ var_remap_[binding->var->vid] = parent.value();
+ } else {
+ // LHS and RHS have the same struct info, but the RHS is not a
+ // drop-in replacement for the LHS.
+ builder_->EmitNormalized(VarBinding(binding->var, new_value));
+ }
} else if (new_value.same_as(binding->value)) {
builder_->EmitNormalized(GetRef<MatchCast>(binding));
} else {
@@ -104,24 +106,17 @@ class BindingCanonicalizer : public ExprMutator {
return !(both_present || neither_present) || (both_present &&
!check_eq(obj1, obj2));
}
- bool CanCanonicalizeVar(Var v) {
- Optional<Expr> value = LookupBinding(v);
- // can replace only if the value is also a var
- if (!value || !value.as<VarNode>()) {
- return false;
- }
- Var parent_var = Downcast<Var>(value);
-
+ bool CanCanonicalizeVar(Var var, Var parent_var) {
// Cases when we conservatively do not unify:
// 1. checked_type_ or shape_ of the child differs from that of the parent
// In this case, we could be overriding user annotations.
// 2. If the child is a Var and the parent is a DataflowVar.
// That could result in a DataflowVar leaving the current DataflowBlock.
- bool annotations_differ = AnnotationsDiffer(v->struct_info_,
parent_var->struct_info_,
+ bool annotations_differ = AnnotationsDiffer(var->struct_info_,
parent_var->struct_info_,
[&](const ObjectRef& lhs,
const ObjectRef& rhs) {
return
tvm::StructuralEqual()(lhs, rhs);
});
- bool var_to_dataflow = (!v.as<DataflowVarNode>() &&
parent_var.as<DataflowVarNode>());
+ bool var_to_dataflow = (!var.as<DataflowVarNode>() &&
parent_var.as<DataflowVarNode>());
return !annotations_differ && !var_to_dataflow;
}
};
diff --git a/tests/python/relax/test_transform_canonicalize_bindings.py
b/tests/python/relax/test_transform_canonicalize_bindings.py
index 91396ccb13..52bf5a6e43 100644
--- a/tests/python/relax/test_transform_canonicalize_bindings.py
+++ b/tests/python/relax/test_transform_canonicalize_bindings.py
@@ -36,17 +36,10 @@ def test_simple_assignments():
o = p
return o
- # a little annoying to have these unused bindings around
- # but they can be eliminated in a separate pass
@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor):
- y = x
- z = x
- q = x
- p = x
- o = x
return x
new_mod = relax.transform.CanonicalizeBindings()(TestChainAssignments)
@@ -68,19 +61,16 @@ def test_dataflow_block():
R.output(n)
return n
- # a little annoying to have these unused bindings around
- # but they can be eliminated in a separate pass
@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor):
with R.dataflow():
y = R.const(1)
- z = y
- o = y
- p = y
- m = y
- # we can't get rid of n because it leaves the block
+ # We can't get rid of n because it leaves the block.
+ # CanonicalizeBindings does not do a full dead-code
+ # elimination, and only does local analysis of trivial
+ # bindings that it may produce.
n = y
R.output(n)
return n
@@ -108,15 +98,6 @@ def test_assign_to_output_indataflow_block():
class Expected:
@R.function
def main(x: R.Tensor):
- with R.dataflow():
- y = x
- z = x
- o = x
- p = x
- m = x
- # we can't get rid of n because it leaves the block
- n = x
- R.output(n)
return x
new_mod = relax.transform.CanonicalizeBindings()(TestDataflowAssignments)
@@ -137,8 +118,6 @@ def test_ops():
class Expected:
@R.function
def main(x: R.Tensor, y: R.Tensor):
- w = y
- q = x
z = R.add(y, x)
return R.add(x, z)
@@ -161,7 +140,6 @@ def test_casting():
class Expected:
@R.function
def main(x: R.Tensor) -> R.Object:
- y = x
# Cannot unify because the cast indicates user intent
z: R.Object = x
return z
@@ -185,11 +163,9 @@ def test_match_cast():
class Expected:
@R.function
def main(x: R.Tensor):
- q = x
# can't get rid of z because its shape_ is different from x's
m, n = T.int64(), T.int64()
z = R.match_cast(x, R.Tensor((m, n)))
- w = z
return z
new_mod = relax.transform.CanonicalizeBindings()(TestMatchCast)
@@ -213,11 +189,6 @@ def test_same_shape():
class Expected:
@R.function
def main(x: R.Tensor(("m", "n"), "float32")):
- m, n = T.int64(), T.int64()
- y = x
- # canonicalized into a var binding
- z = x
- w = x
q = R.add(x, x)
return R.add(q, x)
@@ -242,10 +213,8 @@ def test_change_shape():
class Expected:
@R.function
def main(x: R.Tensor(("m", "n"))):
- y = x
o, p = T.int64(), T.int64()
z = R.match_cast(x, R.Tensor((o, p)))
- w = z
# the shape_ field on q will need to be updated
q = R.add(z, x)
return R.add(q, z)
@@ -270,8 +239,6 @@ def test_unwrap_tuple():
@R.function
def main(x: R.Tensor, y: R.Tensor):
tuple_var = (x, y)
- w = x
- q = y
z = R.add(x, y)
return R.add(y, z)