This is an automated email from the ASF dual-hosted git repository.

yaxingcai pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new ea70b02961 [Unity] `nn.Mutator` for `nn.Module` level transform 
(#15958)
ea70b02961 is described below

commit ea70b029619bf7480b965f44a54c6e3590819268
Author: Yaxing Cai <[email protected]>
AuthorDate: Sun Oct 22 00:28:32 2023 -0700

    [Unity] `nn.Mutator` for `nn.Module` level transform (#15958)
    
    * [Unity] `nn.Mutator` for `nn.Module` level transform
    
    This PR introduces the `nn.Mutator` for the `nn.Module` level transform.
    
    * apply code review suggestions
    
    * pylint
---
 python/tvm/relax/frontend/nn/__init__.py       |   1 +
 python/tvm/relax/frontend/nn/visitor.py        | 135 +++++++++++++++
 tests/python/relax/test_frontend_nn_mutator.py | 225 +++++++++++++++++++++++++
 3 files changed, 361 insertions(+)

diff --git a/python/tvm/relax/frontend/nn/__init__.py 
b/python/tvm/relax/frontend/nn/__init__.py
index 59cf32eaa8..d6db56d6e5 100644
--- a/python/tvm/relax/frontend/nn/__init__.py
+++ b/python/tvm/relax/frontend/nn/__init__.py
@@ -33,3 +33,4 @@ from .modules import (
 )
 from .op import *
 from .subroutine import SubroutineMixin
+from .visitor import Mutator
diff --git a/python/tvm/relax/frontend/nn/visitor.py 
b/python/tvm/relax/frontend/nn/visitor.py
new file mode 100644
index 0000000000..2e89a6672c
--- /dev/null
+++ b/python/tvm/relax/frontend/nn/visitor.py
@@ -0,0 +1,135 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""The visitor and mutator infra for nn.Module."""
+from typing import Any
+from . import core as nn
+
+
+class Mutator:
+    """The mutator for nn.Module transform. Users can override the `visit_*` 
methods
+    to apply transform in different structures, or even override the `visit` 
method
+    to change the logic of traversal."""
+
+    def visit_module(self, name: str, node: nn.Module) -> Any:
+        """The base visiting method for mutation of nn.Module nodes.
+
+        Parameters
+        ----------
+        name : str
+            The name of the current node in parent's attribute.
+
+        node : nn.Module
+            The current node of nn.Module to mutate.
+
+        Returns
+        ------
+        ret_node: Any
+            The new node to replace current node.
+        """
+        return self.visit(name, node)
+
+    def visit_effect(self, name: str, node: nn.Parameter) -> Any:
+        """The base visiting method for mutation of nn.Parameter nodes.
+
+        Parameters
+        ----------
+        name : str
+            The name of the current node in parent's attribute.
+
+        node : nn.Parameter
+            The current node of nn.Parameter to mutate.
+
+        Returns
+        ------
+        ret_node: Any
+            The new node to replace current node.
+        """
+        return self.visit(name, node)
+
+    def visit_param(self, name: str, node: nn.Effect) -> Any:
+        """The base visiting method for mutation of nn.Effect nodes.
+
+        Parameters
+        ----------
+        name : str
+            The name of the current node in parent's attribute.
+
+        node : nn.Effect
+            The current node of nn.Effect to mutate.
+
+        Returns
+        ------
+        ret_node: Any
+            The new node to replace current node.
+        """
+        return self.visit(name, node)
+
+    def visit_modulelist(self, name: str, node: nn.ModuleList) -> Any:
+        """The base visiting method for mutation of nn.ModuleList nodes.
+
+        Parameters
+        ----------
+        name : str
+            The name of the current node in parent's attribute.
+
+        node : nn.ModuleList
+            The current node of nn.MoModuleListdule to mutate.
+
+        Returns
+        ------
+        ret_node: Any
+            The new node to replace current node.
+        """
+        return self.visit(name, node)
+
+    def visit(self, name: str, node: Any) -> Any:
+        """The base dispatching method for visiting of all nodes.
+
+        Parameters
+        ----------
+        name : str
+            The name of the current node in parent's attribute.
+
+        node : Any
+            The current node to visit.
+
+        Returns
+        ------
+        ret_node: Any
+            The new node to replace current node.
+        """
+        if isinstance(node, nn.ModuleList):
+            for i in range(len(node)):
+                if isinstance(node[i], nn.ModuleList):
+                    node[i] = self.visit_modulelist(f"{name}.{i}", node[i])
+                elif isinstance(node[i], nn.Module):
+                    node[i] = self.visit_module(f"{name}.{i}", node[i])
+                elif isinstance(node[i], nn.Effect):
+                    node[i] = self.visit_effect(f"{name}.{i}", node[i])
+                elif isinstance(node[i], nn.Parameter):
+                    node[i] = self.visit_param(f"{name}.{i}", node[i])
+        else:
+            for key, value in node.__dict__.items():
+                if isinstance(value, nn.ModuleList):
+                    setattr(node, key, self.visit_modulelist(f"{name}.{key}", 
value))
+                elif isinstance(value, nn.Module):
+                    setattr(node, key, self.visit_module(f"{name}.{key}", 
value))
+                elif isinstance(value, nn.Effect):
+                    setattr(node, key, self.visit_effect(f"{name}.{key}", 
value))
+                elif isinstance(value, nn.Parameter):
+                    setattr(node, key, self.visit_param(f"{name}.{key}", 
value))
+        return node
diff --git a/tests/python/relax/test_frontend_nn_mutator.py 
b/tests/python/relax/test_frontend_nn_mutator.py
new file mode 100644
index 0000000000..ffb6586159
--- /dev/null
+++ b/tests/python/relax/test_frontend_nn_mutator.py
@@ -0,0 +1,225 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import Any
+
+import tvm
+from tvm.relax.frontend import nn
+
+
+def test_mutator_naming_basic():
+    class Module0(nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+            self.param0 = nn.Parameter((32, 128), "float64")
+
+    class Module1(nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+            self.mod0 = Module0()
+            self.param1 = nn.Parameter((32, 128), "float32")
+
+    class Module2(nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+            self.mod1 = Module1()
+            self.param2 = nn.Parameter((32, 128), "float16")
+
+    class Module3(nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+            self.mod2 = Module2()
+            self.param3 = nn.Parameter((32, 128), "float8")
+
+    class Mutator(nn.Mutator):
+        def visit_param(self, name: str, node: nn.Parameter) -> Any:
+            if node.dtype == "float8":
+                assert name == "mod3.param3"
+                return node
+            elif node.dtype == "float16":
+                assert name == "mod3.mod2.param2"
+                return node
+            elif node.dtype == "float32":
+                assert name == "mod3.mod2.mod1.param1"
+                return node
+            elif node.dtype == "float64":
+                assert name == "mod3.mod2.mod1.mod0.param0"
+                return node
+
+    mod3 = Module3()
+    mutator = Mutator()
+    mutator.visit("mod3", mod3)
+
+
+def test_mutator_naming_modulelist():
+    class Module(nn.Module):
+        def __init__(self, dtype) -> None:
+            super().__init__()
+            self.param = nn.Parameter((32, 128), dtype)
+
+    class Mutator(nn.Mutator):
+        def visit_param(self, name: str, node: nn.Parameter) -> Any:
+            if node.dtype == "float64":
+                assert name == "mod_list.0.0.param"
+                return node
+            elif node.dtype == "float32":
+                assert name == "mod_list.0.1.param"
+                return node
+            elif node.dtype == "float16":
+                assert name == "mod_list.1.0.param"
+                return node
+            elif node.dtype == "float8":
+                assert name == "mod_list.1.1.param"
+                return node
+
+    mod_list = nn.ModuleList(
+        [
+            nn.ModuleList([Module("float64"), Module("float32")]),
+            nn.ModuleList([Module("float16"), Module("float8")]),
+        ]
+    )
+    mutator = Mutator()
+    mutator.visit("mod_list", mod_list)
+
+
+def test_mutator_module():
+    class SubModule1(nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+
+    class SubModule2(nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+
+    class Module(nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+            self.mod = SubModule1()
+
+    class Mutator(nn.Mutator):
+        def visit_module(self, name: str, node: nn.Module) -> Any:
+            if isinstance(node, SubModule1):
+                return SubModule2()
+            else:
+                return node
+
+    mutator = Mutator()
+    module = Module()
+    assert isinstance(module.mod, SubModule1)
+    module = mutator.visit("", module)
+    assert isinstance(module.mod, SubModule2)
+
+
+def test_mutator_modulelist():
+    class Module1(nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+
+    class Module2(nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+
+    class Module3(nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+
+    class Mutator(nn.Mutator):
+        def visit_module(self, name: str, node: nn.Module) -> Any:
+            if isinstance(node, Module3):
+                return Module1()
+            else:
+                return node
+
+    mutator = Mutator()
+    module_list = nn.ModuleList([Module1(), Module2(), Module3()])
+    assert isinstance(module_list[0], Module1)
+    assert isinstance(module_list[1], Module2)
+    assert isinstance(module_list[2], Module3)
+    module_list = mutator.visit("", module_list)
+    print(module_list[2])
+    assert isinstance(module_list[0], Module1)
+    assert isinstance(module_list[1], Module2)
+    assert isinstance(module_list[2], Module1)
+
+
+def test_mutator_effect():
+    class Effect1(nn.Effect):
+        pass
+
+    class Effect2(nn.Effect):
+        pass
+
+    class Module(nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+            self.effect = Effect1()
+
+    class Mutator(nn.Mutator):
+        def visit_effect(self, name: str, node: nn.Effect) -> Any:
+            if isinstance(node, Effect1):
+                return Effect2()
+
+    mutator = Mutator()
+    module = Module()
+    assert isinstance(module.effect, Effect1)
+    module = mutator.visit("", module)
+    assert isinstance(module.effect, Effect2)
+
+
+def test_mutator_param():
+    class Module(nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+            self.weight = nn.Parameter((128, 64), "float16")
+
+    class Mutator(nn.Mutator):
+        def visit_param(self, name: str, node: nn.Parameter) -> Any:
+            if node.dtype == "float16":
+                return nn.Parameter(node.shape, "float32")
+
+    mutator = Mutator()
+    module = Module()
+    assert module.weight.dtype == "float16"
+    module = mutator.visit("", module)
+    assert module.weight.dtype == "float32"
+
+
+def test_mutator_recursively():
+    class SubModule(nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+            self.weight = nn.Parameter((128, 64), "float16")
+
+    class Module(nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+            self.mod = SubModule()
+
+    class Mutator(nn.Mutator):
+        def visit_param(self, name: str, node: nn.Parameter) -> Any:
+            if node.dtype == "float16":
+                return nn.Parameter(node.shape, "float32")
+
+    mutator = Mutator()
+    module = Module()
+    assert module.mod.weight.dtype == "float16"
+    module = mutator.visit("", module)
+    assert module.mod.weight.dtype == "float32"
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to