This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new ef46f4e8d3 Revert "[SLM] Allow modules to define pre-processing of
weights" (#16777)
ef46f4e8d3 is described below
commit ef46f4e8d33f1946dca9cd61f6db5eec79c7deab
Author: Tianqi Chen <[email protected]>
AuthorDate: Mon Mar 25 11:10:19 2024 -0400
Revert "[SLM] Allow modules to define pre-processing of weights" (#16777)
Revert "[SLM] Allow modules to define pre-processing of weights (#16757)"
This reverts commit 1cccc3b5d65cae743a2becb7e256c05897af29ca.
---
python/tvm/relax/frontend/nn/core.py | 17 +-
python/tvm/relax/frontend/nn/exporter.py | 40 +-
tests/python/relax/test_frontend_nn_exporter.py | 443 ---------------------
.../python/relax/test_frontend_nn_extern_module.py | 10 +-
tests/python/relax/test_frontend_nn_modules.py | 3 +-
tests/python/relax/test_frontend_nn_op.py | 27 +-
tests/python/relax/test_frontend_nn_packing.py | 3 +-
tests/python/relax/test_frontend_nn_subroutines.py | 13 +-
8 files changed, 58 insertions(+), 498 deletions(-)
diff --git a/python/tvm/relax/frontend/nn/core.py
b/python/tvm/relax/frontend/nn/core.py
index 820acd235d..b7b3f411ed 100644
--- a/python/tvm/relax/frontend/nn/core.py
+++ b/python/tvm/relax/frontend/nn/core.py
@@ -591,22 +591,7 @@ def wrap_nested(expr: rx.Expr, name: str) -> Union[Tensor,
Sequence[Tensor]]:
The computed result.
"""
if not isinstance(expr, rx.DataflowVar):
- block_builder = BlockBuilder.current()
- if block_builder is None:
- # Normalize to make sure we have valid StructInfo, but
- # wait until we are actually building the function to
- # flatten nested expressions.
- #
- # TODO(Lunderberg): Make this easier to call. Infering
- # struct info for a nested expression should be doable in
- # a free function, without requiring an active
- # BlockBuilder and an active FunctionFrame.
- builder = BlockBuilder()
- with builder.function("dummy_scope", params=[]):
- expr = builder.normalize(expr)
- builder.emit_func_output([])
- else:
- expr = BlockBuilder.current().emit(expr, name)
+ expr = BlockBuilder.current().emit(expr, name)
if isinstance(expr.struct_info_, TensorStructInfo):
return Tensor(_expr=expr)
if isinstance(expr.struct_info_, TupleStructInfo):
diff --git a/python/tvm/relax/frontend/nn/exporter.py
b/python/tvm/relax/frontend/nn/exporter.py
index 525d689f49..1a7dcd6a64 100644
--- a/python/tvm/relax/frontend/nn/exporter.py
+++ b/python/tvm/relax/frontend/nn/exporter.py
@@ -111,8 +111,7 @@ class Exporter:
return result
# pylint: enable=protected-access
-
- params = _params()
+ params = None
effects = _effects()
ext_mods = self.extern_mods
with self:
@@ -122,6 +121,7 @@ class Exporter:
outputs = _emit_effect_init(self.builder, effects)
self.builder.emit_func_output(outputs, params=[])
for method_name, method_spec in zip(spec.method_names,
spec.method_specs):
+ params = _params() # Re-initialize so symbolic shapes not
shared across methods
len_args = len(method_spec.arg_specs)
len_effects = {
"packed": 1,
@@ -135,18 +135,9 @@ class Exporter:
with self.builder.dataflow():
outputs, inputs = _emit_method(self.builder,
method_spec, params, effects)
self.builder.emit_func_output(outputs, inputs)
-
- # TODO(Lunderberg): Make a `ir.transform.ConvertSSA`,
- # similar to the existing `tir.transform.ConvertSSA`,
- # that converts an entire module to SSA, including TIR
- # variable definitions used in either TIR or Relax.
- mod = self.builder.get()
- mod[method_name] =
rx.utils.copy_with_new_vars(mod[method_name])
-
mod = self.builder.finalize()
assert rx.analysis.well_formed(mod)
- mod = rx.transform.CanonicalizeBindings()(mod)
return mod, params, ext_mods
@@ -170,6 +161,8 @@ def _emit_method( # pylint:
disable=too-many-locals,too-many-branches,too-many-
effects: typing.Optional[typing.List[typing.Tuple[str, core.Effect]]],
):
# pylint: disable=protected-access
+ # symbolic shape's name mapping to its tir.Var for reuse
+ str2var_params: typing.Dict[str, tir.Var] = {}
def _unwrap_ret(expr: typing.Any) -> typing.Any:
if isinstance(expr, (core.Tensor, core.Object)):
@@ -183,26 +176,35 @@ def _emit_method( # pylint:
disable=too-many-locals,too-many-branches,too-many-
def _convert_input(arg):
if isinstance(arg, tir.Var):
return rx.Var(arg.name, struct_info=ShapeStructInfo(values=[arg]))
- elif isinstance(arg, (core.Tensor, core.Object)):
+ if isinstance(arg, (core.Tensor, core.Object)):
return arg._expr # pylint: disable=protected-access
- elif isinstance(arg, _spec.Tuple):
+ if isinstance(arg, _spec.Tuple):
return rx.Var(
arg.name,
struct_info=TupleStructInfo(
[_convert_input(arg_i).struct_info for arg_i in
arg.elements]
),
)
- elif isinstance(arg, rx.Expr):
- return arg
- else:
- raise TypeError(f"Unsupported input type: {type(arg)}")
+ raise TypeError(f"Unsupported input type: {type(arg)}")
def _params(mode: str) -> typing.List[rx.Var]:
inputs: typing.List[rx.Var] = []
- for name, param in params:
- inputs.append(param._expr)
+ def _get_var(shape_var: tir.Var) -> tir.Var:
+ name = shape_var.name
+ if name in str2var_params:
+ return str2var_params[name]
+ var = tir.Var(name, "int64")
+ str2var_params[name] = var
+ return var
+ for name, param in params:
+ # Make sure the a symbolic shape is not re-registered (same as
_method_spec_to_inputs)
+ # e.g. we do not see `vocab_size` for `lm_head` and `vocab_size_1`
for `embed_tokens`
+ new_shape = [_get_var(x) if isinstance(x, tir.Var) else x for x in
param.shape]
+ var = core.Tensor.placeholder(new_shape, param.dtype, name)._expr
+ inputs.append(var)
+ param._expr = var
if mode == "none":
return []
if mode == "plain":
diff --git a/tests/python/relax/test_frontend_nn_exporter.py
b/tests/python/relax/test_frontend_nn_exporter.py
deleted file mode 100644
index de8900238b..0000000000
--- a/tests/python/relax/test_frontend_nn_exporter.py
+++ /dev/null
@@ -1,443 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-
-import tvm
-import tvm.testing
-
-from tvm import relax, tir
-from tvm.ir import assert_structural_equal
-from tvm.relax.frontend import nn
-from tvm.script import ir as I, relax as R, tir as T
-
-
-def test_simple():
- """A module may be exported from nn.Module to Relax"""
-
- slm_mod = nn.modules.ReLU()
- exported_mod, _ = slm_mod.export_tvm(
- spec={"forward": {"x": nn.spec.Tensor((3, 3), "float32")}},
- debug=False,
- )
-
- @I.ir_module
- class Expected:
- @R.function
- def forward(x: R.Tensor([3, 3], dtype="float32")):
- R.func_attr({"num_input": 1})
- with R.dataflow():
- relu = R.nn.relu(x)
- R.output(relu)
- return relu
-
- assert_structural_equal(exported_mod, Expected)
-
-
-def test_custom_module():
- """A module may be exported from nn.Module to Relax"""
-
- class Before(nn.Module):
- def forward(self, x: R.Tensor):
- return nn.op.relu(x)
-
- slm_mod = Before()
- exported_mod, _ = slm_mod.export_tvm(
- spec={"forward": {"x": nn.spec.Tensor((3, 3), "float32")}},
- debug=False,
- )
-
- @I.ir_module
- class Expected:
- @R.function
- def forward(x: R.Tensor([3, 3], dtype="float32")):
- R.func_attr({"num_input": 1})
- with R.dataflow():
- relu = R.nn.relu(x)
- R.output(relu)
- return relu
-
- assert_structural_equal(exported_mod, Expected)
-
-
-def test_debug_effect():
- """Passing debug=True provides an argument for IO effect"""
-
- slm_mod = nn.modules.ReLU()
- exported_mod, _ = slm_mod.export_tvm(
- spec={"forward": {"x": nn.spec.Tensor((3, 3), "float32")}},
- debug=True,
- )
-
- @I.ir_module
- class Expected:
- @R.function
- def forward(
- x: R.Tensor([3, 3], dtype="float32"),
- _io: R.Object,
- ):
- R.func_attr({"num_input": 2})
- with R.dataflow():
- relu = R.nn.relu(x)
- output = relu, (_io,)
- R.output(output)
- return output
-
- @R.function
- def _initialize_effect():
- with R.dataflow():
- _io = R.null_value()
- output = (_io,)
- R.output(output)
- return output
-
- assert_structural_equal(exported_mod, Expected)
-
-
-def test_dynamic_shape():
- """An argument may have a dynamic shape"""
-
- slm_mod = nn.modules.ReLU()
- exported_mod, _ = slm_mod.export_tvm(
- spec={"forward": {"x": nn.spec.Tensor([tir.Var("batch_size", "int64"),
8], "float32")}},
- debug=False,
- )
-
- @I.ir_module
- class Expected:
- @R.function
- def forward(x: R.Tensor(["batch_size", 8], dtype="float32")):
- R.func_attr({"num_input": 1})
- with R.dataflow():
- relu = R.nn.relu(x)
- R.output(relu)
- return relu
-
- assert_structural_equal(exported_mod, Expected)
-
-
-def test_dynamic_shape_in_multiple_functions():
- """A dynamic shape may be used in multiple functions"""
-
- class Before(nn.Module):
- def forward_relu(self, x: nn.Tensor):
- return nn.relu(x)
-
- def forward_silu(self, x: nn.Tensor):
- return nn.silu(x)
-
- slm_mod = Before()
- exported_mod, _ = slm_mod.export_tvm(
- spec={
- "forward_relu": {"x": nn.spec.Tensor((tir.Var("batch_size",
"int64"), 8), "float32")},
- "forward_silu": {"x": nn.spec.Tensor((tir.Var("batch_size",
"int64"), 8), "float32")},
- },
- debug=False,
- )
-
- @I.ir_module
- class Expected:
- @R.function
- def forward_relu(x: R.Tensor(["batch_size", 8], dtype="float32")):
- R.func_attr({"num_input": 1})
- with R.dataflow():
- relu = R.nn.relu(x)
- R.output(relu)
- return relu
-
- @R.function
- def forward_silu(x: R.Tensor(["batch_size", 8], dtype="float32")):
- R.func_attr({"num_input": 1})
- with R.dataflow():
- silu = R.nn.silu(x)
- R.output(silu)
- return silu
-
- assert_structural_equal(exported_mod, Expected)
-
-
-def test_export_nested_module():
- """nn.Module instances may contain other nn.Module
-
- When exporting to a Relax IRModule, all `nn.Parameter` instances
- within the `nn.Module` become Relax function parameters.
- """
-
- class LlamaMLP(nn.Module):
- def __init__(self, hidden_size: int, intermediate_size: int):
- super().__init__()
- self.gate_proj = nn.Linear(
- in_features=hidden_size,
- out_features=intermediate_size,
- dtype="float16",
- bias=False,
- )
- self.up_proj = nn.Linear(
- in_features=hidden_size,
- out_features=intermediate_size,
- dtype="float16",
- bias=False,
- )
- self.down_proj = nn.Linear(
- intermediate_size,
- hidden_size,
- dtype="float16",
- bias=False,
- )
-
- def forward(self, x: nn.Tensor):
- gate = self.gate_proj(x)
- up = self.up_proj(x)
- return self.down_proj(nn.op.silu(gate) * up)
-
- hidden_size = 4096
- intermediate_size = 11008
- slm_mod = LlamaMLP(hidden_size=hidden_size,
intermediate_size=intermediate_size)
- exported_mod, _ = slm_mod.export_tvm(
- spec={
- "forward": {
- "x": nn.spec.Tensor((tir.Var("batch_size", "int64"),
hidden_size), "float16")
- },
- },
- debug=False,
- )
-
- @I.ir_module
- class Expected:
- @R.function
- def forward(
- x: R.Tensor(["batch_size", hidden_size], "float16"),
- gate_proj_weights: R.Tensor([intermediate_size, hidden_size],
"float16"),
- up_proj_weights: R.Tensor([intermediate_size, hidden_size],
"float16"),
- down_proj_weights: R.Tensor([hidden_size, intermediate_size],
"float16"),
- ):
- R.func_attr({"num_input": 1})
- batch_size = T.int64()
- with R.dataflow():
- gate: R.Tensor([batch_size, intermediate_size]) = R.matmul(
- x, R.permute_dims(gate_proj_weights)
- )
- up: R.Tensor([batch_size, intermediate_size]) = R.matmul(
- x, R.permute_dims(up_proj_weights)
- )
- down: R.Tensor([batch_size, hidden_size]) = R.matmul(
- R.nn.silu(gate) * up, R.permute_dims(down_proj_weights)
- )
- R.output(down)
- return down
-
- assert_structural_equal(exported_mod, Expected)
-
-
-def test_generate_parameters():
- """Weights may be expressions in terms of other parameters
-
- Optimizations often require preprocessing of the model weights.
-
- 1. Declare the `nn.Module` members that contain the original model
- weights. These are used to define the parameter names when
- reading from a Pytorch or Safetensors file.
-
- 2. Declare the `nn.Module` members, with the `weight` field
- in terms of the un-optimized weights. These `nn.Module`
- do not generate any parameters in the Relax function.
-
- 3. Define the `forward` function in terms of the `nn.Module`
- members for the updated weight tensors.
-
- The exported Relax function accepts the original model parameters,
- computes the pre-processed weights, and then performs computations
- using the pre-processed weights.
-
- In this example, the `LiftTransformParams` transform is applied
- immediately, splitting the Relax function into a pre-processing
- step and an execution step. In practice, this transform would be
- applied much later in an optimization pipeline, to allow optimized
- compute kernels to be recognized. For example, in some cases
- `R.matmul(x, R.permute_dims(weight))` may be computed more
- efficiently than `R.matmul(x, weight_transpose)`. For this
- reason, we do *not* apply `LiftTransformParams` as part of the
- export from `nn.Module` to Relax.
-
- """
-
- class LlamaMLP(nn.Module):
- def __init__(self, hidden_size: int, intermediate_size: int):
- super().__init__()
- # The nn.Linear for the original parameters are present in
- # the model definition, and are still found when
- # collecting a function's parameters.
- self.gate_proj = nn.Linear(
- in_features=hidden_size,
- out_features=intermediate_size,
- dtype="float16",
- bias=False,
- )
- self.up_proj = nn.Linear(
- in_features=hidden_size,
- out_features=intermediate_size,
- dtype="float16",
- bias=False,
- )
- self.down_proj = nn.Linear(
- intermediate_size,
- hidden_size,
- dtype="float16",
- bias=False,
- )
-
- # At runtime, we'd like to have a single concatenated
- # tensor containing both the gate and up projection
- # weights. We also want to use it in the `forward`
- # function as if it owned its own weights.
- self.gate_up_proj = nn.Linear(
- in_features=hidden_size,
- out_features=intermediate_size,
- dtype="float16",
- bias=False,
- )
-
- # The weight tensor of `gate_up_proj` can be overwritten
- # in terms of the original `gate_proj` and `up_proj`
- # tensors.
- self.gate_up_proj.weight = nn.op.concat(
- [self.gate_proj.weight, self.up_proj.weight], dim=0,
name="gate_up_proj_weights"
- )
-
- def forward(self, x: nn.Tensor):
- # Even though the `gate_up_proj` weights are defined as an
- # expression rather than a `nn.Parameter`, the `forward`
- # function does not require any special handling for it.
- concat_gate_up = self.gate_up_proj(x)
- gate, up = nn.op.split(concat_gate_up, 2, axis=-1)
- return self.down_proj(nn.op.silu(gate) * up)
-
- hidden_size = 4096
- intermediate_size = 11008
- slm_mod = LlamaMLP(hidden_size=hidden_size,
intermediate_size=intermediate_size)
- exported_mod, _ = slm_mod.export_tvm(
- spec={
- "forward": {
- "x": nn.spec.Tensor((tir.Var("batch_size", "int64"),
hidden_size), "float16")
- },
- },
- debug=False,
- )
-
- @I.ir_module
- class Expected:
- @R.function
- def forward(
- x: R.Tensor(["batch_size", hidden_size], "float16"),
- # The function's parameters are defined by the
- # `nn.Parameter` instances, and still reference the
- # original `gate_proj` and `up_proj` weights. This
- # maintains compatibility with named model weights in a
- # Pytorch or Safetensors file.
- gate_proj_weights: R.Tensor([intermediate_size, hidden_size],
"float16"),
- up_proj_weights: R.Tensor([intermediate_size, hidden_size],
"float16"),
- down_proj_weights: R.Tensor([hidden_size, intermediate_size],
"float16"),
- ):
- R.func_attr({"num_input": 1})
- batch_size = T.int64()
- with R.dataflow():
- # At this stage of compilation, the concatenation is
- # written within the body of the function. This will
- # later be extracted into a pre-processing step using
- # `relax.transform.LiftTransformParams`.
- gate_up_proj_weights: R.Tensor(
- [intermediate_size * 2, hidden_size], "float16"
- ) = R.concat([gate_proj_weights, up_proj_weights], axis=0)
- gate_up: R.Tensor([batch_size, intermediate_size * 2],
"float16") = R.matmul(
- x, R.permute_dims(gate_up_proj_weights)
- )
- gate_up_split = R.split(gate_up, 2, axis=-1)
- gate = gate_up_split[0]
- up = gate_up_split[1]
- down: R.Tensor([batch_size, hidden_size], "float16") =
R.matmul(
- R.nn.silu(gate) * up, R.permute_dims(down_proj_weights)
- )
- R.output(down)
- return down
-
- assert_structural_equal(exported_mod, Expected)
-
- @I.ir_module
- class ExpectedAfterLift:
- @R.function
- def forward(
- x: R.Tensor(["batch_size", hidden_size], "float16"),
- # After `relax.transform.LiftTransformParams`, the
- # `gate_proj` and `up_proj` weights have been concatenated
- # together.
- gate_up_proj_weights_transpose: R.Tensor(
- [hidden_size, intermediate_size * 2], "float16"
- ),
- down_proj_weights_transpose: R.Tensor([intermediate_size,
hidden_size], "float16"),
- ):
- R.func_attr({"num_input": 1})
- batch_size = T.int64()
- with R.dataflow():
- gate_up: R.Tensor([batch_size, intermediate_size * 2],
"float16") = R.matmul(
- x, gate_up_proj_weights_transpose
- )
- gate_up_split = R.split(gate_up, 2, axis=-1)
- gate = gate_up_split[0]
- up = gate_up_split[1]
- down: R.Tensor([batch_size, hidden_size], "float16") =
R.matmul(
- R.nn.silu(gate) * up, down_proj_weights_transpose
- )
- R.output(down)
- return down
-
- @R.function
- def transform_params(
- model_params: R.Tuple(
- R.Tensor([intermediate_size, hidden_size], "float16"),
- R.Tensor([intermediate_size, hidden_size], "float16"),
- R.Tensor([hidden_size, intermediate_size], "float16"),
- )
- ):
- R.func_attr({"num_input": 0})
- with R.dataflow():
- gate_proj_weights: R.Tensor(
- [intermediate_size, hidden_size], "float16"
- ) = model_params[0]
- up_proj_weights: R.Tensor(
- [intermediate_size, hidden_size], "float16"
- ) = model_params[1]
- gate_up_proj_weights: R.Tensor(
- [intermediate_size * 2, hidden_size], "float16"
- ) = R.concat([gate_proj_weights, up_proj_weights], axis=0)
- gate_up_proj_weights_transpose: R.Tensor(
- [hidden_size, intermediate_size * 2], "float16"
- ) = R.permute_dims(gate_up_proj_weights)
- down_proj_weights: R.Tensor(
- [hidden_size, intermediate_size], "float16"
- ) = model_params[2]
- down_proj_weights_transpose: R.Tensor(
- [intermediate_size, hidden_size], "float16"
- ) = R.permute_dims(down_proj_weights)
- output = (gate_up_proj_weights_transpose,
down_proj_weights_transpose)
- R.output(output)
- return output
-
- lifted_mod =
relax.transform.LiftTransformParams(shared_transform=True)(exported_mod)
- assert_structural_equal(lifted_mod, ExpectedAfterLift)
-
-
-if __name__ == "__main__":
- tvm.testing.main()
diff --git a/tests/python/relax/test_frontend_nn_extern_module.py
b/tests/python/relax/test_frontend_nn_extern_module.py
index 6ca7742422..6eaf1fbfc8 100644
--- a/tests/python/relax/test_frontend_nn_extern_module.py
+++ b/tests/python/relax/test_frontend_nn_extern_module.py
@@ -94,8 +94,9 @@ def _check_ir_equality(mod):
ext_scalar_add = R.call_dps_packed(
"ext_scalar_add", (a, b), out_sinfo=R.Tensor((),
dtype="float32")
)
- R.output(ext_scalar_add)
- return ext_scalar_add
+ gv: R.Tensor((), dtype="float32") = ext_scalar_add
+ R.output(gv)
+ return gv
@R.function
def test_sym(
@@ -109,8 +110,9 @@ def _check_ir_equality(mod):
ext_test_sym = R.call_dps_packed(
"ext_test_sym", (a, b), out_sinfo=R.Tensor((x, y, z, 9),
dtype="float32")
)
- R.output(ext_test_sym)
- return ext_test_sym
+ gv1: R.Tensor((x, y, z, 9), dtype="float32") = ext_test_sym
+ R.output(gv1)
+ return gv1
tvm.ir.assert_structural_equal(ExpectedModule, mod)
diff --git a/tests/python/relax/test_frontend_nn_modules.py
b/tests/python/relax/test_frontend_nn_modules.py
index 45128749e2..5ddc105055 100644
--- a/tests/python/relax/test_frontend_nn_modules.py
+++ b/tests/python/relax/test_frontend_nn_modules.py
@@ -493,7 +493,8 @@ def test_kv_cache():
R.prim_value(0),
sinfo_args=[R.Object()],
)
- gv = _io, cache
+ lv1 = _io, cache
+ gv = lv1
R.output(gv)
return gv
diff --git a/tests/python/relax/test_frontend_nn_op.py
b/tests/python/relax/test_frontend_nn_op.py
index 68f86bba50..7d78e47c94 100644
--- a/tests/python/relax/test_frontend_nn_op.py
+++ b/tests/python/relax/test_frontend_nn_op.py
@@ -538,7 +538,8 @@ def test_tensor_expr_op():
def _initialize_effect() -> R.Tuple(R.Object):
with R.dataflow():
_io: R.Object = R.null_value()
- gv = (_io,)
+ lv: R.Tuple(R.Object) = (_io,)
+ gv: R.Tuple(R.Object) = lv
R.output(gv)
return gv
@@ -610,7 +611,8 @@ def test_tensor_ir_op():
def _initialize_effect() -> R.Tuple(R.Object):
with R.dataflow():
_io: R.Object = R.null_value()
- gv = (_io,)
+ lv: R.Tuple(R.Object) = (_io,)
+ gv: R.Tuple(R.Object) = lv
R.output(gv)
return gv
@@ -697,7 +699,8 @@ def test_tensor_ir_inplace_op():
def _initialize_effect() -> R.Tuple(R.Object):
with R.dataflow():
_io: R.Object = R.null_value()
- gv = (_io,)
+ lv: R.Tuple(R.Object) = (_io,)
+ gv: R.Tuple(R.Object) = lv
R.output(gv)
return gv
@@ -714,12 +717,13 @@ def test_tensor_ir_inplace_op():
R.func_attr({"num_input": 4})
cls = Expected
with R.dataflow():
- gv1 = R.call_tir(
+ lv1 = R.call_tir(
cls.inplace_take,
(embedding_table, input_ids, embedding_dst),
out_sinfo=R.Tensor((total_seq_len, hidden_size), dtype),
tir_vars=R.shape([offset_1]),
)
+ gv1: R.Tensor((total_seq_len, hidden_size), dtype) = lv1
R.output(gv1)
return gv1
@@ -768,7 +772,8 @@ def test_tensor_ir_op_no_tir_var():
R.func_attr({"num_input": 1})
cls = Expected
with R.dataflow():
- gv = R.call_tir(cls.tir_func, (A,), out_sinfo=R.Tensor((16,
16), dtype="float32"))
+ lv = R.call_tir(cls.tir_func, (A,), out_sinfo=R.Tensor((16,
16), dtype="float32"))
+ gv: R.Tensor((16, 16), dtype="float32") = lv
R.output(gv)
return gv
@@ -795,7 +800,8 @@ def test_extern():
def _initialize_effect() -> R.Tuple(R.Object):
with R.dataflow():
_io: R.Object = R.null_value()
- gv = (_io,)
+ lv: R.Tuple(R.Object) = (_io,)
+ gv: R.Tuple(R.Object) = lv
R.output(gv)
return gv
@@ -882,7 +888,8 @@ def test_multinomial_from_uniform():
def _initialize_effect() -> R.Tuple(R.Object):
with R.dataflow():
_io: R.Object = R.null_value()
- gv = (_io,)
+ lv: R.Tuple(R.Object) = (_io,)
+ gv: R.Tuple(R.Object) = lv
R.output(gv)
return gv
@@ -1008,7 +1015,8 @@ def test_sample_top_p_top_k_from_sorted_prob():
def _initialize_effect() -> R.Tuple(R.Object):
with R.dataflow():
_io: R.Object = R.null_value()
- gv: R.Tuple(R.Object) = (_io,)
+ lv: R.Tuple(R.Object) = (_io,)
+ gv: R.Tuple(R.Object) = lv
R.output(gv)
return gv
@@ -1122,7 +1130,8 @@ def test_renormalize_top_p_top_k_prob():
def _initialize_effect() -> R.Tuple(R.Object):
with R.dataflow():
_io: R.Object = R.null_value()
- gv: R.Tuple(R.Object) = (_io,)
+ lv: R.Tuple(R.Object) = (_io,)
+ gv: R.Tuple(R.Object) = lv
R.output(gv)
return gv
diff --git a/tests/python/relax/test_frontend_nn_packing.py
b/tests/python/relax/test_frontend_nn_packing.py
index c2cc22c17d..56b614a807 100644
--- a/tests/python/relax/test_frontend_nn_packing.py
+++ b/tests/python/relax/test_frontend_nn_packing.py
@@ -59,7 +59,8 @@ def test_nn_export_to_relax():
matmul = R.matmul(x, matmul_1_weight)
matmul_2_weight = R.permute_dims(linear_2_weight)
matmul1 = R.matmul(x, matmul_2_weight)
- gv = R.add(matmul, matmul1)
+ add = R.add(matmul, matmul1)
+ gv = add
R.output(gv)
return gv
diff --git a/tests/python/relax/test_frontend_nn_subroutines.py
b/tests/python/relax/test_frontend_nn_subroutines.py
index 32ae967916..6bbf57aead 100644
--- a/tests/python/relax/test_frontend_nn_subroutines.py
+++ b/tests/python/relax/test_frontend_nn_subroutines.py
@@ -61,7 +61,8 @@ def test_linear():
def _initialize_effect() -> R.Tuple(R.Object):
with R.dataflow():
_io: R.Object = R.null_value()
- gv = (_io,)
+ lv: R.Tuple(R.Object) = (_io,)
+ gv: R.Tuple(R.Object) = lv
R.output(gv)
return gv
@@ -74,8 +75,9 @@ def test_linear():
with R.dataflow():
state = R.matmul(state, weights)
state = Expected.activation(state)
- R.output(state)
- return state
+ dataflow_output = state
+ R.output(dataflow_output)
+ return dataflow_output
@R.function(private=True)
def activation(
@@ -83,8 +85,9 @@ def test_linear():
) -> R.Tensor(("batch_size", 32), dtype="float32"):
with R.dataflow():
state = R.nn.silu(state)
- R.output(state)
- return state
+ dataflow_output = state
+ R.output(dataflow_output)
+ return dataflow_output
mod = Layer(64, 32)
batch_size = tvm.tir.Var("batch_size", "int64")