This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 11d5f677ec [Unity][Frontend][NN] Make effects optional in nn module.
(#15650)
11d5f677ec is described below
commit 11d5f677ecd5cd107db6a0dfa42db3381ab2aa53
Author: Josh Fromm <[email protected]>
AuthorDate: Fri Sep 1 23:03:19 2023 -0700
[Unity][Frontend][NN] Make effects optional in nn module. (#15650)
This PR adds the `debug` argument to `export_tvm`. When `debug` is `False`,
effects are not included in the output graph. This can make deploying models
less cumbersome since its not often theyll use effects. I also added automatic
annotation of the `num_inputs` function attribute since it is useful for a few
passes.
---
python/tvm/relax/frontend/nn/core.py | 26 +++++-
python/tvm/relax/frontend/nn/op.py | 2 +
python/tvm/relax/frontend/nn/spec.py | 39 +++++---
.../python/relax/test_frontend_nn_extern_module.py | 10 +-
tests/python/relax/test_frontend_nn_modules.py | 104 +++++++++++++++------
tests/python/relax/test_frontend_nn_op.py | 43 ++++++---
tests/python/relax/test_frontend_nn_subroutines.py | 5 +-
tests/python/relax/test_frontend_nn_tensor.py | 15 ++-
8 files changed, 173 insertions(+), 71 deletions(-)
diff --git a/python/tvm/relax/frontend/nn/core.py
b/python/tvm/relax/frontend/nn/core.py
index ea3744b6ca..718fb45e31 100644
--- a/python/tvm/relax/frontend/nn/core.py
+++ b/python/tvm/relax/frontend/nn/core.py
@@ -360,12 +360,31 @@ class Module(SubroutineMixin):
def export_tvm(
self,
spec: "_spec.ModuleSpecType",
+ debug: bool = False,
) -> Tuple[IRModule, List[Tuple[str, Parameter]]]:
- """Export the module to TVM IRModule and parameters"""
+ """Export the module to TVM IRModule and parameters
+
+ Parameters
+ ----------
+ spec : _spec.ModuleSpecType
+ A dictionary mapping each input name to a specification
+ that defines the inputs shape and dtype.
+ debug : bool
+ If set to True, then the exported module will support
+ effects. This enables things like printing in the graph.
+
+ Returns
+ -------
+ irmodule : tvm.ir.IRModule
+ The converted tvm IR representation of the model.
+ params : Dict[str, tvm.nd.array]
+ A dictionary of parameters corresponding to the weights of
+ the model.
+ """
from . import spec as _spec # pylint: disable=import-outside-toplevel
spec = _spec.ModuleSpec.from_raw(spec, self)
- mod, params = _spec.SpecBuilder().build(spec)
+ mod, params = _spec.SpecBuilder().build(spec, debug=debug)
return mod, params
def jit( # pylint: disable=too-many-arguments
@@ -375,6 +394,7 @@ class Module(SubroutineMixin):
device: str = "cpu",
pipeline: str = "zero",
out_format: str = "torch",
+ debug: bool = False,
) -> Callable:
"""Just-in-time compilation of a nn.model to an executable"""
from tvm import relax # pylint: disable=import-outside-toplevel
@@ -383,7 +403,7 @@ class Module(SubroutineMixin):
# Convert nn.Module to IRModule
spec = _spec.ModuleSpec.from_raw(spec, self)
- mod, params = _spec.SpecBuilder().build(spec)
+ mod, params = _spec.SpecBuilder().build(spec, debug=debug)
# Convert parameters
device = _str_to_device(device)
diff --git a/python/tvm/relax/frontend/nn/op.py
b/python/tvm/relax/frontend/nn/op.py
index 5473fcb499..66bf0e44b5 100644
--- a/python/tvm/relax/frontend/nn/op.py
+++ b/python/tvm/relax/frontend/nn/op.py
@@ -1290,4 +1290,6 @@ def tensor_expr_op(
def print_(array: Tensor):
+ if SpecBuilder.current().io_effect is None:
+ raise RuntimeError("Printing is only supported when debug mode is on.")
SpecBuilder.current().io_effect.print_(array)
diff --git a/python/tvm/relax/frontend/nn/spec.py
b/python/tvm/relax/frontend/nn/spec.py
index 983e0adc52..aeecfee782 100644
--- a/python/tvm/relax/frontend/nn/spec.py
+++ b/python/tvm/relax/frontend/nn/spec.py
@@ -18,7 +18,7 @@
from collections import defaultdict
import inspect
import threading
-from typing import Any, Callable, Dict, List, Sequence, Tuple, Union
+from typing import Any, Callable, Dict, List, Sequence, Tuple, Union, Optional
from tvm import tir
from tvm.ir import IRModule
@@ -288,7 +288,9 @@ class SpecBuilder:
assert hasattr(SpecBuilder._tls, "current")
delattr(SpecBuilder._tls, "current")
- def build(self, spec: ModuleSpec) -> Tuple[IRModule, List[Tuple[str,
core.Parameter]]]:
+ def build(
+ self, spec: ModuleSpec, debug: bool = False
+ ) -> Tuple[IRModule, List[Tuple[str, core.Parameter]]]:
"""Build the ModuleSpec to TVM IRModule. Returns the IRModule and the
parameters."""
# pylint: disable=protected-access
@@ -301,7 +303,9 @@ class SpecBuilder:
return params
def _effects() -> List[Tuple[str, core.Effect]]:
- result = [("", self.io_effect)]
+ result = []
+ if self.io_effect is not None:
+ result.append(("", self.io_effect))
for name, effect in core._attribute_finder(
spec.module, "", condition_yield=lambda x: isinstance(x,
core.Effect)
):
@@ -321,16 +325,22 @@ class SpecBuilder:
# pylint: enable=protected-access
+ # Disable IO effects if not in debug mode.
+ if not debug:
+ self.io_effect = None
params = _params()
effects = _effects()
extern_modules = _extern_modules()
with self:
- with self.builder.function("_initialize_effect"):
- with self.builder.dataflow():
- outputs = _emit_effect_init(self.builder, effects)
- self.builder.emit_func_output(outputs, params=[])
+ if effects:
+ with self.builder.function("_initialize_effect"):
+ with self.builder.dataflow():
+ outputs = _emit_effect_init(self.builder, effects)
+ self.builder.emit_func_output(outputs, params=[])
for method_name, method_spec in zip(spec.method_names,
spec.method_specs):
- with self.builder.function(method_name):
+ with self.builder.function(
+ method_name, attrs={"num_input":
len(method_spec.arg_specs) + len(effects)}
+ ):
with self.builder.dataflow():
outputs, inputs = _emit_method(self.builder,
method_spec, params, effects)
self.builder.emit_func_output(outputs, inputs)
@@ -363,7 +373,7 @@ def _emit_method(
builder: BlockBuilder,
spec: MethodSpec,
params: List[Tuple[str, core.Parameter]],
- effects: List[Tuple[str, core.Effect]],
+ effects: Optional[List[Tuple[str, core.Effect]]],
):
# pylint: disable=protected-access
def _unwrap_ret(expr: Any) -> Any:
@@ -386,16 +396,19 @@ def _emit_method(
inputs = []
for arg in explicit_inputs:
inputs.append(_convert_input(arg))
+ for name, effect in effects:
+ inputs.extend(effect.create(name))
for name, param in params:
param._expr = core._tensor_placeholder(name, param.shape,
param.dtype)._expr
inputs.append(param._expr)
- for name, effect in effects:
- inputs.extend(effect.create(name))
- # pylint: enable=protected-access
+ # pylint: enable=protected-access
outputs = spec.method(*explicit_inputs)
effect_outputs = []
for _, effect in effects:
effect_outputs.extend(effect.finalize())
- outputs = builder.emit_output(rx.Tuple([_unwrap_ret(outputs),
rx.Tuple(effect_outputs)]))
+ if effect_outputs:
+ outputs = builder.emit_output(rx.Tuple([_unwrap_ret(outputs),
rx.Tuple(effect_outputs)]))
+ else:
+ outputs = builder.emit_output(_unwrap_ret(outputs))
return outputs, inputs
diff --git a/tests/python/relax/test_frontend_nn_extern_module.py
b/tests/python/relax/test_frontend_nn_extern_module.py
index a5753c19ab..ba8584a85e 100644
--- a/tests/python/relax/test_frontend_nn_extern_module.py
+++ b/tests/python/relax/test_frontend_nn_extern_module.py
@@ -92,25 +92,21 @@ def test_extern_module():
def forward(
a_1: R.Tensor(("a", "b", "c", "d", 1, 2, 3, 4), dtype="float32"),
b_1: R.Tensor(("c", "d", "e", "f", 5, 6, 7, 8), dtype="float32"),
- _io: R.Object,
- ) -> R.Tuple(
- R.Tensor(("a", "b", "c", "d", "e", "f", 9, 10), dtype="float32"),
R.Tuple(R.Object)
- ):
+ ) -> R.Tensor(("a", "b", "c", "d", "e", "f", 9, 10), dtype="float32"):
a = T.int64()
b = T.int64()
c = T.int64()
d = T.int64()
e = T.int64()
f = T.int64()
+ R.func_attr({"num_input": 2})
with R.dataflow():
matmul = R.call_dps_packed(
"matmul",
(a_1, b_1),
out_sinfo=R.Tensor((a, b, c, d, e, f, 9, 10), dtype="float32"),
)
- gv1: R.Tuple(
- R.Tensor((a, b, c, d, e, f, 9, 10), dtype="float32"),
R.Tuple(R.Object)
- ) = matmul, (_io,)
+ gv1: R.Tensor((a, b, c, d, e, f, 9, 10), dtype="float32") = matmul
R.output(gv1)
return gv1
diff --git a/tests/python/relax/test_frontend_nn_modules.py
b/tests/python/relax/test_frontend_nn_modules.py
index 86f018383f..dc69e2c51e 100644
--- a/tests/python/relax/test_frontend_nn_modules.py
+++ b/tests/python/relax/test_frontend_nn_modules.py
@@ -32,6 +32,7 @@ def test_silu():
x: R.Tensor((3, 3), dtype="float32"),
_io: R.Object,
) -> R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tuple(R.Object)):
+ R.func_attr({"num_input": 2})
with R.dataflow():
silu: R.Tensor((3, 3), dtype="float32") = R.nn.silu(x)
gv1: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tuple(R.Object))
= silu, (_io,)
@@ -39,7 +40,7 @@ def test_silu():
return gv1
mod = modules.SiLU()
- tvm_mod, _ = mod.export_tvm(spec={"forward": {"x": spec.Tensor((3, 3),
"float32")}})
+ tvm_mod, _ = mod.export_tvm(spec={"forward": {"x": spec.Tensor((3, 3),
"float32")}}, debug=True)
assert_structural_equal(tvm_mod["forward"], forward, True)
@@ -49,13 +50,14 @@ def test_identity():
x: R.Tensor((3, 3), dtype="float32"),
_io: R.Object,
) -> R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tuple(R.Object)):
+ R.func_attr({"num_input": 2})
with R.dataflow():
gv1: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tuple(R.Object))
= x, (_io,)
R.output(gv1)
return gv1
mod = modules.Identity()
- tvm_mod, _ = mod.export_tvm(spec={"forward": {"x": spec.Tensor((3, 3),
"float32")}})
+ tvm_mod, _ = mod.export_tvm(spec={"forward": {"x": spec.Tensor((3, 3),
"float32")}}, debug=True)
assert_structural_equal(tvm_mod["forward"], forward, True)
@@ -63,10 +65,11 @@ def test_linear():
@R.function
def forward(
x: R.Tensor((1, 4), dtype="float32"),
+ _io: R.Object,
weight: R.Tensor((8, 4), dtype="float32"),
bias: R.Tensor((8,), dtype="float32"),
- _io: R.Object,
) -> R.Tuple(R.Tensor((1, 8), dtype="float32"), R.Tuple(R.Object)):
+ R.func_attr({"num_input": 2})
with R.dataflow():
permute_dims: R.Tensor((4, 8), dtype="float32") =
R.permute_dims(weight, axes=None)
matmul: R.Tensor((1, 8), dtype="float32") = R.matmul(x,
permute_dims, out_dtype="void")
@@ -76,22 +79,37 @@ def test_linear():
return gv1
mod = modules.Linear(4, 8)
- tvm_mod, _ = mod.export_tvm(spec={"forward": {"x": spec.Tensor((1, 4),
"float32")}})
+ tvm_mod, _ = mod.export_tvm(spec={"forward": {"x": spec.Tensor((1, 4),
"float32")}}, debug=True)
assert_structural_equal(tvm_mod["forward"], forward, True)
def test_conv1d():
- # fmt: off
@R.function
- def forward(x: R.Tensor((1, 3, 32), dtype="float32"), weight:
R.Tensor((32, 3, 3), dtype="float32"), bias: R.Tensor((32,), dtype="float32"),
_io: R.Object) -> R.Tuple(R.Tensor((1, 32, 30), dtype="float32"),
R.Tuple(R.Object)):
+ def forward(
+ x: R.Tensor((1, 3, 32), dtype="float32"),
+ _io: R.Object,
+ weight: R.Tensor((32, 3, 3), dtype="float32"),
+ bias: R.Tensor((32,), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((1, 32, 30), dtype="float32"), R.Tuple(R.Object)):
+ R.func_attr({"num_input": 2})
with R.dataflow():
- lv1: R.Tensor((1, 32, 30), dtype="float32") = R.nn.conv1d(x,
weight, strides=[1], padding=[0, 0], dilation=[1], groups=1, data_layout="NCW",
kernel_layout="OIW", out_layout="NCW", out_dtype="void")
+ lv1: R.Tensor((1, 32, 30), dtype="float32") = R.nn.conv1d(
+ x,
+ weight,
+ strides=[1],
+ padding=[0, 0],
+ dilation=[1],
+ groups=1,
+ data_layout="NCW",
+ kernel_layout="OIW",
+ out_layout="NCW",
+ out_dtype="void",
+ )
lv2: R.Tensor((1, 32, 1), dtype="float32") = R.reshape(bias,
R.shape([1, 32, 1]))
conv1d: R.Tensor((1, 32, 30), dtype="float32") = R.add(lv1, lv2)
gv1: R.Tuple(R.Tensor((1, 32, 30), dtype="float32"),
R.Tuple(R.Object)) = conv1d, (_io,)
R.output(gv1)
return gv1
- # fmt: on
mod = modules.Conv1D(3, 32, 3, bias=True)
tvm_mod, _ = mod.export_tvm(
@@ -99,24 +117,35 @@ def test_conv1d():
"forward": {
"x": spec.Tensor([1, 3, 32], "float32"),
}
- }
+ },
+ debug=True,
)
assert_structural_equal(tvm_mod["forward"], forward, True)
def test_layer_norm():
- # fmt: off
@R.function
- def forward(x: R.Tensor((2, 4, 8), dtype="float32"), weight:
R.Tensor((8,), dtype="float32"), bias: R.Tensor((8,), dtype="float32"), _io:
R.Object) -> R.Tuple(R.Tensor((2, 4, 8), dtype="float32"), R.Tuple(R.Object)):
+ def forward(
+ x: R.Tensor((2, 4, 8), dtype="float32"),
+ _io: R.Object,
+ weight: R.Tensor((8,), dtype="float32"),
+ bias: R.Tensor((8,), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((2, 4, 8), dtype="float32"), R.Tuple(R.Object)):
+ R.func_attr({"num_input": 2})
with R.dataflow():
- layer_norm: R.Tensor((2, 4, 8), dtype="float32") =
R.nn.layer_norm(x, weight, bias, axes=[-1], epsilon=1.0000000000000001e-05,
center=True, scale=True)
- gv1: R.Tuple(R.Tensor((2, 4, 8), dtype="float32"),
R.Tuple(R.Object)) = layer_norm, (_io,)
+ layer_norm: R.Tensor((2, 4, 8), dtype="float32") = R.nn.layer_norm(
+ x, weight, bias, axes=[-1], epsilon=1.0000000000000001e-05,
center=True, scale=True
+ )
+ gv1: R.Tuple(R.Tensor((2, 4, 8), dtype="float32"),
R.Tuple(R.Object)) = layer_norm, (
+ _io,
+ )
R.output(gv1)
return gv1
- # fmt: on
mod = modules.LayerNorm(8)
- tvm_mod, _ = mod.export_tvm(spec={"forward": {"x": spec.Tensor((2, 4, 8),
"float32")}})
+ tvm_mod, _ = mod.export_tvm(
+ spec={"forward": {"x": spec.Tensor((2, 4, 8), "float32")}}, debug=True
+ )
assert_structural_equal(tvm_mod["forward"], forward, True)
@@ -124,10 +153,11 @@ def test_conv2d():
@R.function
def forward(
x: R.Tensor((1, 3, 32, 32), dtype="float32"),
+ _io: R.Object,
weight: R.Tensor((32, 3, 3, 3), dtype="float32"),
bias: R.Tensor((32,), dtype="float32"),
- _io: R.Object,
) -> R.Tuple(R.Tensor((1, 32, 30, 30), dtype="float32"),
R.Tuple(R.Object)):
+ R.func_attr({"num_input": 2})
with R.dataflow():
lv1: R.Tensor((1, 32, 30, 30), dtype="float32") = R.nn.conv2d(x,
weight)
lv2: R.Tensor((1, 32, 1, 1), dtype="float32") = R.reshape(bias,
R.shape([1, 32, 1, 1]))
@@ -144,7 +174,8 @@ def test_conv2d():
"forward": {
"x": spec.Tensor([1, 3, 32, 32], "float32"),
}
- }
+ },
+ debug=True,
)
assert_structural_equal(tvm_mod["forward"], forward, True)
@@ -153,9 +184,10 @@ def test_rms_norm():
@R.function
def forward(
x: R.Tensor((2, 4, 8), dtype="float32"),
- weight: R.Tensor((8,), dtype="float32"),
_io: R.Object,
+ weight: R.Tensor((8,), dtype="float32"),
) -> R.Tuple(R.Tensor((2, 4, 8), dtype="float32"), R.Tuple(R.Object)):
+ R.func_attr({"num_input": 2})
with R.dataflow():
rms_norm: R.Tensor((2, 4, 8), dtype="float32") = R.nn.rms_norm(
x, weight, axes=[2], epsilon=1.0000000000000001e-05
@@ -165,7 +197,9 @@ def test_rms_norm():
return gv1
mod = modules.RMSNorm(8, [2], bias=False)
- tvm_mod, _ = mod.export_tvm(spec={"forward": {"x": spec.Tensor((2, 4, 8),
"float32")}})
+ tvm_mod, _ = mod.export_tvm(
+ spec={"forward": {"x": spec.Tensor((2, 4, 8), "float32")}}, debug=True
+ )
assert_structural_equal(tvm_mod["forward"], forward, True)
@@ -173,10 +207,11 @@ def test_group_norm():
@R.function
def forward(
x: R.Tensor((2, 4, 8), dtype="float32"),
+ _io: R.Object,
weight: R.Tensor((4,), dtype="float32"),
bias: R.Tensor((4,), dtype="float32"),
- _io: R.Object,
) -> R.Tuple(R.Tensor((2, 4, 8), dtype="float32"), R.Tuple(R.Object)):
+ R.func_attr({"num_input": 2})
with R.dataflow():
group_norm: R.Tensor((2, 4, 8), dtype="float32") = R.nn.group_norm(
x, weight, bias, num_groups=2, channel_axis=1, axes=[2]
@@ -188,7 +223,9 @@ def test_group_norm():
return gv1
mod = modules.GroupNorm(num_groups=2, num_channels=4)
- tvm_mod, _ = mod.export_tvm(spec={"forward": {"x": spec.Tensor((2, 4, 8),
"float32")}})
+ tvm_mod, _ = mod.export_tvm(
+ spec={"forward": {"x": spec.Tensor((2, 4, 8), "float32")}}, debug=True
+ )
assert_structural_equal(tvm_mod["forward"], forward, True)
@@ -196,9 +233,10 @@ def test_embedding():
@R.function
def forward(
x: R.Tensor((1, 4), dtype="int32"),
- weight: R.Tensor((4, 8), dtype="float32"),
_io: R.Object,
+ weight: R.Tensor((4, 8), dtype="float32"),
) -> R.Tuple(R.Tensor((1, 4, 8), dtype="float32"), R.Tuple(R.Object)):
+ R.func_attr({"num_input": 2})
with R.dataflow():
reshape: R.Tensor((4,), dtype="int32") = R.reshape(x, R.shape([4]))
take: R.Tensor((4, 8), dtype="float32") = R.take(weight, reshape,
axis=0)
@@ -208,7 +246,7 @@ def test_embedding():
return gv1
mod = modules.Embedding(4, 8, "float32")
- tvm_mod, _ = mod.export_tvm(spec={"forward": {"x": spec.Tensor((1, 4),
"int32")}})
+ tvm_mod, _ = mod.export_tvm(spec={"forward": {"x": spec.Tensor((1, 4),
"int32")}}, debug=True)
assert_structural_equal(tvm_mod["forward"], forward, True)
@@ -217,13 +255,14 @@ def test_timestep_embedding():
def forward(
sample: R.Tensor((32, 32), dtype="float32"),
condition: R.Tensor((32, 16), dtype="float32"),
+ _io: R.Object,
linear_1_weight: R.Tensor((32, 32), dtype="float32"),
linear_1_bias: R.Tensor((32,), dtype="float32"),
cond_proj_weight: R.Tensor((32, 16), dtype="float32"),
linear_2_weight: R.Tensor((32, 32), dtype="float32"),
linear_2_bias: R.Tensor((32,), dtype="float32"),
- _io: R.Object,
) -> R.Tuple(R.Tensor((32, 32), dtype="float32"), R.Tuple(R.Object)):
+ R.func_attr({"num_input": 3})
with R.dataflow():
permute_dims: R.Tensor((16, 32), dtype="float32") = R.permute_dims(
cond_proj_weight, axes=None
@@ -258,7 +297,8 @@ def test_timestep_embedding():
"sample": spec.Tensor((32, 32), "float32"),
"condition": spec.Tensor((32, 16), "float32"),
}
- }
+ },
+ debug=True,
)
assert_structural_equal(tvm_mod["forward"], forward, True)
@@ -268,6 +308,7 @@ def test_timesteps():
def forward(
x: R.Tensor((3,), dtype="float32"), _io: R.Object
) -> R.Tuple(R.Tensor((3, 10), dtype="float32"), R.Tuple(R.Object)):
+ R.func_attr({"num_input": 2})
with R.dataflow():
lv1: R.Tensor((3,), dtype="float32") = R.astype(x, dtype="float32")
lv2: R.Tensor((3, 1), dtype="float32") = R.expand_dims(lv1,
axis=[1])
@@ -293,7 +334,7 @@ def test_timesteps():
return gv1
mod = modules.Timesteps(10)
- tvm_mod, _ = mod.export_tvm(spec={"forward": {"x": spec.Tensor((3,),
"float32")}})
+ tvm_mod, _ = mod.export_tvm(spec={"forward": {"x": spec.Tensor((3,),
"float32")}}, debug=True)
assert_structural_equal(tvm_mod["forward"], forward, True)
@@ -323,6 +364,7 @@ def test_kv_cache():
def forward(
x: R.Tensor((2, 4), dtype="float32"), _io: R.Object, cache:
R.Object
) -> R.Tuple(R.Tensor((4, 2, 4), dtype="float32"), R.Tuple(R.Object,
R.Object)):
+ R.func_attr({"num_input": 3})
with R.dataflow():
lv2: R.Object = R.call_packed(
"vm.builtin.attention_kv_cache_append", cache, x,
sinfo_args=(R.Object,)
@@ -347,7 +389,9 @@ def test_kv_cache():
self.cache.append(x)
return self.cache.view(4)
- tvm_mod, _ = KVCacheTest().export_tvm(spec={"forward": {"x":
spec.Tensor((2, 4), "float32")}})
+ tvm_mod, _ = KVCacheTest().export_tvm(
+ spec={"forward": {"x": spec.Tensor((2, 4), "float32")}}, debug=True
+ )
assert_structural_equal(tvm_mod, Module, True)
@@ -356,6 +400,7 @@ def test_attention():
def forward(
hidden_states: R.Tensor((2, 4096, 640), dtype="float32"),
encoder_hidden_states: R.Tensor((2, 77, 2048), dtype="float32"),
+ _io: R.Object,
to_q_weight: R.Tensor((640, 640), dtype="float32"),
to_k_weight: R.Tensor((640, 2048), dtype="float32"),
to_v_weight: R.Tensor((640, 2048), dtype="float32"),
@@ -363,8 +408,8 @@ def test_attention():
group_norm_bias: R.Tensor((640,), dtype="float32"),
to_out_0_weight: R.Tensor((640, 640), dtype="float32"),
to_out_0_bias: R.Tensor((640,), dtype="float32"),
- _io: R.Object,
) -> R.Tuple(R.Tensor((2, 4096, 640), dtype="float32"), R.Tuple(R.Object)):
+ R.func_attr({"num_input": 3})
with R.dataflow():
group_norm: R.Tensor((2, 4096, 640), dtype="float32") =
R.nn.group_norm(
hidden_states,
@@ -428,7 +473,8 @@ def test_attention():
"hidden_states": spec.Tensor((2, 4096, 640), "float32"),
"encoder_hidden_states": spec.Tensor((2, 77, 2048), "float32"),
}
- }
+ },
+ debug=True,
)
assert_structural_equal(tvm_mod["forward"], forward, True)
diff --git a/tests/python/relax/test_frontend_nn_op.py
b/tests/python/relax/test_frontend_nn_op.py
index 255f8ab3ba..f6cb29a87b 100644
--- a/tests/python/relax/test_frontend_nn_op.py
+++ b/tests/python/relax/test_frontend_nn_op.py
@@ -43,6 +43,7 @@ def test_binary():
# fmt: off
@R.function
def test(x: R.Tensor((1, 10), dtype="float32"), y: R.Tensor((10, 1),
dtype="float32"), _io: R.Object) -> R.Tuple(R.Tuple(R.Tensor((10, 10),
dtype="float32"), R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10),
dtype="float32"), R.Tensor((1, 1), dtype="float32"), R.Tensor((10, 10),
dtype="float32"), R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10),
dtype="float32")), R.Tuple(R.Object)):
+ R.func_attr({"num_input": 3})
with R.dataflow():
add: R.Tensor((10, 10), dtype="float32") = R.add(x, y)
mul: R.Tensor((10, 10), dtype="float32") = R.multiply(x, y)
@@ -58,7 +59,8 @@ def test_binary():
m = Model()
irmodule, _ = m.export_tvm(
- spec={"test": {"x": spec.Tensor([1, 10], "float32"), "y":
spec.Tensor([10, 1], "float32")}}
+ spec={"test": {"x": spec.Tensor([1, 10], "float32"), "y":
spec.Tensor([10, 1], "float32")}},
+ debug=True,
)
tvm.ir.assert_structural_equal(irmodule["test"], test)
@@ -78,6 +80,7 @@ def test_manipulate():
# fmt: off
@R.function
def test(x: R.Tensor((1, 5, 2), dtype="float32"), _io: R.Object) ->
R.Tuple(R.Tuple(R.Tensor((2, 5, 2), dtype="float32"), R.Tensor((2, 5, 1),
dtype="float32"), R.Tensor((1, 10), dtype="float32"), R.Tensor((1, 10, 2),
dtype="float32"), R.Tensor((5, 2), dtype="float32"), R.Tensor((1, 1, 5, 2),
dtype="float32"), R.Tensor((2, 5, 2), dtype="float32")), R.Tuple(R.Object)):
+ R.func_attr({"num_input": 2})
with R.dataflow():
broadcast_to: R.Tensor((2, 5, 2), dtype="float32") =
R.broadcast_to(x, R.shape([2, 5, 2]))
permute_dims: R.Tensor((2, 5, 1), dtype="float32") =
R.permute_dims(x, axes=[2, 1, 0])
@@ -92,7 +95,7 @@ def test_manipulate():
# fmt: on
m = Model()
- irmodule, _ = m.export_tvm(spec={"test": {"x": spec.Tensor([1, 5, 2],
"float32")}})
+ irmodule, _ = m.export_tvm(spec={"test": {"x": spec.Tensor([1, 5, 2],
"float32")}}, debug=True)
tvm.ir.assert_structural_equal(irmodule["test"], test)
@@ -106,6 +109,7 @@ def test_index():
# fmt: off
@R.function
def test(x: R.Tensor((2, 1, 10), dtype="float32"), y: R.Tensor((5,),
dtype="int32"), _io: R.Object) -> R.Tuple(R.Tensor((2, 1, 5), dtype="float32"),
R.Tuple(R.Object)):
+ R.func_attr({"num_input": 3})
with R.dataflow():
take: R.Tensor((2, 1, 5), dtype="float32") = R.take(x, y, axis=2)
gv1: R.Tuple(R.Tensor((2, 1, 5), dtype="float32"),
R.Tuple(R.Object)) = take, (_io,)
@@ -115,7 +119,8 @@ def test_index():
m = Model()
irmodule, params = m.export_tvm(
- spec={"test": {"x": spec.Tensor([2, 1, 10], "float32"), "y":
spec.Tensor([5], "int32")}}
+ spec={"test": {"x": spec.Tensor([2, 1, 10], "float32"), "y":
spec.Tensor([5], "int32")}},
+ debug=True,
)
tvm.ir.assert_structural_equal(irmodule["test"], test)
@@ -130,6 +135,7 @@ def test_datatype():
# fmt: off
@R.function
def test(x: R.Tensor((2, 1, 10), dtype="float32"), _io: R.Object) ->
R.Tuple(R.Tensor((2, 1, 10), dtype="float16"), R.Tuple(R.Object)):
+ R.func_attr({"num_input": 2})
with R.dataflow():
astype: R.Tensor((2, 1, 10), dtype="float16") = R.astype(x,
dtype="float16")
gv1: R.Tuple(R.Tensor((2, 1, 10), dtype="float16"),
R.Tuple(R.Object)) = astype, (_io,)
@@ -138,7 +144,7 @@ def test_datatype():
# fmt: on
m = Model()
- irmodule, params = m.export_tvm(spec={"test": {"x": spec.Tensor([2, 1,
10], "float32")}})
+ irmodule, _ = m.export_tvm(spec={"test": {"x": spec.Tensor([2, 1, 10],
"float32")}}, debug=True)
tvm.ir.assert_structural_equal(irmodule["test"], test)
@@ -163,6 +169,7 @@ def test_image():
),
R.Tuple(R.Object),
):
+ R.func_attr({"num_input": 4})
with R.dataflow():
lv0: R.Tensor((1, 3, 34, 34), dtype="float32") = R.nn.pad(x, (0,
0, 0, 0, 1, 1, 1, 1))
lv1: R.Tensor((1, 32, 32, 32), dtype="float32") = R.nn.conv2d(
@@ -210,7 +217,8 @@ def test_image():
"weight": spec.Tensor([32, 3, 3, 3], "float32"),
"bias": spec.Tensor([32], "float32"),
}
- }
+ },
+ debug=True,
)
tvm.ir.assert_structural_equal(irmodule["test"], test)
@@ -233,6 +241,7 @@ def test_chunk():
),
R.Tuple(R.Object),
):
+ R.func_attr({"num_input": 2})
with R.dataflow():
chunk: R.Tuple(
R.Tensor((2,), dtype="float32"),
@@ -257,7 +266,7 @@ def test_chunk():
return gv1
m = Model()
- irmodule, _ = m.export_tvm(spec={"test": {"x": spec.Tensor([8],
"float32")}})
+ irmodule, _ = m.export_tvm(spec={"test": {"x": spec.Tensor([8],
"float32")}}, debug=True)
tvm.ir.assert_structural_equal(irmodule["test"], test)
@@ -279,6 +288,7 @@ def test_nn():
bias: R.Tensor((3,), dtype="float32"),
_io: R.Object,
) -> R.Tuple(R.Tensor((2, 3, 4, 5), dtype="float32"), R.Tuple(R.Object)):
+ R.func_attr({"num_input": 4})
with R.dataflow():
silu: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.silu(x)
gelu: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.gelu(x)
@@ -304,7 +314,8 @@ def test_nn():
"weight": spec.Tensor([4, 5], "float32"),
"bias": spec.Tensor([3], "float32"),
}
- }
+ },
+ debug=True,
)
tvm.ir.assert_structural_equal(irmodule["test"], test)
@@ -328,6 +339,7 @@ def test_create():
# fmt: off
@R.function
def test(x: R.Tensor((10, 10), dtype="float32"), _io: R.Object) ->
R.Tuple(R.Tensor((10, 10), dtype="float32"), R.Tuple(R.Object)):
+ R.func_attr({"num_input": 2})
with R.dataflow():
triu: R.Tensor((10, 10), dtype="float32") = R.triu(x, k=0)
full: R.Tensor((10, 10), dtype="float32") = R.full(R.shape([10,
10]), R.const(10, "float32"), dtype="float32")
@@ -341,7 +353,9 @@ def test_create():
# fmt: on
m = Model()
- irmodule, params = m.export_tvm(spec={"test": {"x": spec.Tensor([10, 10],
"float32")}})
+ irmodule, params = m.export_tvm(
+ spec={"test": {"x": spec.Tensor([10, 10], "float32")}}, debug=True
+ )
tvm.ir.assert_structural_equal(irmodule["test"], test)
@@ -356,6 +370,7 @@ def test_timestep_embedding():
def test(
x: R.Tensor((3,), dtype="float32"), _io: R.Object
) -> R.Tuple(R.Tensor((3, 10), dtype="float32"), R.Tuple(R.Object)):
+ R.func_attr({"num_input": 2})
with R.dataflow():
lv1: R.Tensor((3,), dtype="float32") = R.astype(x, dtype="float32")
lv2: R.Tensor((3, 1), dtype="float32") = R.expand_dims(lv1,
axis=[1])
@@ -381,7 +396,7 @@ def test_timestep_embedding():
return gv1
m = Model()
- irmodule, _ = m.export_tvm(spec={"test": {"x": spec.Tensor([3],
"float32")}})
+ irmodule, _ = m.export_tvm(spec={"test": {"x": spec.Tensor([3],
"float32")}}, debug=True)
tvm.ir.assert_structural_equal(irmodule["test"], test)
@@ -398,6 +413,7 @@ def test_scaled_dot_product_attention():
value: R.Tensor((1, 32, 32, 32), dtype="float32"),
_io: R.Object,
) -> R.Tuple(R.Tensor((1, 32, 32, 32), dtype="float32"),
R.Tuple(R.Object)):
+ R.func_attr({"num_input": 4})
with R.dataflow():
scaled_dot_product_attention: R.Tensor(
(1, 32, 32, 32), dtype="float32"
@@ -416,7 +432,8 @@ def test_scaled_dot_product_attention():
"key": spec.Tensor([1, 32, 32, 32], "float32"),
"value": spec.Tensor([1, 32, 32, 32], "float32"),
}
- }
+ },
+ debug=True,
)
tvm.ir.assert_structural_equal(irmodule["test"], test)
@@ -455,6 +472,7 @@ def test_tensor_expr_op():
@R.function
def test(x: R.Tensor((10, 10), dtype="float32"), _io: R.Object) ->
R.Tuple(R.Tensor((10, 10), dtype="float32"), R.Tuple(R.Object)):
cls = Expected
+ R.func_attr({"num_input": 2})
with R.dataflow():
lv1 = R.call_tir(cls.add_one, (x,), out_sinfo=R.Tensor((10,
10), dtype="float32"))
add_one1: R.Tensor((10, 10), dtype="float32") = lv1
@@ -464,7 +482,7 @@ def test_tensor_expr_op():
# fmt: on
m = Model()
- irmodule, params = m.export_tvm(spec={"test": {"x": spec.Tensor([10, 10],
"float32")}})
+ irmodule, _ = m.export_tvm(spec={"test": {"x": spec.Tensor([10, 10],
"float32")}}, debug=True)
tvm.ir.assert_structural_equal(irmodule, Expected)
@@ -490,6 +508,7 @@ def test_print():
@R.function
def test(x: R.Tensor((10, 10), dtype="float32"), _io: R.Object) ->
R.Tuple(R.Tensor((10, 10), dtype="float32"), R.Tuple(R.Object)):
+ R.func_attr({"num_input": 2})
with R.dataflow():
add: R.Tensor((10, 10), dtype="float32") = R.add(x, x)
_io1: R.Object = R.call_pure_packed("effect.print", _io, add,
sinfo_args=(R.Object(),))
@@ -499,7 +518,7 @@ def test_print():
# fmt: on
m = Model()
- irmodule, params = m.export_tvm(spec={"test": {"x": spec.Tensor([10, 10],
"float32")}})
+ irmodule, _ = m.export_tvm(spec={"test": {"x": spec.Tensor([10, 10],
"float32")}}, debug=True)
tvm.ir.assert_structural_equal(irmodule["test"], Expected["test"])
diff --git a/tests/python/relax/test_frontend_nn_subroutines.py
b/tests/python/relax/test_frontend_nn_subroutines.py
index 2075e7c768..6bbf57aead 100644
--- a/tests/python/relax/test_frontend_nn_subroutines.py
+++ b/tests/python/relax/test_frontend_nn_subroutines.py
@@ -47,9 +47,10 @@ def test_linear():
@R.function
def forward(
state: R.Tensor(("batch_size", 64), dtype="float32"),
- weights: R.Tensor((64, 32), dtype="float32"),
_io: R.Object,
+ weights: R.Tensor((64, 32), dtype="float32"),
) -> R.Tuple(R.Tensor(("batch_size", 32), dtype="float32"),
R.Tuple(R.Object)):
+ R.func_attr({"num_input": 2})
with R.dataflow():
state = Expected.layer(state, weights)
dataflow_output = (state, (_io,))
@@ -91,7 +92,7 @@ def test_linear():
mod = Layer(64, 32)
batch_size = tvm.tir.Var("batch_size", "int64")
tvm_mod, _ = mod.export_tvm(
- spec={"forward": {"input": nn.spec.Tensor((batch_size, 64),
"float32")}}
+ spec={"forward": {"input": nn.spec.Tensor((batch_size, 64),
"float32")}}, debug=True
)
assert_structural_equal(Expected, tvm_mod, True)
diff --git a/tests/python/relax/test_frontend_nn_tensor.py
b/tests/python/relax/test_frontend_nn_tensor.py
index 63d756e637..2d2a23cc46 100644
--- a/tests/python/relax/test_frontend_nn_tensor.py
+++ b/tests/python/relax/test_frontend_nn_tensor.py
@@ -55,6 +55,7 @@ def test_tensor_op_binary_tensor_tensor():
# fmt: off
@R.function
def test(x: R.Tensor((1, 10), dtype="float32"), y: R.Tensor((2, 1),
dtype="float32"), _io: R.Object) -> R.Tuple(R.Tuple(R.Tensor((2, 10),
dtype="float32"), R.Tensor((2, 10), dtype="float32"), R.Tensor((2, 10),
dtype="float32"), R.Tensor((2, 10), dtype="float32"), R.Tensor((2, 10),
dtype="float32")), R.Tuple(R.Object)):
+ R.func_attr({"num_input": 3})
with R.dataflow():
add: R.Tensor((2, 10), dtype="float32") = R.add(x, y)
mul: R.Tensor((2, 10), dtype="float32") = R.multiply(x, y)
@@ -68,13 +69,14 @@ def test_tensor_op_binary_tensor_tensor():
m = Model()
irmodule, _ = m.export_tvm(
- spec={"test": {"x": spec.Tensor([1, 10], "float32"), "y":
spec.Tensor([2, 1], "float32")}}
+ spec={"test": {"x": spec.Tensor([1, 10], "float32"), "y":
spec.Tensor([2, 1], "float32")}},
+ debug=True,
)
tvm.ir.assert_structural_equal(irmodule["test"], test)
-def test_tensor_op_binary_tensor_saclar():
+def test_tensor_op_binary_tensor_scalar():
class Model(Module):
def test(self, x: Tensor):
y = 10
@@ -89,6 +91,7 @@ def test_tensor_op_binary_tensor_saclar():
# fmt: off
@R.function
def test(x: R.Tensor((1, 10), dtype="float32"), _io: R.Object) ->
R.Tuple(R.Tuple(R.Tensor((1, 10), dtype="float32"), R.Tensor((1, 10),
dtype="float32"), R.Tensor((1, 10), dtype="float32"), R.Tensor((1, 10),
dtype="float32"), R.Tensor((1, 10), dtype="float32"), R.Tensor((1, 10),
dtype="float32")), R.Tuple(R.Object)):
+ R.func_attr({"num_input": 2})
with R.dataflow():
add: R.Tensor((1, 10), dtype="float32") = R.add(x, R.const(10,
"float32"))
add1: R.Tensor((1, 10), dtype="float32") = R.add(x, R.const(10,
"float32"))
@@ -102,7 +105,7 @@ def test_tensor_op_binary_tensor_saclar():
# fmt: on
m = Model()
- irmodule, _ = m.export_tvm(spec={"test": {"x": spec.Tensor([1, 10],
"float32")}})
+ irmodule, _ = m.export_tvm(spec={"test": {"x": spec.Tensor([1, 10],
"float32")}}, debug=True)
tvm.ir.assert_structural_equal(irmodule["test"], test)
@@ -116,6 +119,7 @@ def test_tensor_op_datatype():
# fmt: off
@R.function
def test(x: R.Tensor((1, 10), dtype="float32"), _io: R.Object) ->
R.Tuple(R.Tensor((1, 10), dtype="float16"), R.Tuple(R.Object)):
+ R.func_attr({"num_input": 2})
with R.dataflow():
astype: R.Tensor((1, 10), dtype="float16") = R.astype(x,
dtype="float16")
gv1: R.Tuple(R.Tensor((1, 10), dtype="float16"),
R.Tuple(R.Object)) = astype, (_io,)
@@ -124,7 +128,7 @@ def test_tensor_op_datatype():
# fmt: on
m = Model()
- irmodule, _ = m.export_tvm(spec={"test": {"x": spec.Tensor([1, 10],
"float32")}})
+ irmodule, _ = m.export_tvm(spec={"test": {"x": spec.Tensor([1, 10],
"float32")}}, debug=True)
tvm.ir.assert_structural_equal(irmodule["test"], test)
@@ -140,6 +144,7 @@ def test_tensor_op_manipulate():
# fmt: off
@R.function
def test(x: R.Tensor((2, 1, 10), dtype="float32"), _io: R.Object) ->
R.Tuple(R.Tuple(R.Tensor((2, 5, 2), dtype="float32"), R.Tensor((10, 1, 2),
dtype="float32"), R.Tensor((2, 2, 10), dtype="float32")), R.Tuple(R.Object)):
+ R.func_attr({"num_input": 2})
with R.dataflow():
reshape: R.Tensor((2, 5, 2), dtype="float32") = R.reshape(x,
R.shape([2, 5, 2]))
permute_dims: R.Tensor((10, 1, 2), dtype="float32") =
R.permute_dims(x, axes=[2, 1, 0])
@@ -150,7 +155,7 @@ def test_tensor_op_manipulate():
# fmt: on
m = Model()
- irmodule, params = m.export_tvm(spec={"test": {"x": spec.Tensor([2, 1,
10], "float32")}})
+ irmodule, _ = m.export_tvm(spec={"test": {"x": spec.Tensor([2, 1, 10],
"float32")}}, debug=True)
tvm.ir.assert_structural_equal(irmodule["test"], test)