This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new f215a417bf [Unity][NN] Use Linear name for nn.op.permute_dims (#16303)
f215a417bf is described below
commit f215a417bf86ad8d5eb2c74b3a98719c86a915ed
Author: Eric Lunderberg <[email protected]>
AuthorDate: Fri Jan 5 12:19:05 2024 -0600
[Unity][NN] Use Linear name for nn.op.permute_dims (#16303)
The `relax::op::linear` is implemented as `permute_dims`, followed by
`matmul`. In this case, readability can be improved by naming the
weights.
---
python/tvm/relax/frontend/nn/op.py | 9 ++++-
tests/python/relax/test_frontend_nn_packing.py | 51 +++++++++++++-------------
2 files changed, 33 insertions(+), 27 deletions(-)
diff --git a/python/tvm/relax/frontend/nn/op.py
b/python/tvm/relax/frontend/nn/op.py
index 1d3454fc88..ac5858d5cd 100644
--- a/python/tvm/relax/frontend/nn/op.py
+++ b/python/tvm/relax/frontend/nn/op.py
@@ -577,7 +577,7 @@ def broadcast_to(x: Tensor, shape: Sequence[IntExpr], name:
str = "broadcast_to"
return wrap_nested(_op.broadcast_to(x._expr, shape), name)
-def permute_dims(x: Tensor, axes: Optional[List[int]] = None, name: str =
"permute_dims") -> Tensor:
+def permute_dims(x: Tensor, axes: Optional[List[int]] = None, name: str =
None) -> Tensor:
"""Permutes the dimensions of an array.
Parameters
@@ -596,6 +596,13 @@ def permute_dims(x: Tensor, axes: Optional[List[int]] =
None, name: str = "permu
result : Tensor
The transposed result.
"""
+ if name is None:
+ x_name = getattr(getattr(x, "_expr", None), "name_hint", None)
+ if x_name is not None and "linear" in x_name:
+ name = x_name.replace("linear", "matmul")
+ else:
+ name = "permute_dims"
+
return wrap_nested(_op.permute_dims(x._expr, axes=axes), name)
diff --git a/tests/python/relax/test_frontend_nn_packing.py
b/tests/python/relax/test_frontend_nn_packing.py
index 00f981d1d4..56b614a807 100644
--- a/tests/python/relax/test_frontend_nn_packing.py
+++ b/tests/python/relax/test_frontend_nn_packing.py
@@ -21,7 +21,14 @@ from tvm.script import ir as I
from tvm.script import relax as R
-def main():
+def _iter_binding_names(mod):
+ """Helper function to compare the names of relax variables"""
+ for block in mod["forward"].body.blocks:
+ for binding in block.bindings:
+ yield binding.var.name_hint
+
+
+def test_nn_export_to_relax():
class TestModule(nn.Module):
def __init__(self, in_features: int, out_features: int):
super().__init__()
@@ -35,39 +42,28 @@ def main():
x2 = self.linear_2(x)
return x1 + x2
- # pylint: disable=line-too-long
@I.ir_module
- class ExpectedModule: # pylint: disable=too-few-public-methods
+ class ExpectedModule:
@R.function
def forward(
x: R.Tensor((1, 10), dtype="float32"),
packed_params: R.Tuple(
R.Tensor((20, 10), dtype="float32"), R.Tensor((20, 10),
dtype="float32")
),
- ) -> R.Tensor((1, 20), dtype="float32"):
- R.func_attr({"num_input": 1}) # type: ignore[attr-defined]
- with R.dataflow(): # type: ignore[attr-defined]
- linear_1_weight: R.Tensor((20, 10), dtype="float32") =
packed_params[0] # type: ignore[valid-type]
- linear_2_weight: R.Tensor((20, 10), dtype="float32") =
packed_params[1] # type: ignore[valid-type]
- permute_dims: R.Tensor((10, 20), dtype="float32") =
R.permute_dims( # type: ignore[attr-defined,valid-type]
- linear_1_weight, axes=None
- )
- matmul: R.Tensor((1, 20), dtype="float32") = R.matmul( #
type: ignore[attr-defined,valid-type]
- x, permute_dims, out_dtype="void"
- )
- permute_dims1: R.Tensor((10, 20), dtype="float32") =
R.permute_dims( # type: ignore[attr-defined,valid-type]
- linear_2_weight, axes=None
- )
- matmul1: R.Tensor((1, 20), dtype="float32") = R.matmul( #
type: ignore[attr-defined,valid-type]
- x, permute_dims1, out_dtype="void"
- )
- add: R.Tensor((1, 20), dtype="float32") = R.add(matmul,
matmul1) # type: ignore[attr-defined,valid-type]
- gv: R.Tensor((1, 20), dtype="float32") = add # type:
ignore[attr-defined,valid-type]
- R.output(gv) # type: ignore[attr-defined,valid-type]
+ ):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ linear_1_weight = packed_params[0]
+ linear_2_weight = packed_params[1]
+ matmul_1_weight = R.permute_dims(linear_1_weight)
+ matmul = R.matmul(x, matmul_1_weight)
+ matmul_2_weight = R.permute_dims(linear_2_weight)
+ matmul1 = R.matmul(x, matmul_2_weight)
+ add = R.add(matmul, matmul1)
+ gv = add
+ R.output(gv)
return gv
- # pylint: enable=line-too-long
-
model = TestModule(10, 20)
mod, _ = model.export_tvm(
spec={
@@ -82,6 +78,9 @@ def main():
)
tvm.ir.assert_structural_equal(mod, ExpectedModule)
+ for name, expected_name in zip(_iter_binding_names(mod),
_iter_binding_names(ExpectedModule)):
+ assert name == expected_name
+
if __name__ == "__main__":
- main()
+ tvm.testing.main()