This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ef46f4e8d3 Revert "[SLM] Allow modules to define pre-processing of 
weights" (#16777)
ef46f4e8d3 is described below

commit ef46f4e8d33f1946dca9cd61f6db5eec79c7deab
Author: Tianqi Chen <[email protected]>
AuthorDate: Mon Mar 25 11:10:19 2024 -0400

    Revert "[SLM] Allow modules to define pre-processing of weights" (#16777)
    
    Revert "[SLM] Allow modules to define pre-processing of weights (#16757)"
    
    This reverts commit 1cccc3b5d65cae743a2becb7e256c05897af29ca.
---
 python/tvm/relax/frontend/nn/core.py               |  17 +-
 python/tvm/relax/frontend/nn/exporter.py           |  40 +-
 tests/python/relax/test_frontend_nn_exporter.py    | 443 ---------------------
 .../python/relax/test_frontend_nn_extern_module.py |  10 +-
 tests/python/relax/test_frontend_nn_modules.py     |   3 +-
 tests/python/relax/test_frontend_nn_op.py          |  27 +-
 tests/python/relax/test_frontend_nn_packing.py     |   3 +-
 tests/python/relax/test_frontend_nn_subroutines.py |  13 +-
 8 files changed, 58 insertions(+), 498 deletions(-)

diff --git a/python/tvm/relax/frontend/nn/core.py 
b/python/tvm/relax/frontend/nn/core.py
index 820acd235d..b7b3f411ed 100644
--- a/python/tvm/relax/frontend/nn/core.py
+++ b/python/tvm/relax/frontend/nn/core.py
@@ -591,22 +591,7 @@ def wrap_nested(expr: rx.Expr, name: str) -> Union[Tensor, 
Sequence[Tensor]]:
         The computed result.
     """
     if not isinstance(expr, rx.DataflowVar):
-        block_builder = BlockBuilder.current()
-        if block_builder is None:
-            # Normalize to make sure we have valid StructInfo, but
-            # wait until we are actually building the function to
-            # flatten nested expressions.
-            #
-            # TODO(Lunderberg): Make this easier to call.  Infering
-            # struct info for a nested expression should be doable in
-            # a free function, without requiring an active
-            # BlockBuilder and an active FunctionFrame.
-            builder = BlockBuilder()
-            with builder.function("dummy_scope", params=[]):
-                expr = builder.normalize(expr)
-                builder.emit_func_output([])
-        else:
-            expr = BlockBuilder.current().emit(expr, name)
+        expr = BlockBuilder.current().emit(expr, name)
     if isinstance(expr.struct_info_, TensorStructInfo):
         return Tensor(_expr=expr)
     if isinstance(expr.struct_info_, TupleStructInfo):
diff --git a/python/tvm/relax/frontend/nn/exporter.py 
b/python/tvm/relax/frontend/nn/exporter.py
index 525d689f49..1a7dcd6a64 100644
--- a/python/tvm/relax/frontend/nn/exporter.py
+++ b/python/tvm/relax/frontend/nn/exporter.py
@@ -111,8 +111,7 @@ class Exporter:
             return result
 
         # pylint: enable=protected-access
-
-        params = _params()
+        params = None
         effects = _effects()
         ext_mods = self.extern_mods
         with self:
@@ -122,6 +121,7 @@ class Exporter:
                         outputs = _emit_effect_init(self.builder, effects)
                     self.builder.emit_func_output(outputs, params=[])
             for method_name, method_spec in zip(spec.method_names, 
spec.method_specs):
+                params = _params()  # Re-initialize so symbolic shapes not 
shared across methods
                 len_args = len(method_spec.arg_specs)
                 len_effects = {
                     "packed": 1,
@@ -135,18 +135,9 @@ class Exporter:
                     with self.builder.dataflow():
                         outputs, inputs = _emit_method(self.builder, 
method_spec, params, effects)
                     self.builder.emit_func_output(outputs, inputs)
-
-                # TODO(Lunderberg): Make a `ir.transform.ConvertSSA`,
-                # similar to the existing `tir.transform.ConvertSSA`,
-                # that converts an entire module to SSA, including TIR
-                # variable definitions used in either TIR or Relax.
-                mod = self.builder.get()
-                mod[method_name] = 
rx.utils.copy_with_new_vars(mod[method_name])
-
         mod = self.builder.finalize()
         assert rx.analysis.well_formed(mod)
 
-        mod = rx.transform.CanonicalizeBindings()(mod)
         return mod, params, ext_mods
 
 
@@ -170,6 +161,8 @@ def _emit_method(  # pylint: 
disable=too-many-locals,too-many-branches,too-many-
     effects: typing.Optional[typing.List[typing.Tuple[str, core.Effect]]],
 ):
     # pylint: disable=protected-access
+    # symbolic shape's name mapping to its tir.Var for reuse
+    str2var_params: typing.Dict[str, tir.Var] = {}
 
     def _unwrap_ret(expr: typing.Any) -> typing.Any:
         if isinstance(expr, (core.Tensor, core.Object)):
@@ -183,26 +176,35 @@ def _emit_method(  # pylint: 
disable=too-many-locals,too-many-branches,too-many-
     def _convert_input(arg):
         if isinstance(arg, tir.Var):
             return rx.Var(arg.name, struct_info=ShapeStructInfo(values=[arg]))
-        elif isinstance(arg, (core.Tensor, core.Object)):
+        if isinstance(arg, (core.Tensor, core.Object)):
             return arg._expr  # pylint: disable=protected-access
-        elif isinstance(arg, _spec.Tuple):
+        if isinstance(arg, _spec.Tuple):
             return rx.Var(
                 arg.name,
                 struct_info=TupleStructInfo(
                     [_convert_input(arg_i).struct_info for arg_i in 
arg.elements]
                 ),
             )
-        elif isinstance(arg, rx.Expr):
-            return arg
-        else:
-            raise TypeError(f"Unsupported input type: {type(arg)}")
+        raise TypeError(f"Unsupported input type: {type(arg)}")
 
     def _params(mode: str) -> typing.List[rx.Var]:
         inputs: typing.List[rx.Var] = []
 
-        for name, param in params:
-            inputs.append(param._expr)
+        def _get_var(shape_var: tir.Var) -> tir.Var:
+            name = shape_var.name
+            if name in str2var_params:
+                return str2var_params[name]
+            var = tir.Var(name, "int64")
+            str2var_params[name] = var
+            return var
 
+        for name, param in params:
+            # Make sure the a symbolic shape is not re-registered (same as 
_method_spec_to_inputs)
+            # e.g. we do not see `vocab_size` for `lm_head` and `vocab_size_1` 
for `embed_tokens`
+            new_shape = [_get_var(x) if isinstance(x, tir.Var) else x for x in 
param.shape]
+            var = core.Tensor.placeholder(new_shape, param.dtype, name)._expr
+            inputs.append(var)
+            param._expr = var
         if mode == "none":
             return []
         if mode == "plain":
diff --git a/tests/python/relax/test_frontend_nn_exporter.py 
b/tests/python/relax/test_frontend_nn_exporter.py
deleted file mode 100644
index de8900238b..0000000000
--- a/tests/python/relax/test_frontend_nn_exporter.py
+++ /dev/null
@@ -1,443 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-
-import tvm
-import tvm.testing
-
-from tvm import relax, tir
-from tvm.ir import assert_structural_equal
-from tvm.relax.frontend import nn
-from tvm.script import ir as I, relax as R, tir as T
-
-
-def test_simple():
-    """A module may be exported from nn.Module to Relax"""
-
-    slm_mod = nn.modules.ReLU()
-    exported_mod, _ = slm_mod.export_tvm(
-        spec={"forward": {"x": nn.spec.Tensor((3, 3), "float32")}},
-        debug=False,
-    )
-
-    @I.ir_module
-    class Expected:
-        @R.function
-        def forward(x: R.Tensor([3, 3], dtype="float32")):
-            R.func_attr({"num_input": 1})
-            with R.dataflow():
-                relu = R.nn.relu(x)
-                R.output(relu)
-            return relu
-
-    assert_structural_equal(exported_mod, Expected)
-
-
-def test_custom_module():
-    """A module may be exported from nn.Module to Relax"""
-
-    class Before(nn.Module):
-        def forward(self, x: R.Tensor):
-            return nn.op.relu(x)
-
-    slm_mod = Before()
-    exported_mod, _ = slm_mod.export_tvm(
-        spec={"forward": {"x": nn.spec.Tensor((3, 3), "float32")}},
-        debug=False,
-    )
-
-    @I.ir_module
-    class Expected:
-        @R.function
-        def forward(x: R.Tensor([3, 3], dtype="float32")):
-            R.func_attr({"num_input": 1})
-            with R.dataflow():
-                relu = R.nn.relu(x)
-                R.output(relu)
-            return relu
-
-    assert_structural_equal(exported_mod, Expected)
-
-
-def test_debug_effect():
-    """Passing debug=True provides an argument for IO effect"""
-
-    slm_mod = nn.modules.ReLU()
-    exported_mod, _ = slm_mod.export_tvm(
-        spec={"forward": {"x": nn.spec.Tensor((3, 3), "float32")}},
-        debug=True,
-    )
-
-    @I.ir_module
-    class Expected:
-        @R.function
-        def forward(
-            x: R.Tensor([3, 3], dtype="float32"),
-            _io: R.Object,
-        ):
-            R.func_attr({"num_input": 2})
-            with R.dataflow():
-                relu = R.nn.relu(x)
-                output = relu, (_io,)
-                R.output(output)
-            return output
-
-        @R.function
-        def _initialize_effect():
-            with R.dataflow():
-                _io = R.null_value()
-                output = (_io,)
-                R.output(output)
-            return output
-
-    assert_structural_equal(exported_mod, Expected)
-
-
-def test_dynamic_shape():
-    """An argument may have a dynamic shape"""
-
-    slm_mod = nn.modules.ReLU()
-    exported_mod, _ = slm_mod.export_tvm(
-        spec={"forward": {"x": nn.spec.Tensor([tir.Var("batch_size", "int64"), 
8], "float32")}},
-        debug=False,
-    )
-
-    @I.ir_module
-    class Expected:
-        @R.function
-        def forward(x: R.Tensor(["batch_size", 8], dtype="float32")):
-            R.func_attr({"num_input": 1})
-            with R.dataflow():
-                relu = R.nn.relu(x)
-                R.output(relu)
-            return relu
-
-    assert_structural_equal(exported_mod, Expected)
-
-
-def test_dynamic_shape_in_multiple_functions():
-    """A dynamic shape may be used in multiple functions"""
-
-    class Before(nn.Module):
-        def forward_relu(self, x: nn.Tensor):
-            return nn.relu(x)
-
-        def forward_silu(self, x: nn.Tensor):
-            return nn.silu(x)
-
-    slm_mod = Before()
-    exported_mod, _ = slm_mod.export_tvm(
-        spec={
-            "forward_relu": {"x": nn.spec.Tensor((tir.Var("batch_size", 
"int64"), 8), "float32")},
-            "forward_silu": {"x": nn.spec.Tensor((tir.Var("batch_size", 
"int64"), 8), "float32")},
-        },
-        debug=False,
-    )
-
-    @I.ir_module
-    class Expected:
-        @R.function
-        def forward_relu(x: R.Tensor(["batch_size", 8], dtype="float32")):
-            R.func_attr({"num_input": 1})
-            with R.dataflow():
-                relu = R.nn.relu(x)
-                R.output(relu)
-            return relu
-
-        @R.function
-        def forward_silu(x: R.Tensor(["batch_size", 8], dtype="float32")):
-            R.func_attr({"num_input": 1})
-            with R.dataflow():
-                silu = R.nn.silu(x)
-                R.output(silu)
-            return silu
-
-    assert_structural_equal(exported_mod, Expected)
-
-
-def test_export_nested_module():
-    """nn.Module instances may contain other nn.Module
-
-    When exporting to a Relax IRModule, all `nn.Parameter` instances
-    within the `nn.Module` become Relax function parameters.
-    """
-
-    class LlamaMLP(nn.Module):
-        def __init__(self, hidden_size: int, intermediate_size: int):
-            super().__init__()
-            self.gate_proj = nn.Linear(
-                in_features=hidden_size,
-                out_features=intermediate_size,
-                dtype="float16",
-                bias=False,
-            )
-            self.up_proj = nn.Linear(
-                in_features=hidden_size,
-                out_features=intermediate_size,
-                dtype="float16",
-                bias=False,
-            )
-            self.down_proj = nn.Linear(
-                intermediate_size,
-                hidden_size,
-                dtype="float16",
-                bias=False,
-            )
-
-        def forward(self, x: nn.Tensor):
-            gate = self.gate_proj(x)
-            up = self.up_proj(x)
-            return self.down_proj(nn.op.silu(gate) * up)
-
-    hidden_size = 4096
-    intermediate_size = 11008
-    slm_mod = LlamaMLP(hidden_size=hidden_size, 
intermediate_size=intermediate_size)
-    exported_mod, _ = slm_mod.export_tvm(
-        spec={
-            "forward": {
-                "x": nn.spec.Tensor((tir.Var("batch_size", "int64"), 
hidden_size), "float16")
-            },
-        },
-        debug=False,
-    )
-
-    @I.ir_module
-    class Expected:
-        @R.function
-        def forward(
-            x: R.Tensor(["batch_size", hidden_size], "float16"),
-            gate_proj_weights: R.Tensor([intermediate_size, hidden_size], 
"float16"),
-            up_proj_weights: R.Tensor([intermediate_size, hidden_size], 
"float16"),
-            down_proj_weights: R.Tensor([hidden_size, intermediate_size], 
"float16"),
-        ):
-            R.func_attr({"num_input": 1})
-            batch_size = T.int64()
-            with R.dataflow():
-                gate: R.Tensor([batch_size, intermediate_size]) = R.matmul(
-                    x, R.permute_dims(gate_proj_weights)
-                )
-                up: R.Tensor([batch_size, intermediate_size]) = R.matmul(
-                    x, R.permute_dims(up_proj_weights)
-                )
-                down: R.Tensor([batch_size, hidden_size]) = R.matmul(
-                    R.nn.silu(gate) * up, R.permute_dims(down_proj_weights)
-                )
-                R.output(down)
-            return down
-
-    assert_structural_equal(exported_mod, Expected)
-
-
-def test_generate_parameters():
-    """Weights may be expressions in terms of other parameters
-
-    Optimizations often require preprocessing of the model weights.
-
-    1. Declare the `nn.Module` members that contain the original model
-       weights.  These are used to define the parameter names when
-       reading from a Pytorch or Safetensors file.
-
-    2. Declare the `nn.Module` members, with the `weight` field
-       in terms of the un-optimized weights.  These `nn.Module`
-       do not generate any parameters in the Relax function.
-
-    3. Define the `forward` function in terms of the `nn.Module`
-       members for the updated weight tensors.
-
-    The exported Relax function accepts the original model parameters,
-    computes the pre-processed weights, and then performs computations
-    using the pre-processed weights.
-
-    In this example, the `LiftTransformParams` transform is applied
-    immediately, splitting the Relax function into a pre-processing
-    step and an execution step.  In practice, this transform would be
-    applied much later in an optimization pipeline, to allow optimized
-    compute kernels to be recognized.  For example, in some cases
-    `R.matmul(x, R.permute_dims(weight))` may be computed more
-    efficiently than `R.matmul(x, weight_transpose)`.  For this
-    reason, we do *not* apply `LiftTransformParams` as part of the
-    export from `nn.Module` to Relax.
-
-    """
-
-    class LlamaMLP(nn.Module):
-        def __init__(self, hidden_size: int, intermediate_size: int):
-            super().__init__()
-            # The nn.Linear for the original parameters are present in
-            # the model definition, and are still found when
-            # collecting a function's parameters.
-            self.gate_proj = nn.Linear(
-                in_features=hidden_size,
-                out_features=intermediate_size,
-                dtype="float16",
-                bias=False,
-            )
-            self.up_proj = nn.Linear(
-                in_features=hidden_size,
-                out_features=intermediate_size,
-                dtype="float16",
-                bias=False,
-            )
-            self.down_proj = nn.Linear(
-                intermediate_size,
-                hidden_size,
-                dtype="float16",
-                bias=False,
-            )
-
-            # At runtime, we'd like to have a single concatenated
-            # tensor containing both the gate and up projection
-            # weights.  We also want to use it in the `forward`
-            # function as if it owned its own weights.
-            self.gate_up_proj = nn.Linear(
-                in_features=hidden_size,
-                out_features=intermediate_size,
-                dtype="float16",
-                bias=False,
-            )
-
-            # The weight tensor of `gate_up_proj` can be overwritten
-            # in terms of the original `gate_proj` and `up_proj`
-            # tensors.
-            self.gate_up_proj.weight = nn.op.concat(
-                [self.gate_proj.weight, self.up_proj.weight], dim=0, 
name="gate_up_proj_weights"
-            )
-
-        def forward(self, x: nn.Tensor):
-            # Even though the `gate_up_proj` weights are defined as an
-            # expression rather than a `nn.Parameter`, the `forward`
-            # function does not require any special handling for it.
-            concat_gate_up = self.gate_up_proj(x)
-            gate, up = nn.op.split(concat_gate_up, 2, axis=-1)
-            return self.down_proj(nn.op.silu(gate) * up)
-
-    hidden_size = 4096
-    intermediate_size = 11008
-    slm_mod = LlamaMLP(hidden_size=hidden_size, 
intermediate_size=intermediate_size)
-    exported_mod, _ = slm_mod.export_tvm(
-        spec={
-            "forward": {
-                "x": nn.spec.Tensor((tir.Var("batch_size", "int64"), 
hidden_size), "float16")
-            },
-        },
-        debug=False,
-    )
-
-    @I.ir_module
-    class Expected:
-        @R.function
-        def forward(
-            x: R.Tensor(["batch_size", hidden_size], "float16"),
-            # The function's parameters are defined by the
-            # `nn.Parameter` instances, and still reference the
-            # original `gate_proj` and `up_proj` weights.  This
-            # maintains compatibility with named model weights in a
-            # Pytorch or Safetensors file.
-            gate_proj_weights: R.Tensor([intermediate_size, hidden_size], 
"float16"),
-            up_proj_weights: R.Tensor([intermediate_size, hidden_size], 
"float16"),
-            down_proj_weights: R.Tensor([hidden_size, intermediate_size], 
"float16"),
-        ):
-            R.func_attr({"num_input": 1})
-            batch_size = T.int64()
-            with R.dataflow():
-                # At this stage of compilation, the concatenation is
-                # written within the body of the function.  This will
-                # later be extracted into a pre-processing step using
-                # `relax.transform.LiftTransformParams`.
-                gate_up_proj_weights: R.Tensor(
-                    [intermediate_size * 2, hidden_size], "float16"
-                ) = R.concat([gate_proj_weights, up_proj_weights], axis=0)
-                gate_up: R.Tensor([batch_size, intermediate_size * 2], 
"float16") = R.matmul(
-                    x, R.permute_dims(gate_up_proj_weights)
-                )
-                gate_up_split = R.split(gate_up, 2, axis=-1)
-                gate = gate_up_split[0]
-                up = gate_up_split[1]
-                down: R.Tensor([batch_size, hidden_size], "float16") = 
R.matmul(
-                    R.nn.silu(gate) * up, R.permute_dims(down_proj_weights)
-                )
-                R.output(down)
-            return down
-
-    assert_structural_equal(exported_mod, Expected)
-
-    @I.ir_module
-    class ExpectedAfterLift:
-        @R.function
-        def forward(
-            x: R.Tensor(["batch_size", hidden_size], "float16"),
-            # After `relax.transform.LiftTransformParams`, the
-            # `gate_proj` and `up_proj` weights have been concatenated
-            # together.
-            gate_up_proj_weights_transpose: R.Tensor(
-                [hidden_size, intermediate_size * 2], "float16"
-            ),
-            down_proj_weights_transpose: R.Tensor([intermediate_size, 
hidden_size], "float16"),
-        ):
-            R.func_attr({"num_input": 1})
-            batch_size = T.int64()
-            with R.dataflow():
-                gate_up: R.Tensor([batch_size, intermediate_size * 2], 
"float16") = R.matmul(
-                    x, gate_up_proj_weights_transpose
-                )
-                gate_up_split = R.split(gate_up, 2, axis=-1)
-                gate = gate_up_split[0]
-                up = gate_up_split[1]
-                down: R.Tensor([batch_size, hidden_size], "float16") = 
R.matmul(
-                    R.nn.silu(gate) * up, down_proj_weights_transpose
-                )
-                R.output(down)
-            return down
-
-        @R.function
-        def transform_params(
-            model_params: R.Tuple(
-                R.Tensor([intermediate_size, hidden_size], "float16"),
-                R.Tensor([intermediate_size, hidden_size], "float16"),
-                R.Tensor([hidden_size, intermediate_size], "float16"),
-            )
-        ):
-            R.func_attr({"num_input": 0})
-            with R.dataflow():
-                gate_proj_weights: R.Tensor(
-                    [intermediate_size, hidden_size], "float16"
-                ) = model_params[0]
-                up_proj_weights: R.Tensor(
-                    [intermediate_size, hidden_size], "float16"
-                ) = model_params[1]
-                gate_up_proj_weights: R.Tensor(
-                    [intermediate_size * 2, hidden_size], "float16"
-                ) = R.concat([gate_proj_weights, up_proj_weights], axis=0)
-                gate_up_proj_weights_transpose: R.Tensor(
-                    [hidden_size, intermediate_size * 2], "float16"
-                ) = R.permute_dims(gate_up_proj_weights)
-                down_proj_weights: R.Tensor(
-                    [hidden_size, intermediate_size], "float16"
-                ) = model_params[2]
-                down_proj_weights_transpose: R.Tensor(
-                    [intermediate_size, hidden_size], "float16"
-                ) = R.permute_dims(down_proj_weights)
-                output = (gate_up_proj_weights_transpose, 
down_proj_weights_transpose)
-                R.output(output)
-            return output
-
-    lifted_mod = 
relax.transform.LiftTransformParams(shared_transform=True)(exported_mod)
-    assert_structural_equal(lifted_mod, ExpectedAfterLift)
-
-
-if __name__ == "__main__":
-    tvm.testing.main()
diff --git a/tests/python/relax/test_frontend_nn_extern_module.py 
b/tests/python/relax/test_frontend_nn_extern_module.py
index 6ca7742422..6eaf1fbfc8 100644
--- a/tests/python/relax/test_frontend_nn_extern_module.py
+++ b/tests/python/relax/test_frontend_nn_extern_module.py
@@ -94,8 +94,9 @@ def _check_ir_equality(mod):
                 ext_scalar_add = R.call_dps_packed(
                     "ext_scalar_add", (a, b), out_sinfo=R.Tensor((), 
dtype="float32")
                 )
-                R.output(ext_scalar_add)
-            return ext_scalar_add
+                gv: R.Tensor((), dtype="float32") = ext_scalar_add
+                R.output(gv)
+            return gv
 
         @R.function
         def test_sym(
@@ -109,8 +110,9 @@ def _check_ir_equality(mod):
                 ext_test_sym = R.call_dps_packed(
                     "ext_test_sym", (a, b), out_sinfo=R.Tensor((x, y, z, 9), 
dtype="float32")
                 )
-                R.output(ext_test_sym)
-            return ext_test_sym
+                gv1: R.Tensor((x, y, z, 9), dtype="float32") = ext_test_sym
+                R.output(gv1)
+            return gv1
 
     tvm.ir.assert_structural_equal(ExpectedModule, mod)
 
diff --git a/tests/python/relax/test_frontend_nn_modules.py 
b/tests/python/relax/test_frontend_nn_modules.py
index 45128749e2..5ddc105055 100644
--- a/tests/python/relax/test_frontend_nn_modules.py
+++ b/tests/python/relax/test_frontend_nn_modules.py
@@ -493,7 +493,8 @@ def test_kv_cache():
                     R.prim_value(0),
                     sinfo_args=[R.Object()],
                 )
-                gv = _io, cache
+                lv1 = _io, cache
+                gv = lv1
                 R.output(gv)
             return gv
 
diff --git a/tests/python/relax/test_frontend_nn_op.py 
b/tests/python/relax/test_frontend_nn_op.py
index 68f86bba50..7d78e47c94 100644
--- a/tests/python/relax/test_frontend_nn_op.py
+++ b/tests/python/relax/test_frontend_nn_op.py
@@ -538,7 +538,8 @@ def test_tensor_expr_op():
         def _initialize_effect() -> R.Tuple(R.Object):
             with R.dataflow():
                 _io: R.Object = R.null_value()
-                gv = (_io,)
+                lv: R.Tuple(R.Object) = (_io,)
+                gv: R.Tuple(R.Object) = lv
                 R.output(gv)
             return gv
 
@@ -610,7 +611,8 @@ def test_tensor_ir_op():
         def _initialize_effect() -> R.Tuple(R.Object):
             with R.dataflow():
                 _io: R.Object = R.null_value()
-                gv = (_io,)
+                lv: R.Tuple(R.Object) = (_io,)
+                gv: R.Tuple(R.Object) = lv
                 R.output(gv)
             return gv
 
@@ -697,7 +699,8 @@ def test_tensor_ir_inplace_op():
         def _initialize_effect() -> R.Tuple(R.Object):
             with R.dataflow():
                 _io: R.Object = R.null_value()
-                gv = (_io,)
+                lv: R.Tuple(R.Object) = (_io,)
+                gv: R.Tuple(R.Object) = lv
                 R.output(gv)
             return gv
 
@@ -714,12 +717,13 @@ def test_tensor_ir_inplace_op():
             R.func_attr({"num_input": 4})
             cls = Expected
             with R.dataflow():
-                gv1 = R.call_tir(
+                lv1 = R.call_tir(
                     cls.inplace_take,
                     (embedding_table, input_ids, embedding_dst),
                     out_sinfo=R.Tensor((total_seq_len, hidden_size), dtype),
                     tir_vars=R.shape([offset_1]),
                 )
+                gv1: R.Tensor((total_seq_len, hidden_size), dtype) = lv1
                 R.output(gv1)
             return gv1
 
@@ -768,7 +772,8 @@ def test_tensor_ir_op_no_tir_var():
             R.func_attr({"num_input": 1})
             cls = Expected
             with R.dataflow():
-                gv = R.call_tir(cls.tir_func, (A,), out_sinfo=R.Tensor((16, 
16), dtype="float32"))
+                lv = R.call_tir(cls.tir_func, (A,), out_sinfo=R.Tensor((16, 
16), dtype="float32"))
+                gv: R.Tensor((16, 16), dtype="float32") = lv
                 R.output(gv)
             return gv
 
@@ -795,7 +800,8 @@ def test_extern():
         def _initialize_effect() -> R.Tuple(R.Object):
             with R.dataflow():
                 _io: R.Object = R.null_value()
-                gv = (_io,)
+                lv: R.Tuple(R.Object) = (_io,)
+                gv: R.Tuple(R.Object) = lv
                 R.output(gv)
             return gv
 
@@ -882,7 +888,8 @@ def test_multinomial_from_uniform():
         def _initialize_effect() -> R.Tuple(R.Object):
             with R.dataflow():
                 _io: R.Object = R.null_value()
-                gv = (_io,)
+                lv: R.Tuple(R.Object) = (_io,)
+                gv: R.Tuple(R.Object) = lv
                 R.output(gv)
             return gv
 
@@ -1008,7 +1015,8 @@ def test_sample_top_p_top_k_from_sorted_prob():
         def _initialize_effect() -> R.Tuple(R.Object):
             with R.dataflow():
                 _io: R.Object = R.null_value()
-                gv: R.Tuple(R.Object) = (_io,)
+                lv: R.Tuple(R.Object) = (_io,)
+                gv: R.Tuple(R.Object) = lv
                 R.output(gv)
             return gv
 
@@ -1122,7 +1130,8 @@ def test_renormalize_top_p_top_k_prob():
         def _initialize_effect() -> R.Tuple(R.Object):
             with R.dataflow():
                 _io: R.Object = R.null_value()
-                gv: R.Tuple(R.Object) = (_io,)
+                lv: R.Tuple(R.Object) = (_io,)
+                gv: R.Tuple(R.Object) = lv
                 R.output(gv)
             return gv
 
diff --git a/tests/python/relax/test_frontend_nn_packing.py 
b/tests/python/relax/test_frontend_nn_packing.py
index c2cc22c17d..56b614a807 100644
--- a/tests/python/relax/test_frontend_nn_packing.py
+++ b/tests/python/relax/test_frontend_nn_packing.py
@@ -59,7 +59,8 @@ def test_nn_export_to_relax():
                 matmul = R.matmul(x, matmul_1_weight)
                 matmul_2_weight = R.permute_dims(linear_2_weight)
                 matmul1 = R.matmul(x, matmul_2_weight)
-                gv = R.add(matmul, matmul1)
+                add = R.add(matmul, matmul1)
+                gv = add
                 R.output(gv)
             return gv
 
diff --git a/tests/python/relax/test_frontend_nn_subroutines.py 
b/tests/python/relax/test_frontend_nn_subroutines.py
index 32ae967916..6bbf57aead 100644
--- a/tests/python/relax/test_frontend_nn_subroutines.py
+++ b/tests/python/relax/test_frontend_nn_subroutines.py
@@ -61,7 +61,8 @@ def test_linear():
         def _initialize_effect() -> R.Tuple(R.Object):
             with R.dataflow():
                 _io: R.Object = R.null_value()
-                gv = (_io,)
+                lv: R.Tuple(R.Object) = (_io,)
+                gv: R.Tuple(R.Object) = lv
                 R.output(gv)
 
             return gv
@@ -74,8 +75,9 @@ def test_linear():
             with R.dataflow():
                 state = R.matmul(state, weights)
                 state = Expected.activation(state)
-                R.output(state)
-            return state
+                dataflow_output = state
+                R.output(dataflow_output)
+            return dataflow_output
 
         @R.function(private=True)
         def activation(
@@ -83,8 +85,9 @@ def test_linear():
         ) -> R.Tensor(("batch_size", 32), dtype="float32"):
             with R.dataflow():
                 state = R.nn.silu(state)
-                R.output(state)
-            return state
+                dataflow_output = state
+                R.output(dataflow_output)
+            return dataflow_output
 
     mod = Layer(64, 32)
     batch_size = tvm.tir.Var("batch_size", "int64")

Reply via email to