This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new bb5c8df4c7 [Unity][Pass] Normalize Pass (#14031)
bb5c8df4c7 is described below

commit bb5c8df4c79e599dbf82fbab9a32543eb2b5a6d9
Author: Lesheng Jin <[email protected]>
AuthorDate: Fri Feb 17 16:25:37 2023 -0800

    [Unity][Pass] Normalize Pass (#14031)
    
    This PR implements relax `Normalize` Pass, which allows users to transform 
Relax IR to normal form, i.e., the expressions are normalized (no nesting and 
hence the AST is in ANF), and all `checked_type_` and `shape_` of expressions 
are available. (tests are added).
    
    Co-Authored-by: Yuchen Jin <[email protected]>
    Co-Authored-by: Ruihang Lai <[email protected]>
    Co-Authored-by: Siyuan Feng <[email protected]>
    Co-Authored-by: Tianqi Chen <[email protected]>
---
 include/tvm/relax/transform.h                  |   9 +
 python/tvm/relax/transform/transform.py        |  11 +
 src/relax/transform/normalize.cc               | 186 +++++++++
 tests/python/relax/test_transform_normalize.py | 554 +++++++++++++++++++++++++
 4 files changed, 760 insertions(+)

diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index e9f63ee9db..7a4054d414 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -133,6 +133,15 @@ TVM_DLL Pass BindParams(String func_name, Map<String, 
runtime::NDArray> params);
  * \return The Pass.
  */
 TVM_DLL Pass FoldConstant();
+
+/*!
+ * \brief Transform Relax IR to normal form: transform AST to A-normal form, 
and fill the
+ * checked_type_ and shape_ of expressions.
+ *
+ * \return The Pass.
+ */
+TVM_DLL Pass Normalize();
+
 }  // namespace transform
 }  // namespace relax
 }  // namespace tvm
diff --git a/python/tvm/relax/transform/transform.py 
b/python/tvm/relax/transform/transform.py
index c0ac180ff1..7fcf0b1121 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -68,6 +68,17 @@ def CallTIRRewrite() -> tvm.ir.transform.Pass:
     return _ffi_api.CallTIRRewrite()  # type: ignore
 
 
