This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 72c38ab32f [Unity][Pass] Canonicalize Bindings (#14079)
72c38ab32f is described below

commit 72c38ab32f91e4487f1e38e88546c3ab96bc6c2c
Author: Yuchen Jin <[email protected]>
AuthorDate: Tue Feb 21 21:31:52 2023 -0800

    [Unity][Pass] Canonicalize Bindings (#14079)
    
    It may be useful for some passes to collapse chains of definitions, 
particularly after other compiler transformations that may reduce or simplify 
some expressions.
    
    This pass will take chains of definitions and replace references to later 
definitions to the original one. It works by checking `LookupBinding` for each 
var use-site and replacing the var with its definition if the definition was 
another var. Additionally, `MatchCast` bindings where the LHS and the RHS are 
guaranteed to match at compile time are canonicalized into ordinary 
`VarBinding`s.
    
    Example:
    ```python
    y = x
    z = y
    w = z
    o = w
    p = o
    ```
    Will be replaced with
    ```python
    y = x
    z = x
    w = x
    o = x
    p = x
    ```
    
    Original PR: https://github.com/tlc-pack/relax/pull/233
    
    Co-authored-by: Steven S. Lyubomirsky <[email protected]>
---
 include/tvm/relax/transform.h                      |   9 +
 python/tvm/relax/transform/transform.py            |  14 ++
 src/relax/transform/canonicalize_bindings.cc       | 135 +++++++++++++
 .../relax/test_transform_canonicalize_bindings.py  | 224 +++++++++++++++++++++
 4 files changed, 382 insertions(+)

diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index 7d6c93bcde..b42fb5864e 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -127,6 +127,15 @@ TVM_DLL Pass AttachGlobalSymbol();
  */
 TVM_DLL Pass Normalize();
 
+/*!
+ * \brief Simplify a Relax module by folding var bindings and match shape 
nodes.
+ * May include other forms of expression simplification in the future.
+ * Best used alongside constant folding and eliminating unused bindings.
+ *
+ * \return The Pass.
+ */
+TVM_DLL Pass CanonicalizeBindings();
+
 /*!
  * \brief Bind params of function of the module to constant tensors.
  *
diff --git a/python/tvm/relax/transform/transform.py 
b/python/tvm/relax/transform/transform.py
index 9fb2458dc0..c72d053290 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -80,6 +80,20 @@ def Normalize() -> tvm.ir.transform.Pass:
     return _ffi_api.Normalize()  # type: ignore
 
 
+def CanonicalizeBindings() -> tvm.ir.transform.Pass:
+    """
+    Canonicalizes variable definitions
+    (e.g., if there is y = x and z = y, it replaces uses of y and z with x).
+
+    Best combined with constant folding and the elimination of unused 
definitions.
+
+    Returns
+    -------
+    ret: tvm.ir.transform.Pass
+    """
+    return _ffi_api.CanonicalizeBindings()  # type: ignore
+
+
 def RewriteDataflowReshape() -> tvm.ir.transform.Pass:
     """Convert all reshape-like call_tir to VM reshape operator call.
     The VM reshape operator calls will be further lowered to a CreateView
diff --git a/src/relax/transform/canonicalize_bindings.cc 
b/src/relax/transform/canonicalize_bindings.cc
new file mode 100644
index 0000000000..962f76a376
--- /dev/null
+++ b/src/relax/transform/canonicalize_bindings.cc
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relax/transform/canonicalize_bindings.cc
+ * \brief Pass for simplifying modules by folding var bindings and match shape 
nodes.
+ *        May include other forms of simplification in the future.
+ *        Ideally should be used before constant folding and eliminating 
unused bindings.
+ */
+
+#include <tvm/relax/expr.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/struct_info.h>
+#include <tvm/relax/transform.h>
+
+namespace tvm {
+namespace relax {
+
+class BindingCanonicalizer : public ExprMutator {
+ public:
+  BindingCanonicalizer() {}
+
+  Expr VisitExpr_(const VarNode* op) override {
+    // remap first
+    Var v = Downcast<Var>(ExprMutator::VisitExpr_(op));
+    if (!CanCanonicalizeVar(v)) {
+      return Downcast<Expr>(v);
+    }
+    // visit again in case we need to do a substitution in the value
+    return ExprMutator::VisitExpr_(LookupBinding(v).as<VarNode>());
+  }
+
+  Expr VisitExpr_(const DataflowVarNode* op) override {
+    Var v = Downcast<Var>(ExprMutator::VisitExpr_(op));
+    if (!CanCanonicalizeVar(v)) {
+      return Downcast<Expr>(v);
+    }
+    return ExprMutator::VisitExpr_(LookupBinding(v).as<DataflowVarNode>());
+  }
+
+  void VisitBinding_(const VarBindingNode* binding) override {
+    // Unlike default visitor, we do not permit the checked type to change
+    // if the new value's checked type is different (this preserves user 
annotations)
+    Expr new_value = this->VisitExpr(binding->value);
+    Var new_var = this->VisitVarDef(binding->var);
+
+    if (new_var.same_as(binding->var) && new_value.same_as(binding->value)) {
+      this->builder_->EmitNormalized(GetRef<VarBinding>(binding));
+      return;
+    }
+
+    this->builder_->EmitNormalized(VarBinding(new_var, new_value));
+  }
+
+  void VisitBinding_(const MatchCastNode* binding) override {
+    // If we have a trivial shape check (the shape_ of LHS and RHS is the 
same),
+    // we can canonicalize to a var binding
+    Expr new_value = this->VisitExpr(binding->value);
+
+    // if the LHS and RHS have the same struct info, we canonicalize to a var 
binding instead
+    if (StructuralEqual()(binding->struct_info, GetStructInfo(new_value))) {
+      builder_->EmitNormalized(VarBinding(binding->var, new_value));
+    } else if (new_value.same_as(binding->value)) {
+      builder_->EmitNormalized(GetRef<MatchCast>(binding));
+    } else {
+      builder_->EmitNormalized(MatchCast(binding->var, new_value, 
binding->struct_info));
+    }
+  }
+
+ private:
+  bool AnnotationsDiffer(const ObjectRef& obj1, const ObjectRef& obj2,
+                         std::function<bool(const ObjectRef&, const 
ObjectRef&)> check_eq) {
+    // annotations differ if one is present but not the other
+    // or they're both present and they differ
+    bool both_present = obj1.defined() && obj2.defined();
+    bool neither_present = !obj1.defined() && !obj2.defined();
+    return !(both_present || neither_present) || (both_present && 
!check_eq(obj1, obj2));
+  }
+
+  bool CanCanonicalizeVar(Var v) {
+    Optional<Expr> value = LookupBinding(v);
+    // can replace only if the value is also a var
+    if (!value || !value.as<VarNode>()) {
+      return false;
+    }
+    Var parent_var = Downcast<Var>(value);
+
+    // Cases when we conservatively do not unify:
+    // 1. checked_type_ or shape_ of the child differs from that of the parent
+    //    In this case, we could be overriding user annotations.
+    // 2. If the child is a Var and the parent is a DataflowVar.
+    //    That could result in a DataflowVar leaving the current DataflowBlock.
+    bool annotations_differ = AnnotationsDiffer(v->struct_info_, 
parent_var->struct_info_,
+                                                [&](const ObjectRef& lhs, 
const ObjectRef& rhs) {
+                                                  return 
tvm::StructuralEqual()(lhs, rhs);
+                                                });
+    bool var_to_dataflow = (!v.as<DataflowVarNode>() && 
parent_var.as<DataflowVarNode>());
+    return !annotations_differ && !var_to_dataflow;
+  }
+};
+
+Expr CanonicalizeBindings(const Expr& e) { return 
BindingCanonicalizer().VisitExpr(e); }
+
+namespace transform {
+
+Pass CanonicalizeBindings() {
+  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> 
pass_func =
+      [=](Function f, IRModule m, PassContext pc) {
+        return Downcast<Function>(CanonicalizeBindings(f));
+      };
+  return CreateFunctionPass(pass_func, 1, "CanonicalizeBindings", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.CanonicalizeBindings").set_body_typed(CanonicalizeBindings);
+
+}  // namespace transform
+
+}  // namespace relax
+}  // namespace tvm
diff --git a/tests/python/relax/test_transform_canonicalize_bindings.py 
b/tests/python/relax/test_transform_canonicalize_bindings.py
new file mode 100644
index 0000000000..4694e98973
--- /dev/null
+++ b/tests/python/relax/test_transform_canonicalize_bindings.py
@@ -0,0 +1,224 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import tvm.script
+import tvm.testing
+import pytest
+from tvm import relax
+from tvm.ir.base import assert_structural_equal
+from tvm.script import relax as R, tir as T
+
+
+def test_simple_assignments():
+    @tvm.script.ir_module
+    class TestChainAssignments:
+        @R.function
+        def main(x: R.Tensor):
+            y = x
+            z = y
+            q = z
+            p = q
+            o = p
+            return o
+
+    # a little annoying to have these unused bindings around
+    # but they can be eliminated in a separate pass
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor):
+            y = x
+            z = x
+            q = x
+            p = x
+            o = x
+            return x
+
+    new_mod = relax.transform.CanonicalizeBindings()(TestChainAssignments)
+    assert_structural_equal(new_mod, Expected)
+
+
+def test_dataflow_block():
+    @tvm.script.ir_module
+    class TestDataflowAssignments:
+        @R.function
+        def main(x: R.Tensor):
+            with R.dataflow():
+                y = R.const(1)
+                z = y
+                o = z
+                p = o
+                m = p
+                n = m
+                R.output(n)
+            return n
+
+    # a little annoying to have these unused bindings around
+    # but they can be eliminated in a separate pass
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor):
+            with R.dataflow():
+                y = R.const(1)
+                z = y
+                o = y
+                p = y
+                m = y
+                # we can't get rid of n because it leaves the block
+                n = y
+                R.output(n)
+            return n
+
+    new_mod = relax.transform.CanonicalizeBindings()(TestDataflowAssignments)
+    assert_structural_equal(new_mod, Expected)
+
+
+def test_ops():
+    @tvm.script.ir_module
+    class TestOps:
+        @R.function
+        def main(x: R.Tensor, y: R.Tensor):
+            w = y
+            q = x
+            z = R.add(w, q)
+            return R.add(q, z)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor, y: R.Tensor):
+            w = y
+            q = x
+            z = R.add(y, x)
+            return R.add(x, z)
+
+    new_mod = relax.transform.CanonicalizeBindings()(TestOps)
+    assert_structural_equal(new_mod, Expected)
+
+
[email protected](reason="The lhs and rhs of an assignment should have the 
same struct info.")
+def test_casting():
+    @tvm.script.ir_module
+    class TestCasting:
+        @R.function
+        def main(x: R.Tensor) -> R.Object:
+            y = x
+            # z will be treated as object type even though it's a tensor
+            z: R.Object = y
+            return z
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor) -> R.Object:
+            y = x
+            # Cannot unify because the cast indicates user intent
+            z: R.Object = x
+            return z
+
+    new_mod = relax.transform.CanonicalizeBindings()(TestCasting)
+    assert_structural_equal(new_mod, Expected)
+
+
+def test_match_cast():
+    @tvm.script.ir_module
+    class TestMatchCast:
+        @R.function
+        def main(x: R.Tensor):
+            q = x
+            m, n = T.var("int64"), T.var("int64")
+            z = R.match_cast(q, R.Tensor((m, n)))
+            w = z
+            return w
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor):
+            q = x
+            # can't get rid of z because its shape_ is different from x's
+            m, n = T.var("int64"), T.var("int64")
+            z = R.match_cast(x, R.Tensor((m, n)))
+            w = z
+            return z
+
+    new_mod = relax.transform.CanonicalizeBindings()(TestMatchCast)
+    assert_structural_equal(new_mod, Expected)
+
+
+def test_same_shape():
+    @tvm.script.ir_module
+    class TestSameShape:
+        @R.function
+        def main(x: R.Tensor(("m", "n"), "float32")):
+            m, n = T.var("int64"), T.var("int64")
+            y = x
+            # trivial check
+            z = R.match_cast(x, R.Tensor((m, n), "float32"))
+            w = z
+            q = R.add(w, y)
+            return R.add(q, w)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor(("m", "n"), "float32")):
+            m, n = T.var("int64"), T.var("int64")
+            y = x
+            # canonicalized into a var binding
+            z = x
+            w = x
+            q = R.add(x, x)
+            return R.add(q, x)
+
+    new_mod = relax.transform.CanonicalizeBindings()(TestSameShape)
+    assert_structural_equal(new_mod, Expected)
+
+
+def test_change_shape():
+    @tvm.script.ir_module
+    class TestChangeShape:
+        @R.function
+        def main(x: R.Tensor(("m", "n"))):
+            y = x
+            # not trivial: introduces new shape vars
+            o, p = T.var("int64"), T.var("int64")
+            z = R.match_cast(x, R.Tensor((o, p)))
+            w = z
+            q = R.add(w, y)
+            return R.add(q, w)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor(("m", "n"))):
+            y = x
+            o, p = T.var("int64"), T.var("int64")
+            z = R.match_cast(x, R.Tensor((o, p)))
+            w = z
+            # the shape_ field on q will need to be updated
+            q = R.add(z, x)
+            return R.add(q, z)
+
+    new_mod = relax.transform.CanonicalizeBindings()(TestChangeShape)
+    assert_structural_equal(new_mod, Expected)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to