This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new f215a417bf [Unity][NN] Use Linear name for nn.op.permute_dims (#16303)
f215a417bf is described below

commit f215a417bf86ad8d5eb2c74b3a98719c86a915ed
Author: Eric Lunderberg <[email protected]>
AuthorDate: Fri Jan 5 12:19:05 2024 -0600

    [Unity][NN] Use Linear name for nn.op.permute_dims (#16303)
    
    The `relax::op::linear` is implemented as `permute_dims`, followed by
    `matmul`.  In this case, readability can be improved by naming the
    weights.
---
 python/tvm/relax/frontend/nn/op.py             |  9 ++++-
 tests/python/relax/test_frontend_nn_packing.py | 51 +++++++++++++-------------
 2 files changed, 33 insertions(+), 27 deletions(-)

diff --git a/python/tvm/relax/frontend/nn/op.py 
b/python/tvm/relax/frontend/nn/op.py
index 1d3454fc88..ac5858d5cd 100644
--- a/python/tvm/relax/frontend/nn/op.py
+++ b/python/tvm/relax/frontend/nn/op.py
@@ -577,7 +577,7 @@ def broadcast_to(x: Tensor, shape: Sequence[IntExpr], name: 
str = "broadcast_to"
     return wrap_nested(_op.broadcast_to(x._expr, shape), name)
 
 
-def permute_dims(x: Tensor, axes: Optional[List[int]] = None, name: str = 
"permute_dims") -> Tensor:
+def permute_dims(x: Tensor, axes: Optional[List[int]] = None, name: str = 
None) -> Tensor:
     """Permutes the dimensions of an array.
 
     Parameters
@@ -596,6 +596,13 @@ def permute_dims(x: Tensor, axes: Optional[List[int]] = 
None, name: str = "permu
     result : Tensor
         The transposed result.
     """
+    if name is None:
+        x_name = getattr(getattr(x, "_expr", None), "name_hint", None)
+        if x_name is not None and "linear" in x_name:
+            name = x_name.replace("linear", "matmul")
+        else:
+            name = "permute_dims"
+
     return wrap_nested(_op.permute_dims(x._expr, axes=axes), name)
 
 
diff --git a/tests/python/relax/test_frontend_nn_packing.py 
b/tests/python/relax/test_frontend_nn_packing.py
index 00f981d1d4..56b614a807 100644
--- a/tests/python/relax/test_frontend_nn_packing.py
+++ b/tests/python/relax/test_frontend_nn_packing.py
@@ -21,7 +21,14 @@ from tvm.script import ir as I
 from tvm.script import relax as R
 
 
-def main():
+def _iter_binding_names(mod):
+    """Helper function to compare the names of relax variables"""
+    for block in mod["forward"].body.blocks:
+        for binding in block.bindings:
+            yield binding.var.name_hint
+
+
+def test_nn_export_to_relax():
     class TestModule(nn.Module):
         def __init__(self, in_features: int, out_features: int):
             super().__init__()
@@ -35,39 +42,28 @@ def main():
             x2 = self.linear_2(x)
             return x1 + x2
 
-    # pylint: disable=line-too-long
     @I.ir_module
-    class ExpectedModule:  # pylint: disable=too-few-public-methods
+    class ExpectedModule:
         @R.function
         def forward(
             x: R.Tensor((1, 10), dtype="float32"),
             packed_params: R.Tuple(
                 R.Tensor((20, 10), dtype="float32"), R.Tensor((20, 10), 
dtype="float32")
             ),
-        ) -> R.Tensor((1, 20), dtype="float32"):
-            R.func_attr({"num_input": 1})  # type: ignore[attr-defined]
-            with R.dataflow():  # type: ignore[attr-defined]
-                linear_1_weight: R.Tensor((20, 10), dtype="float32") = 
packed_params[0]  # type: ignore[valid-type]
-                linear_2_weight: R.Tensor((20, 10), dtype="float32") = 
packed_params[1]  # type: ignore[valid-type]
-                permute_dims: R.Tensor((10, 20), dtype="float32") = 
R.permute_dims(  # type: ignore[attr-defined,valid-type]
-                    linear_1_weight, axes=None
-                )
-                matmul: R.Tensor((1, 20), dtype="float32") = R.matmul(  # 
type: ignore[attr-defined,valid-type]
-                    x, permute_dims, out_dtype="void"
-                )
-                permute_dims1: R.Tensor((10, 20), dtype="float32") = 
R.permute_dims(  # type: ignore[attr-defined,valid-type]
-                    linear_2_weight, axes=None
-                )
-                matmul1: R.Tensor((1, 20), dtype="float32") = R.matmul(  # 
type: ignore[attr-defined,valid-type]
-                    x, permute_dims1, out_dtype="void"
-                )
-                add: R.Tensor((1, 20), dtype="float32") = R.add(matmul, 
matmul1)  # type: ignore[attr-defined,valid-type]
-                gv: R.Tensor((1, 20), dtype="float32") = add  # type: 
ignore[attr-defined,valid-type]
-                R.output(gv)  # type: ignore[attr-defined,valid-type]
+        ):
+            R.func_attr({"num_input": 1})
+            with R.dataflow():
+                linear_1_weight = packed_params[0]
+                linear_2_weight = packed_params[1]
+                matmul_1_weight = R.permute_dims(linear_1_weight)
+                matmul = R.matmul(x, matmul_1_weight)
+                matmul_2_weight = R.permute_dims(linear_2_weight)
+                matmul1 = R.matmul(x, matmul_2_weight)
+                add = R.add(matmul, matmul1)
+                gv = add
+                R.output(gv)
             return gv
 
-    # pylint: enable=line-too-long
-
     model = TestModule(10, 20)
     mod, _ = model.export_tvm(
         spec={
@@ -82,6 +78,9 @@ def main():
     )
     tvm.ir.assert_structural_equal(mod, ExpectedModule)
 
+    for name, expected_name in zip(_iter_binding_names(mod), 
_iter_binding_names(ExpectedModule)):
+        assert name == expected_name
+
 
 if __name__ == "__main__":
-    main()
+    tvm.testing.main()

Reply via email to