+def Normalize() -> tvm.ir.transform.Pass:
+    """Transforming Relax IR to normal form, i.e., the expressions are 
normalized(no nesting
+    and hence the AST is in ANF), and all checked_type_ and shape_ of 
expressions are available.
+
+    Returns
+    -------
+    ret: tvm.ir.transform.Pass
+    """
+    return _ffi_api.Normalize()  # type: ignore
+
+
 def RewriteDataflowReshape() -> tvm.ir.transform.Pass:
     """Convert all reshape-like call_tir to VM reshape operator call.
     The VM reshape operator calls will be further lowered to a CreateView
diff --git a/src/relax/transform/normalize.cc b/src/relax/transform/normalize.cc
new file mode 100644
index 0000000000..915498178f
--- /dev/null
+++ b/src/relax/transform/normalize.cc
@@ -0,0 +1,186 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relax/transform/normalize.cc
+ * \brief Pass for transforming Relax IR to normal form, i.e., the expressions 
are normalized(no
+ * nesting and hence the AST is in ANF), and all checked_type_ and shape_ of 
expressions are
+ * available.
+ */
+
+#include <tvm/relax/expr.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/struct_info.h>
+#include <tvm/relax/transform.h>
+
+namespace tvm {
+namespace relax {
+
+// TODO(@altanh): LCA binding lifting
+class NormalizeMutator : public ExprMutatorBase {
+ public:
+  NormalizeMutator() { builder_ = BlockBuilder::Create(NullOpt); }
+
+  Expr VisitExpr(const Expr& expr) override {
+    return builder_->Normalize(ExprMutatorBase::VisitExpr(expr));
+  }
+
+  Expr VisitExpr_(const FunctionNode* op) final {
+    Expr body = this->VisitWithNewScope(op->body, op->params);
+
+    if (body.same_as(op->body)) {
+      return GetRef<Expr>(op);
+    } else {
+      return Function(op->params, body, op->ret_struct_info, op->attrs);
+    }
+  }
+
+  Expr VisitExpr_(const IfNode* op) final {
+    Expr guard = this->VisitExpr(op->cond);
+    Expr true_b = this->VisitWithNewScope(op->true_branch);
+    Expr false_b = this->VisitWithNewScope(op->false_branch);
+    if (op->cond.same_as(guard) && op->true_branch.same_as(true_b) &&
+        op->false_branch.same_as(false_b)) {
+      return GetRef<Expr>(op);
+    } else {
+      return If(guard, true_b, false_b, op->span);
+    }
+  }
+
+  Expr VisitWithNewScope(const Expr& expr, Optional<Array<Var>> params = 
NullOpt) {
+    builder_->BeginBindingBlock();
+    builder_->BeginScope(params);
+    Expr ret = this->VisitExpr(expr);
+    BindingBlock prologue = builder_->EndBlock();
+    if (!prologue->bindings.empty()) {
+      ret = SeqExpr({prologue}, ret);
+    }
+    builder_->EndScope();
+    return ret;
+  }
+
+  Expr VisitExpr_(const SeqExprNode* op) final {
+    bool all_blocks_unchanged = true;
+    Array<BindingBlock> blocks;
+    for (auto block : op->blocks) {
+      BindingBlock new_block = this->VisitBindingBlock(block);
+      if (!new_block->bindings.empty()) {
+        blocks.push_back(new_block);
+      }
+      all_blocks_unchanged &= block.same_as(new_block);
+    }
+
+    builder_->BeginBindingBlock();
+    Expr body = this->VisitExpr(op->body);
+    BindingBlock prologue = builder_->EndBlock();
+    if (!prologue->bindings.empty()) {
+      blocks.push_back(prologue);
+      all_blocks_unchanged = false;
+    }
+
+    if (all_blocks_unchanged && body.same_as(op->body)) {
+      return GetRef<Expr>(op);
+    } else {
+      return SeqExpr(blocks, body);
+    }
+  }
+
+  BindingBlock VisitBindingBlock(const BindingBlock& block) final {
+    BindingBlock ret;
+    if (const auto* node = block.as<DataflowBlockNode>()) {
+      ret = VisitBindingBlock_(node);
+    } else if (const auto* node = block.as<BindingBlockNode>()) {
+      ret = VisitBindingBlock_(node);
+    } else {
+      LOG(FATAL) << "TypeError: Invalid type: " << block->GetTypeKey();
+    }
+    return ret;
+  }
+
+  BindingBlock VisitBindingBlock_(const BindingBlockNode* block) {
+    builder_->BeginBindingBlock();
+    for (Binding binding : block->bindings) {
+      this->VisitBinding(binding);
+    }
+    return builder_->EndBlock();
+  }
+
+  BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) {
+    builder_->BeginDataflowBlock();
+    for (Binding binding : block->bindings) {
+      this->VisitBinding(binding);
+    }
+    return builder_->EndBlock();
+  }
+
+  void VisitBinding(const Binding& binding) {
+    if (const auto* node = binding.as<VarBindingNode>()) {
+      VisitBinding_(node);
+    } else if (const auto* node = binding.as<MatchCastNode>()) {
+      VisitBinding_(node);
+    } else {
+      LOG(FATAL) << "TypeError: Invalid type: " << binding->GetTypeKey();
+    }
+  }
+
+  void VisitBinding_(const VarBindingNode* binding) {
+    Expr new_value = this->VisitExpr(binding->value);
+    if (!binding->var->struct_info_.defined()) {
+      UpdateStructInfo(binding->var, GetStructInfo(new_value));
+    }
+
+    if (new_value.same_as(binding->value)) {
+      builder_->EmitNormalized(GetRef<VarBinding>(binding));
+    } else {
+      builder_->EmitNormalized(VarBinding(binding->var, new_value));
+    }
+  }
+
+  void VisitBinding_(const MatchCastNode* binding) {
+    Expr new_value = this->VisitExpr(binding->value);
+
+    if (new_value.same_as(binding->value)) {
+      builder_->EmitNormalized(GetRef<MatchCast>(binding));
+    } else {
+      builder_->EmitNormalized(
+          MatchCast(binding->var, builder_->NormalizeArgument(new_value), 
binding->struct_info));
+    }
+  }
+
+ private:
+  /*! \brief Internal block builder to emit bindings during rewriting. */
+  BlockBuilder builder_;
+};  // namespace relax
+
+Expr Normalize(const Expr& e) { return NormalizeMutator().VisitExpr(e); }
+
+namespace transform {
+
+Pass Normalize() {
+  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> 
pass_func =
+      [=](Function f, IRModule m, PassContext pc) { return 
Downcast<Function>(Normalize(f)); };
+  return CreateFunctionPass(pass_func, 1, "Normalize", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.Normalize").set_body_typed(Normalize);
+
+}  // namespace transform
+
+}  // namespace relax
+}  // namespace tvm
diff --git a/tests/python/relax/test_transform_normalize.py 
b/tests/python/relax/test_transform_normalize.py
new file mode 100644
index 0000000000..9e9533a5ed
--- /dev/null
+++ b/tests/python/relax/test_transform_normalize.py
@@ -0,0 +1,554 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+
+import tvm
+import tvm.testing
+from tvm import relax
+from tvm import tir
+from tvm.ir.base import assert_structural_equal
+
+import tvm.script
+from tvm.script import tir as T, relax as R
+
+
+def test_normalize_function():
+    m = tir.Var("m", "int64")
+    n = tir.Var("n", "int64")
+    x = relax.Var("x", R.Tensor([m, n], "float16"))
+
+    # Note: the parser automatically normalize the IR written in TVMScript,
+    # so we manually construct the function here.
+    mul_add = relax.Function(
+        [x],
+        relax.op.multiply(relax.op.add(x, x), relax.op.add(x, x)),
+        ret_struct_info=R.Tensor("float16", ndim=2),
+    )
+
+    # Note: from_expr api names private function (function without 
global_symbol) as "main"
+    before_mod = tvm.IRModule.from_expr(mul_add)
+
+    after_mod = relax.transform.Normalize()(before_mod)
+
+    @R.function
+    def expected(x: R.Tensor(("m", "n"), "float16")) -> 
R.Tensor(dtype="float16", ndim=2):
+        gv = R.add(x, x)
+        gv1 = R.add(x, x)
+        return R.multiply(gv, gv1)
+
+    assert_structural_equal(after_mod["main"], expected)
+
+
+def test_normalize_if():
+    cond = relax.Var("cond", R.Tensor([], "bool"))
+    x = relax.Var("x", R.Tensor([1], "float32"))
+    # TODO(relax-team): add type and shape inference for IfNode
+    y = relax.Var("y")
+
+    # Note: the parser automatically normalize the IR written in TVMScript,
+    # so we manually construct the function and If here.
+    f = relax.Function(
+        [cond, x],
+        relax.SeqExpr(
+            [
+                relax.BindingBlock(
+                    [
+                        relax.VarBinding(
+                            y,
+                            relax.If(
+                                cond,
+                                relax.op.multiply(relax.op.add(x, x), 
relax.op.add(x, x)),
+                                relax.op.add(relax.op.multiply(x, x), 
relax.op.multiply(x, x)),
+                            ),
+                        )
+                    ]
+                )
+            ],
+            y,
+        ),
+        ret_struct_info=R.Tensor("float32", ndim=1),
+    )
+
+    before_mod = tvm.IRModule.from_expr(f)
+    after_mod = relax.transform.Normalize()(before_mod)
+
+    @R.function
+    def expected(
+        cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32")
+    ) -> R.Tensor(dtype="float32", ndim=1):
+        if cond:
+            gv = R.add(x, x)
+            gv1 = R.add(x, x)
+            y = R.multiply(gv, gv1)
+        else:
+            gv = R.multiply(x, x)
+            gv1 = R.multiply(x, x)
+            y = R.add(gv, gv1)
+        return y
+
+    assert_structural_equal(after_mod["main"], expected)
+
+
+def test_normalize_no_op():
+    # the normalize pass should be no-op for IR in ANF
+    @tvm.script.ir_module
+    class ANFMod1:
+        @R.function
+        def f(x: R.Tensor(dtype="float32")):
+            gv = R.add(x, x)
+            gv1 = R.add(gv, gv)
+            gv2 = R.add(gv, gv1)
+            return (gv, gv2)
+
+    before_mod = ANFMod1
+    after_mod = relax.transform.Normalize()(before_mod)
+    assert_structural_equal(before_mod, after_mod, map_free_vars=True)
+
+    @tvm.script.ir_module
+    class ANFMod2:
+        @R.function
+        def foo(x: R.Tensor(("m", "n"), "float32")):
+            m, n = T.var("int64"), T.var("int64")
+            with R.dataflow():
+                lv0 = R.call_tir("test.op.identity", (x,), R.Tensor((m, n), 
dtype="float32"))
+                gv0 = R.call_tir("test.op.identity", (lv0,), R.Tensor((m, n), 
dtype="float32"))
+                R.output(gv0)
+            return gv0
+
+    mod = ANFMod2
+    mod_post = relax.transform.Normalize()(mod)
+
+    assert_structural_equal(mod, mod_post)
+
+
+def test_normalize_seq_body():
+    # a seq expression with a non-leaf body should bind the body to a var as 
well
+    x = relax.Var("x", R.Tensor([], "int32"))
+    y = relax.Var("y", R.Tensor([], "int32"))
+    seq = relax.SeqExpr([], relax.op.add(x, y))
+    f = relax.Function(
+        [x, y],
+        seq,
+        ret_struct_info=R.Tensor([], "int32"),
+    )
+
+    before_mod = tvm.IRModule.from_expr(f)
+    after_mod = relax.transform.Normalize()(before_mod)
+
+    @R.function
+    def expected(
+        x: R.Tensor((), dtype="int32"), y: R.Tensor((), dtype="int32")
+    ) -> R.Tensor(ndim=0, dtype="int32"):
+        # normalization inserts a binding like this
+        z = R.add(x, y)
+        return z
+
+    assert_structural_equal(after_mod["main"], expected)
+
+
+def test_normalize_func_body():
+    # a function with a body that is not a seq expr should have it wrapped in 
a seq expr
+    x = relax.Var("x", R.Tensor([], "int32"))
+    y = relax.Var("y", R.Tensor([], "int32"))
+    f = relax.Function(
+        [x, y],
+        relax.op.add(x, y),
+        ret_struct_info=R.Tensor([], "int32"),
+    )
+
+    before_mod = tvm.IRModule.from_expr(f)
+    after_mod = relax.transform.Normalize()(before_mod)
+
+    @R.function
+    def expected(
+        x: R.Tensor((), dtype="int32"), y: R.Tensor((), dtype="int32")
+    ) -> R.Tensor(ndim=0, dtype="int32"):
+        # result will be a seq expr where the body is a var
+        z = R.add(x, y)
+        return z
+
+    assert_structural_equal(after_mod["main"], expected)
+
+
+def test_normalize_if_branches():
+    # an if node's branches must be seq exprs
+    x = relax.Var("x", R.Tensor([], "int32"))
+    y = relax.Var("y", R.Tensor([], "int32"))
+    # TODO(@relax-team): z has a shape of () and type of DynTensorType(ndim=0),
+    # but normalization fails to infer these even though it should
+    z = relax.Var("z")
+    cond = relax.Var("cond", R.Tensor([], "bool"))
+    plus = relax.op.add(x, y)
+    mult = relax.op.multiply(x, y)
+    if_node = relax.If(cond, plus, mult)
+    seq = relax.SeqExpr([relax.BindingBlock([relax.VarBinding(z, if_node)])], 
z)
+    f = relax.Function(
+        [cond, x, y],
+        seq,
+        ret_struct_info=R.Tensor([], "int32"),
+    )
+
+    before_mod = tvm.IRModule.from_expr(f)
+    after_mod = relax.transform.Normalize()(before_mod)
+
+    @R.function
+    def expected(
+        cond: R.Tensor((), dtype="bool"),
+        x: R.Tensor((), dtype="int32"),
+        y: R.Tensor((), dtype="int32"),
+    ) -> R.Tensor(ndim=0, dtype="int32"):
+        # the bodies of the branches will be seq exprs with a binding
+        if cond:
+            w = R.add(x, y)
+            z = w
+        else:
+            w = R.multiply(x, y)
+            z = w
+        return z
+
+    assert_structural_equal(after_mod["main"], expected)
+
+
+def test_normalize_if_condition():
+    cond = relax.Var("cond", R.Tensor([], "bool"))
+    x = relax.Var("x", R.Tensor([1], "float32"))
+    # TODO(relax-team): add type and shape inference for IfNode
+    y = relax.Var("y")
+
+    # The condition is wrapped in a tuple and then indexed
+    f = relax.Function(
+        [cond, x],
+        relax.SeqExpr(
+            [
+                relax.BindingBlock(
+                    [
+                        relax.VarBinding(
+                            y,
+                            relax.If(
+                                relax.TupleGetItem(relax.Tuple([cond]), 0),
+                                relax.op.add(x, x),
+                                relax.op.multiply(x, x),
+                            ),
+                        )
+                    ]
+                )
+            ],
+            y,
+        ),
+        ret_struct_info=R.Tensor("float32", ndim=1),
+    )
+
+    before_mod = tvm.IRModule.from_expr(f)
+    after_mod = relax.transform.Normalize()(before_mod)
+
+    @R.function
+    def expected(
+        cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32")
+    ) -> R.Tensor(dtype="float32", ndim=1):
+        c = R.TupleGetItem(R.tuple(cond), 0)
+        if c:
+            gv = R.add(x, x)
+            y = gv
+        else:
+            gv = R.multiply(x, x)
+            y = gv
+        return y
+
+    assert_structural_equal(after_mod["main"], expected)
+
+
+def test_normalize_tuple_get_item():
+    x = relax.Var("x", R.Tensor([], "int32"))
+    f = relax.Function(
+        [x],
+        relax.TupleGetItem(
+            relax.TupleGetItem(
+                relax.Tuple([relax.Tuple([x])]),
+                0,
+            ),
+            0,
+        ),
+        ret_struct_info=R.Tensor([], "int32"),
+    )
+
+    before_mod = tvm.IRModule.from_expr(f)
+    after_mod = relax.transform.Normalize()(before_mod)
+
+    # TODO: Revisit once we canonicalize SeqExprs (part of normalization?)
+    # Not using the parser this time because writing it out correctly results 
in
+    # *one* binding block, whereas the normalized version has *two*
+    idx_var = relax.Var("idx_var", R.Tuple([R.Tensor([], "int32")]))
+    ret_var = relax.Var("ret", R.Tensor([], "int32"))
+    expected_f = relax.Function(
+        [x],
+        relax.SeqExpr(
+            [
+                relax.BindingBlock(
+                    [
+                        relax.VarBinding(
+                            idx_var, 
relax.TupleGetItem(relax.Tuple([relax.Tuple([x])]), 0)
+                        )
+                    ]
+                ),
+                relax.BindingBlock([relax.VarBinding(ret_var, 
relax.TupleGetItem(idx_var, 0))]),
+            ],
+            ret_var,
+        ),
+        ret_struct_info=R.Tensor([], "int32"),
+    )
+    expected_mod = tvm.IRModule.from_expr(expected_f)
+    # apply normalization to fill in type and shape annotations (tedious 
otherwise)
+    final_mod = relax.transform.Normalize()(expected_mod)
+
+    assert_structural_equal(after_mod, final_mod)
+
+
+def test_normalize_combine_nearby_blocks():
+    x = relax.Var("x", R.Tensor([], "int32"))
+    v0 = relax.Var("v0", R.Tensor([], "int32"))
+    v1 = relax.Var("v1", R.Tensor([], "int32"))
+    v2 = relax.Var("v2", R.Tensor([], "int32"))
+    v3 = relax.Var("v3", R.Tensor([], "int32"))
+    f = relax.Function(
+        [x],
+        relax.SeqExpr(
+            [
+                relax.DataflowBlock([relax.VarBinding(v0, x)]),
+                relax.DataflowBlock([relax.VarBinding(v1, v0)]),
+                relax.BindingBlock([relax.VarBinding(v2, v1)]),
+                relax.BindingBlock([relax.VarBinding(v3, v2)]),
+            ],
+            v3,
+        ),
+        ret_struct_info=R.Tensor([], "int32"),
+    )
+
+    after_mod = relax.transform.Normalize()(tvm.IRModule.from_expr(f))
+
+    @R.function
+    def expected(x: R.Tensor((), "int32")):
+        with R.dataflow():
+            v0 = x
+            v1 = v0
+            R.output(v0, v1)
+        v2 = v1
+        v3 = v2
+        return v3
+
+    assert_structural_equal(after_mod["main"], expected)
+
+
+def test_normalize_nested_seq():
+    x = relax.Var("x", R.Tensor([], "int32"))
+    y = relax.Var("y", R.Tensor([], "int32"))
+    z = relax.Var("z", R.Tensor([], "int32"))
+    seq = relax.SeqExpr(
+        [
+            relax.BindingBlock(
+                [
+                    relax.VarBinding(x, relax.const(1)),
+                    relax.VarBinding(
+                        y,
+                        relax.SeqExpr(
+                            [relax.BindingBlock([relax.VarBinding(z, 
relax.const(2))])],
+                            z,
+                        ),
+                    ),
+                ]
+            )
+        ],
+        y,
+    )
+
+    f = relax.Function(
+        [],
+        seq,
+        ret_struct_info=R.Tensor([], "int32"),
+    )
+    after_mod = relax.transform.Normalize()(tvm.IRModule.from_expr(f))
+
+    @R.function
+    def expected():
+        x = relax.const(1)
+        z = relax.const(2)
+        y = z
+        return y
+
+    assert_structural_equal(after_mod["main"], expected)
+
+
+def test_normalize_nested_seq_dataflow():
+    x = relax.Var("x", R.Tensor([], "int32"))
+    y = relax.Var("y", R.Tensor([], "int32"))
+    z = relax.Var("z", R.Tensor([], "int32"))
+    q = relax.Var("u", R.Tensor([], "int32"))
+    w = relax.DataflowVar("w", R.Tensor([], "int32"))
+    u = relax.Var("u", R.Tensor([], "int32"))
+    seq = relax.SeqExpr(
+        [
+            relax.BindingBlock(
+                [
+                    relax.VarBinding(x, relax.const(1)),
+                    relax.VarBinding(
+                        y,
+                        relax.SeqExpr(
+                            [
+                                relax.BindingBlock([relax.VarBinding(q, 
relax.const(2))]),
+                                relax.DataflowBlock(
+                                    [
+                                        relax.VarBinding(w, q),
+                                        relax.VarBinding(u, w),
+                                    ]
+                                ),
+                                relax.BindingBlock([relax.VarBinding(z, u)]),
+                            ],
+                            z,
+                        ),
+                    ),
+                ]
+            )
+        ],
+        y,
+    )
+
+    f = relax.Function(
+        [],
+        seq,
+        ret_struct_info=R.Tensor([], "int32"),
+    )
+    after_mod = relax.transform.Normalize()(tvm.IRModule.from_expr(f))
+
+    @R.function
+    def expected():
+        x = relax.const(1)
+        q = relax.const(2)
+        with R.dataflow():
+            w = q
+            u = w
+            R.output(u)
+        z = u
+        y = z
+        return y
+
+    assert_structural_equal(after_mod["main"], expected)
+
+
+def test_normalize_deeply_nested_seq():
+    x = relax.Var("x", R.Tensor([], "int32"))
+    y = relax.Var("y", R.Tensor([], "int32"))
+    z = relax.Var("z", R.Tensor([], "int32"))
+    u = relax.Var("u", R.Tensor([], "int32"))
+    v = relax.Var("v", R.Tensor([], "int32"))
+    w = relax.Var("w", R.Tensor([], "int32"))
+    _ = relax.Var("w", R.Tensor([], "int32"))
+    seq = relax.SeqExpr(
+        [
+            relax.BindingBlock(
+                [
+                    relax.VarBinding(x, relax.const(1)),
+                    relax.VarBinding(
+                        y,
+                        relax.SeqExpr(
+                            [
+                                relax.BindingBlock(
+                                    [
+                                        relax.VarBinding(
+                                            z,
+                                            relax.SeqExpr(
+                                                [
+                                                    relax.BindingBlock(
+                                                        [
+                                                            
relax.VarBinding(u, relax.const(2)),
+                                                            relax.MatchCast(
+                                                                _, u, 
R.Tensor([], "int32")
+                                                            ),
+                                                            
relax.VarBinding(v, u),
+                                                            relax.MatchCast(
+                                                                w, v, 
R.Tensor([], "int32")
+                                                            ),
+                                                        ]
+                                                    )
+                                                ],
+                                                w,
+                                            ),
+                                        )
+                                    ]
+                                )
+                            ],
+                            z,
+                        ),
+                    ),
+                ]
+            )
+        ],
+        y,
+    )
+
+    f = relax.Function(
+        [],
+        seq,
+        ret_struct_info=R.Tensor([], "int32"),
+    )
+    after_mod = relax.transform.Normalize()(tvm.IRModule.from_expr(f))
+
+    @R.function
+    def expected():
+        x = relax.const(1)
+        u = relax.const(2)
+        _ = R.match_cast(u, R.Tensor((), "int32"))
+        v = u
+        w = R.match_cast(v, R.Tensor((), "int32"))
+        z = w
+        y = z
+        return y
+
+    assert_structural_equal(after_mod["main"], expected)
+
+
[email protected]()
+def test_nesting_non_dataflow_in_dataflow_error():
+    x = relax.DataflowVar("x", R.Tensor([], "int32"))
+    y = relax.Var("y", R.Tensor([], "int32"))
+    z = relax.Var("z", R.Tensor([], "int32"))
+    seq = relax.SeqExpr(
+        [
+            relax.DataflowBlock(
+                [
+                    relax.VarBinding(x, relax.const(1)),
+                    relax.VarBinding(
+                        y,
+                        relax.SeqExpr(
+                            [relax.BindingBlock([relax.VarBinding(z, 
relax.const(2))])],
+                            z,
+                        ),
+                    ),
+                ]
+            )
+        ],
+        y,
+    )
+    f = relax.Function(
+        [],
+        seq,
+        ret_struct_info=R.Tensor([], "int32"),
+    )
+    relax.transform.Normalize()(tvm.IRModule.from_expr(f))
+    # should fail due to a normal binding block being inside a dataflowblock
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to