This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new bb5c8df4c7 [Unity][Pass] Normalize Pass (#14031)
bb5c8df4c7 is described below
commit bb5c8df4c79e599dbf82fbab9a32543eb2b5a6d9
Author: Lesheng Jin <[email protected]>
AuthorDate: Fri Feb 17 16:25:37 2023 -0800
[Unity][Pass] Normalize Pass (#14031)
This PR implements relax `Normalize` Pass, which allows users to transform
Relax IR to normal form, i.e., the expressions are normalized (no nesting and
hence the AST is in ANF), and all `checked_type_` and `shape_` of expressions
are available. (tests are added).
Co-Authored-by: Yuchen Jin <[email protected]>
Co-Authored-by: Ruihang Lai <[email protected]>
Co-Authored-by: Siyuan Feng <[email protected]>
Co-Authored-by: Tianqi Chen <[email protected]>
---
include/tvm/relax/transform.h | 9 +
python/tvm/relax/transform/transform.py | 11 +
src/relax/transform/normalize.cc | 186 +++++++++
tests/python/relax/test_transform_normalize.py | 554 +++++++++++++++++++++++++
4 files changed, 760 insertions(+)
diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index e9f63ee9db..7a4054d414 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -133,6 +133,15 @@ TVM_DLL Pass BindParams(String func_name, Map<String,
runtime::NDArray> params);
* \return The Pass.
*/
TVM_DLL Pass FoldConstant();
+
+/*!
+ * \brief Transform Relax IR to normal form: transform AST to A-normal form,
and fill the
+ * checked_type_ and shape_ of expressions.
+ *
+ * \return The Pass.
+ */
+TVM_DLL Pass Normalize();
+
} // namespace transform
} // namespace relax
} // namespace tvm
diff --git a/python/tvm/relax/transform/transform.py
b/python/tvm/relax/transform/transform.py
index c0ac180ff1..7fcf0b1121 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -68,6 +68,17 @@ def CallTIRRewrite() -> tvm.ir.transform.Pass:
return _ffi_api.CallTIRRewrite() # type: ignore
+def Normalize() -> tvm.ir.transform.Pass:
+ """Transforming Relax IR to normal form, i.e., the expressions are
normalized(no nesting
+ and hence the AST is in ANF), and all checked_type_ and shape_ of
expressions are available.
+
+ Returns
+ -------
+ ret: tvm.ir.transform.Pass
+ """
+ return _ffi_api.Normalize() # type: ignore
+
+
def RewriteDataflowReshape() -> tvm.ir.transform.Pass:
"""Convert all reshape-like call_tir to VM reshape operator call.
The VM reshape operator calls will be further lowered to a CreateView
diff --git a/src/relax/transform/normalize.cc b/src/relax/transform/normalize.cc
new file mode 100644
index 0000000000..915498178f
--- /dev/null
+++ b/src/relax/transform/normalize.cc
@@ -0,0 +1,186 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relax/transform/normalize.cc
+ * \brief Pass for transforming Relax IR to normal form, i.e., the expressions
are normalized(no
+ * nesting and hence the AST is in ANF), and all checked_type_ and shape_ of
expressions are
+ * available.
+ */
+
+#include <tvm/relax/expr.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/struct_info.h>
+#include <tvm/relax/transform.h>
+
+namespace tvm {
+namespace relax {
+
+// TODO(@altanh): LCA binding lifting
+class NormalizeMutator : public ExprMutatorBase {
+ public:
+ NormalizeMutator() { builder_ = BlockBuilder::Create(NullOpt); }
+
+ Expr VisitExpr(const Expr& expr) override {
+ return builder_->Normalize(ExprMutatorBase::VisitExpr(expr));
+ }
+
+ Expr VisitExpr_(const FunctionNode* op) final {
+ Expr body = this->VisitWithNewScope(op->body, op->params);
+
+ if (body.same_as(op->body)) {
+ return GetRef<Expr>(op);
+ } else {
+ return Function(op->params, body, op->ret_struct_info, op->attrs);
+ }
+ }
+
+ Expr VisitExpr_(const IfNode* op) final {
+ Expr guard = this->VisitExpr(op->cond);
+ Expr true_b = this->VisitWithNewScope(op->true_branch);
+ Expr false_b = this->VisitWithNewScope(op->false_branch);
+ if (op->cond.same_as(guard) && op->true_branch.same_as(true_b) &&
+ op->false_branch.same_as(false_b)) {
+ return GetRef<Expr>(op);
+ } else {
+ return If(guard, true_b, false_b, op->span);
+ }
+ }
+
+ Expr VisitWithNewScope(const Expr& expr, Optional<Array<Var>> params =
NullOpt) {
+ builder_->BeginBindingBlock();
+ builder_->BeginScope(params);
+ Expr ret = this->VisitExpr(expr);
+ BindingBlock prologue = builder_->EndBlock();
+ if (!prologue->bindings.empty()) {
+ ret = SeqExpr({prologue}, ret);
+ }
+ builder_->EndScope();
+ return ret;
+ }
+
+ Expr VisitExpr_(const SeqExprNode* op) final {
+ bool all_blocks_unchanged = true;
+ Array<BindingBlock> blocks;
+ for (auto block : op->blocks) {
+ BindingBlock new_block = this->VisitBindingBlock(block);
+ if (!new_block->bindings.empty()) {
+ blocks.push_back(new_block);
+ }
+ all_blocks_unchanged &= block.same_as(new_block);
+ }
+
+ builder_->BeginBindingBlock();
+ Expr body = this->VisitExpr(op->body);
+ BindingBlock prologue = builder_->EndBlock();
+ if (!prologue->bindings.empty()) {
+ blocks.push_back(prologue);
+ all_blocks_unchanged = false;
+ }
+
+ if (all_blocks_unchanged && body.same_as(op->body)) {
+ return GetRef<Expr>(op);
+ } else {
+ return SeqExpr(blocks, body);
+ }
+ }
+
+ BindingBlock VisitBindingBlock(const BindingBlock& block) final {
+ BindingBlock ret;
+ if (const auto* node = block.as<DataflowBlockNode>()) {
+ ret = VisitBindingBlock_(node);
+ } else if (const auto* node = block.as<BindingBlockNode>()) {
+ ret = VisitBindingBlock_(node);
+ } else {
+ LOG(FATAL) << "TypeError: Invalid type: " << block->GetTypeKey();
+ }
+ return ret;
+ }
+
+ BindingBlock VisitBindingBlock_(const BindingBlockNode* block) {
+ builder_->BeginBindingBlock();
+ for (Binding binding : block->bindings) {
+ this->VisitBinding(binding);
+ }
+ return builder_->EndBlock();
+ }
+
+ BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) {
+ builder_->BeginDataflowBlock();
+ for (Binding binding : block->bindings) {
+ this->VisitBinding(binding);
+ }
+ return builder_->EndBlock();
+ }
+
+ void VisitBinding(const Binding& binding) {
+ if (const auto* node = binding.as<VarBindingNode>()) {
+ VisitBinding_(node);
+ } else if (const auto* node = binding.as<MatchCastNode>()) {
+ VisitBinding_(node);
+ } else {
+ LOG(FATAL) << "TypeError: Invalid type: " << binding->GetTypeKey();
+ }
+ }
+
+ void VisitBinding_(const VarBindingNode* binding) {
+ Expr new_value = this->VisitExpr(binding->value);
+ if (!binding->var->struct_info_.defined()) {
+ UpdateStructInfo(binding->var, GetStructInfo(new_value));
+ }
+
+ if (new_value.same_as(binding->value)) {
+ builder_->EmitNormalized(GetRef<VarBinding>(binding));
+ } else {
+ builder_->EmitNormalized(VarBinding(binding->var, new_value));
+ }
+ }
+
+ void VisitBinding_(const MatchCastNode* binding) {
+ Expr new_value = this->VisitExpr(binding->value);
+
+ if (new_value.same_as(binding->value)) {
+ builder_->EmitNormalized(GetRef<MatchCast>(binding));
+ } else {
+ builder_->EmitNormalized(
+ MatchCast(binding->var, builder_->NormalizeArgument(new_value),
binding->struct_info));
+ }
+ }
+
+ private:
+ /*! \brief Internal block builder to emit bindings during rewriting. */
+ BlockBuilder builder_;
+}; // namespace relax
+
+Expr Normalize(const Expr& e) { return NormalizeMutator().VisitExpr(e); }
+
+namespace transform {
+
+Pass Normalize() {
+ runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>
pass_func =
+ [=](Function f, IRModule m, PassContext pc) { return
Downcast<Function>(Normalize(f)); };
+ return CreateFunctionPass(pass_func, 1, "Normalize", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.Normalize").set_body_typed(Normalize);
+
+} // namespace transform
+
+} // namespace relax
+} // namespace tvm
diff --git a/tests/python/relax/test_transform_normalize.py
b/tests/python/relax/test_transform_normalize.py
new file mode 100644
index 0000000000..9e9533a5ed
--- /dev/null
+++ b/tests/python/relax/test_transform_normalize.py
@@ -0,0 +1,554 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+
+import tvm
+import tvm.testing
+from tvm import relax
+from tvm import tir
+from tvm.ir.base import assert_structural_equal
+
+import tvm.script
+from tvm.script import tir as T, relax as R
+
+
+def test_normalize_function():
+ m = tir.Var("m", "int64")
+ n = tir.Var("n", "int64")
+ x = relax.Var("x", R.Tensor([m, n], "float16"))
+
+ # Note: the parser automatically normalize the IR written in TVMScript,
+ # so we manually construct the function here.
+ mul_add = relax.Function(
+ [x],
+ relax.op.multiply(relax.op.add(x, x), relax.op.add(x, x)),
+ ret_struct_info=R.Tensor("float16", ndim=2),
+ )
+
+ # Note: from_expr api names private function (function without
global_symbol) as "main"
+ before_mod = tvm.IRModule.from_expr(mul_add)
+
+ after_mod = relax.transform.Normalize()(before_mod)
+
+ @R.function
+ def expected(x: R.Tensor(("m", "n"), "float16")) ->
R.Tensor(dtype="float16", ndim=2):
+ gv = R.add(x, x)
+ gv1 = R.add(x, x)
+ return R.multiply(gv, gv1)
+
+ assert_structural_equal(after_mod["main"], expected)
+
+
+def test_normalize_if():
+ cond = relax.Var("cond", R.Tensor([], "bool"))
+ x = relax.Var("x", R.Tensor([1], "float32"))
+ # TODO(relax-team): add type and shape inference for IfNode
+ y = relax.Var("y")
+
+ # Note: the parser automatically normalize the IR written in TVMScript,
+ # so we manually construct the function and If here.
+ f = relax.Function(
+ [cond, x],
+ relax.SeqExpr(
+ [
+ relax.BindingBlock(
+ [
+ relax.VarBinding(
+ y,
+ relax.If(
+ cond,
+ relax.op.multiply(relax.op.add(x, x),
relax.op.add(x, x)),
+ relax.op.add(relax.op.multiply(x, x),
relax.op.multiply(x, x)),
+ ),
+ )
+ ]
+ )
+ ],
+ y,
+ ),
+ ret_struct_info=R.Tensor("float32", ndim=1),
+ )
+
+ before_mod = tvm.IRModule.from_expr(f)
+ after_mod = relax.transform.Normalize()(before_mod)
+
+ @R.function
+ def expected(
+ cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32")
+ ) -> R.Tensor(dtype="float32", ndim=1):
+ if cond:
+ gv = R.add(x, x)
+ gv1 = R.add(x, x)
+ y = R.multiply(gv, gv1)
+ else:
+ gv = R.multiply(x, x)
+ gv1 = R.multiply(x, x)
+ y = R.add(gv, gv1)
+ return y
+
+ assert_structural_equal(after_mod["main"], expected)
+
+
+def test_normalize_no_op():
+ # the normalize pass should be no-op for IR in ANF
+ @tvm.script.ir_module
+ class ANFMod1:
+ @R.function
+ def f(x: R.Tensor(dtype="float32")):
+ gv = R.add(x, x)
+ gv1 = R.add(gv, gv)
+ gv2 = R.add(gv, gv1)
+ return (gv, gv2)
+
+ before_mod = ANFMod1
+ after_mod = relax.transform.Normalize()(before_mod)
+ assert_structural_equal(before_mod, after_mod, map_free_vars=True)
+
+ @tvm.script.ir_module
+ class ANFMod2:
+ @R.function
+ def foo(x: R.Tensor(("m", "n"), "float32")):
+ m, n = T.var("int64"), T.var("int64")
+ with R.dataflow():
+ lv0 = R.call_tir("test.op.identity", (x,), R.Tensor((m, n),
dtype="float32"))
+ gv0 = R.call_tir("test.op.identity", (lv0,), R.Tensor((m, n),
dtype="float32"))
+ R.output(gv0)
+ return gv0
+
+ mod = ANFMod2
+ mod_post = relax.transform.Normalize()(mod)
+
+ assert_structural_equal(mod, mod_post)
+
+
+def test_normalize_seq_body():
+ # a seq expression with a non-leaf body should bind the body to a var as
well
+ x = relax.Var("x", R.Tensor([], "int32"))
+ y = relax.Var("y", R.Tensor([], "int32"))
+ seq = relax.SeqExpr([], relax.op.add(x, y))
+ f = relax.Function(
+ [x, y],
+ seq,
+ ret_struct_info=R.Tensor([], "int32"),
+ )
+
+ before_mod = tvm.IRModule.from_expr(f)
+ after_mod = relax.transform.Normalize()(before_mod)
+
+ @R.function
+ def expected(
+ x: R.Tensor((), dtype="int32"), y: R.Tensor((), dtype="int32")
+ ) -> R.Tensor(ndim=0, dtype="int32"):
+ # normalization inserts a binding like this
+ z = R.add(x, y)
+ return z
+
+ assert_structural_equal(after_mod["main"], expected)
+
+
+def test_normalize_func_body():
+ # a function with a body that is not a seq expr should have it wrapped in
a seq expr
+ x = relax.Var("x", R.Tensor([], "int32"))
+ y = relax.Var("y", R.Tensor([], "int32"))
+ f = relax.Function(
+ [x, y],
+ relax.op.add(x, y),
+ ret_struct_info=R.Tensor([], "int32"),
+ )
+
+ before_mod = tvm.IRModule.from_expr(f)
+ after_mod = relax.transform.Normalize()(before_mod)
+
+ @R.function
+ def expected(
+ x: R.Tensor((), dtype="int32"), y: R.Tensor((), dtype="int32")
+ ) -> R.Tensor(ndim=0, dtype="int32"):
+ # result will be a seq expr where the body is a var
+ z = R.add(x, y)
+ return z
+
+ assert_structural_equal(after_mod["main"], expected)
+
+
+def test_normalize_if_branches():
+ # an if node's branches must be seq exprs
+ x = relax.Var("x", R.Tensor([], "int32"))
+ y = relax.Var("y", R.Tensor([], "int32"))
+ # TODO(@relax-team): z has a shape of () and type of DynTensorType(ndim=0),
+ # but normalization fails to infer these even though it should
+ z = relax.Var("z")
+ cond = relax.Var("cond", R.Tensor([], "bool"))
+ plus = relax.op.add(x, y)
+ mult = relax.op.multiply(x, y)
+ if_node = relax.If(cond, plus, mult)
+ seq = relax.SeqExpr([relax.BindingBlock([relax.VarBinding(z, if_node)])],
z)
+ f = relax.Function(
+ [cond, x, y],
+ seq,
+ ret_struct_info=R.Tensor([], "int32"),
+ )
+
+ before_mod = tvm.IRModule.from_expr(f)
+ after_mod = relax.transform.Normalize()(before_mod)
+
+ @R.function
+ def expected(
+ cond: R.Tensor((), dtype="bool"),
+ x: R.Tensor((), dtype="int32"),
+ y: R.Tensor((), dtype="int32"),
+ ) -> R.Tensor(ndim=0, dtype="int32"):
+ # the bodies of the branches will be seq exprs with a binding
+ if cond:
+ w = R.add(x, y)
+ z = w
+ else:
+ w = R.multiply(x, y)
+ z = w
+ return z
+
+ assert_structural_equal(after_mod["main"], expected)
+
+
+def test_normalize_if_condition():
+ cond = relax.Var("cond", R.Tensor([], "bool"))
+ x = relax.Var("x", R.Tensor([1], "float32"))
+ # TODO(relax-team): add type and shape inference for IfNode
+ y = relax.Var("y")
+
+ # The condition is wrapped in a tuple and then indexed
+ f = relax.Function(
+ [cond, x],
+ relax.SeqExpr(
+ [
+ relax.BindingBlock(
+ [
+ relax.VarBinding(
+ y,
+ relax.If(
+ relax.TupleGetItem(relax.Tuple([cond]), 0),
+ relax.op.add(x, x),
+ relax.op.multiply(x, x),
+ ),
+ )
+ ]
+ )
+ ],
+ y,
+ ),
+ ret_struct_info=R.Tensor("float32", ndim=1),
+ )
+
+ before_mod = tvm.IRModule.from_expr(f)
+ after_mod = relax.transform.Normalize()(before_mod)
+
+ @R.function
+ def expected(
+ cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32")
+ ) -> R.Tensor(dtype="float32", ndim=1):
+ c = R.TupleGetItem(R.tuple(cond), 0)
+ if c:
+ gv = R.add(x, x)
+ y = gv
+ else:
+ gv = R.multiply(x, x)
+ y = gv
+ return y
+
+ assert_structural_equal(after_mod["main"], expected)
+
+
+def test_normalize_tuple_get_item():
+ x = relax.Var("x", R.Tensor([], "int32"))
+ f = relax.Function(
+ [x],
+ relax.TupleGetItem(
+ relax.TupleGetItem(
+ relax.Tuple([relax.Tuple([x])]),
+ 0,
+ ),
+ 0,
+ ),
+ ret_struct_info=R.Tensor([], "int32"),
+ )
+
+ before_mod = tvm.IRModule.from_expr(f)
+ after_mod = relax.transform.Normalize()(before_mod)
+
+ # TODO: Revisit once we canonicalize SeqExprs (part of normalization?)
+ # Not using the parser this time because writing it out correctly results
in
+ # *one* binding block, whereas the normalized version has *two*
+ idx_var = relax.Var("idx_var", R.Tuple([R.Tensor([], "int32")]))
+ ret_var = relax.Var("ret", R.Tensor([], "int32"))
+ expected_f = relax.Function(
+ [x],
+ relax.SeqExpr(
+ [
+ relax.BindingBlock(
+ [
+ relax.VarBinding(
+ idx_var,
relax.TupleGetItem(relax.Tuple([relax.Tuple([x])]), 0)
+ )
+ ]
+ ),
+ relax.BindingBlock([relax.VarBinding(ret_var,
relax.TupleGetItem(idx_var, 0))]),
+ ],
+ ret_var,
+ ),
+ ret_struct_info=R.Tensor([], "int32"),
+ )
+ expected_mod = tvm.IRModule.from_expr(expected_f)
+ # apply normalization to fill in type and shape annotations (tedious
otherwise)
+ final_mod = relax.transform.Normalize()(expected_mod)
+
+ assert_structural_equal(after_mod, final_mod)
+
+
+def test_normalize_combine_nearby_blocks():
+ x = relax.Var("x", R.Tensor([], "int32"))
+ v0 = relax.Var("v0", R.Tensor([], "int32"))
+ v1 = relax.Var("v1", R.Tensor([], "int32"))
+ v2 = relax.Var("v2", R.Tensor([], "int32"))
+ v3 = relax.Var("v3", R.Tensor([], "int32"))
+ f = relax.Function(
+ [x],
+ relax.SeqExpr(
+ [
+ relax.DataflowBlock([relax.VarBinding(v0, x)]),
+ relax.DataflowBlock([relax.VarBinding(v1, v0)]),
+ relax.BindingBlock([relax.VarBinding(v2, v1)]),
+ relax.BindingBlock([relax.VarBinding(v3, v2)]),
+ ],
+ v3,
+ ),
+ ret_struct_info=R.Tensor([], "int32"),
+ )
+
+ after_mod = relax.transform.Normalize()(tvm.IRModule.from_expr(f))
+
+ @R.function
+ def expected(x: R.Tensor((), "int32")):
+ with R.dataflow():
+ v0 = x
+ v1 = v0
+ R.output(v0, v1)
+ v2 = v1
+ v3 = v2
+ return v3
+
+ assert_structural_equal(after_mod["main"], expected)
+
+
+def test_normalize_nested_seq():
+ x = relax.Var("x", R.Tensor([], "int32"))
+ y = relax.Var("y", R.Tensor([], "int32"))
+ z = relax.Var("z", R.Tensor([], "int32"))
+ seq = relax.SeqExpr(
+ [
+ relax.BindingBlock(
+ [
+ relax.VarBinding(x, relax.const(1)),
+ relax.VarBinding(
+ y,
+ relax.SeqExpr(
+ [relax.BindingBlock([relax.VarBinding(z,
relax.const(2))])],
+ z,
+ ),
+ ),
+ ]
+ )
+ ],
+ y,
+ )
+
+ f = relax.Function(
+ [],
+ seq,
+ ret_struct_info=R.Tensor([], "int32"),
+ )
+ after_mod = relax.transform.Normalize()(tvm.IRModule.from_expr(f))
+
+ @R.function
+ def expected():
+ x = relax.const(1)
+ z = relax.const(2)
+ y = z
+ return y
+
+ assert_structural_equal(after_mod["main"], expected)
+
+
+def test_normalize_nested_seq_dataflow():
+ x = relax.Var("x", R.Tensor([], "int32"))
+ y = relax.Var("y", R.Tensor([], "int32"))
+ z = relax.Var("z", R.Tensor([], "int32"))
+ q = relax.Var("u", R.Tensor([], "int32"))
+ w = relax.DataflowVar("w", R.Tensor([], "int32"))
+ u = relax.Var("u", R.Tensor([], "int32"))
+ seq = relax.SeqExpr(
+ [
+ relax.BindingBlock(
+ [
+ relax.VarBinding(x, relax.const(1)),
+ relax.VarBinding(
+ y,
+ relax.SeqExpr(
+ [
+ relax.BindingBlock([relax.VarBinding(q,
relax.const(2))]),
+ relax.DataflowBlock(
+ [
+ relax.VarBinding(w, q),
+ relax.VarBinding(u, w),
+ ]
+ ),
+ relax.BindingBlock([relax.VarBinding(z, u)]),
+ ],
+ z,
+ ),
+ ),
+ ]
+ )
+ ],
+ y,
+ )
+
+ f = relax.Function(
+ [],
+ seq,
+ ret_struct_info=R.Tensor([], "int32"),
+ )
+ after_mod = relax.transform.Normalize()(tvm.IRModule.from_expr(f))
+
+ @R.function
+ def expected():
+ x = relax.const(1)
+ q = relax.const(2)
+ with R.dataflow():
+ w = q
+ u = w
+ R.output(u)
+ z = u
+ y = z
+ return y
+
+ assert_structural_equal(after_mod["main"], expected)
+
+
+def test_normalize_deeply_nested_seq():
+ x = relax.Var("x", R.Tensor([], "int32"))
+ y = relax.Var("y", R.Tensor([], "int32"))
+ z = relax.Var("z", R.Tensor([], "int32"))
+ u = relax.Var("u", R.Tensor([], "int32"))
+ v = relax.Var("v", R.Tensor([], "int32"))
+ w = relax.Var("w", R.Tensor([], "int32"))
+ _ = relax.Var("w", R.Tensor([], "int32"))
+ seq = relax.SeqExpr(
+ [
+ relax.BindingBlock(
+ [
+ relax.VarBinding(x, relax.const(1)),
+ relax.VarBinding(
+ y,
+ relax.SeqExpr(
+ [
+ relax.BindingBlock(
+ [
+ relax.VarBinding(
+ z,
+ relax.SeqExpr(
+ [
+ relax.BindingBlock(
+ [
+
relax.VarBinding(u, relax.const(2)),
+ relax.MatchCast(
+ _, u,
R.Tensor([], "int32")
+ ),
+
relax.VarBinding(v, u),
+ relax.MatchCast(
+ w, v,
R.Tensor([], "int32")
+ ),
+ ]
+ )
+ ],
+ w,
+ ),
+ )
+ ]
+ )
+ ],
+ z,
+ ),
+ ),
+ ]
+ )
+ ],
+ y,
+ )
+
+ f = relax.Function(
+ [],
+ seq,
+ ret_struct_info=R.Tensor([], "int32"),
+ )
+ after_mod = relax.transform.Normalize()(tvm.IRModule.from_expr(f))
+
+ @R.function
+ def expected():
+ x = relax.const(1)
+ u = relax.const(2)
+ _ = R.match_cast(u, R.Tensor((), "int32"))
+ v = u
+ w = R.match_cast(v, R.Tensor((), "int32"))
+ z = w
+ y = z
+ return y
+
+ assert_structural_equal(after_mod["main"], expected)
+
+
[email protected]()
+def test_nesting_non_dataflow_in_dataflow_error():
+ x = relax.DataflowVar("x", R.Tensor([], "int32"))
+ y = relax.Var("y", R.Tensor([], "int32"))
+ z = relax.Var("z", R.Tensor([], "int32"))
+ seq = relax.SeqExpr(
+ [
+ relax.DataflowBlock(
+ [
+ relax.VarBinding(x, relax.const(1)),
+ relax.VarBinding(
+ y,
+ relax.SeqExpr(
+ [relax.BindingBlock([relax.VarBinding(z,
relax.const(2))])],
+ z,
+ ),
+ ),
+ ]
+ )
+ ],
+ y,
+ )
+ f = relax.Function(
+ [],
+ seq,
+ ret_struct_info=R.Tensor([], "int32"),
+ )
+ relax.transform.Normalize()(tvm.IRModule.from_expr(f))
+ # should fail due to a normal binding block being inside a dataflowblock
+
+
+if __name__ == "__main__":
+ tvm.testing.main()