This is an automated email from the ASF dual-hosted git repository.
yaxingcai pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new ea70b02961 [Unity] `nn.Mutator` for `nn.Module` level transform
(#15958)
ea70b02961 is described below
commit ea70b029619bf7480b965f44a54c6e3590819268
Author: Yaxing Cai <[email protected]>
AuthorDate: Sun Oct 22 00:28:32 2023 -0700
[Unity] `nn.Mutator` for `nn.Module` level transform (#15958)
* [Unity] `nn.Mutator` for `nn.Module` level transform
This PR introduces the `nn.Mutator` for the `nn.Module` level transform.
* apply code review suggestions
* pylint
---
python/tvm/relax/frontend/nn/__init__.py | 1 +
python/tvm/relax/frontend/nn/visitor.py | 135 +++++++++++++++
tests/python/relax/test_frontend_nn_mutator.py | 225 +++++++++++++++++++++++++
3 files changed, 361 insertions(+)
diff --git a/python/tvm/relax/frontend/nn/__init__.py
b/python/tvm/relax/frontend/nn/__init__.py
index 59cf32eaa8..d6db56d6e5 100644
--- a/python/tvm/relax/frontend/nn/__init__.py
+++ b/python/tvm/relax/frontend/nn/__init__.py
@@ -33,3 +33,4 @@ from .modules import (
)
from .op import *
from .subroutine import SubroutineMixin
+from .visitor import Mutator
diff --git a/python/tvm/relax/frontend/nn/visitor.py
b/python/tvm/relax/frontend/nn/visitor.py
new file mode 100644
index 0000000000..2e89a6672c
--- /dev/null
+++ b/python/tvm/relax/frontend/nn/visitor.py
@@ -0,0 +1,135 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""The visitor and mutator infra for nn.Module."""
+from typing import Any
+from . import core as nn
+
+
+class Mutator:
+ """The mutator for nn.Module transform. Users can override the `visit_*`
methods
+ to apply transform in different structures, or even override the `visit`
method
+ to change the logic of traversal."""
+
+ def visit_module(self, name: str, node: nn.Module) -> Any:
+ """The base visiting method for mutation of nn.Module nodes.
+
+ Parameters
+ ----------
+ name : str
+ The name of the current node in parent's attribute.
+
+ node : nn.Module
+ The current node of nn.Module to mutate.
+
+ Returns
+ ------
+ ret_node: Any
+ The new node to replace current node.
+ """
+ return self.visit(name, node)
+
+ def visit_effect(self, name: str, node: nn.Parameter) -> Any:
+ """The base visiting method for mutation of nn.Parameter nodes.
+
+ Parameters
+ ----------
+ name : str
+ The name of the current node in parent's attribute.
+
+ node : nn.Parameter
+ The current node of nn.Parameter to mutate.
+
+ Returns
+ ------
+ ret_node: Any
+ The new node to replace current node.
+ """
+ return self.visit(name, node)
+
+ def visit_param(self, name: str, node: nn.Effect) -> Any:
+ """The base visiting method for mutation of nn.Effect nodes.
+
+ Parameters
+ ----------
+ name : str
+ The name of the current node in parent's attribute.
+
+ node : nn.Effect
+ The current node of nn.Effect to mutate.
+
+ Returns
+ ------
+ ret_node: Any
+ The new node to replace current node.
+ """
+ return self.visit(name, node)
+
+ def visit_modulelist(self, name: str, node: nn.ModuleList) -> Any:
+ """The base visiting method for mutation of nn.ModuleList nodes.
+
+ Parameters
+ ----------
+ name : str
+ The name of the current node in parent's attribute.
+
+ node : nn.ModuleList
+ The current node of nn.MoModuleListdule to mutate.
+
+ Returns
+ ------
+ ret_node: Any
+ The new node to replace current node.
+ """
+ return self.visit(name, node)
+
+ def visit(self, name: str, node: Any) -> Any:
+ """The base dispatching method for visiting of all nodes.
+
+ Parameters
+ ----------
+ name : str
+ The name of the current node in parent's attribute.
+
+ node : Any
+ The current node to visit.
+
+ Returns
+ ------
+ ret_node: Any
+ The new node to replace current node.
+ """
+ if isinstance(node, nn.ModuleList):
+ for i in range(len(node)):
+ if isinstance(node[i], nn.ModuleList):
+ node[i] = self.visit_modulelist(f"{name}.{i}", node[i])
+ elif isinstance(node[i], nn.Module):
+ node[i] = self.visit_module(f"{name}.{i}", node[i])
+ elif isinstance(node[i], nn.Effect):
+ node[i] = self.visit_effect(f"{name}.{i}", node[i])
+ elif isinstance(node[i], nn.Parameter):
+ node[i] = self.visit_param(f"{name}.{i}", node[i])
+ else:
+ for key, value in node.__dict__.items():
+ if isinstance(value, nn.ModuleList):
+ setattr(node, key, self.visit_modulelist(f"{name}.{key}",
value))
+ elif isinstance(value, nn.Module):
+ setattr(node, key, self.visit_module(f"{name}.{key}",
value))
+ elif isinstance(value, nn.Effect):
+ setattr(node, key, self.visit_effect(f"{name}.{key}",
value))
+ elif isinstance(value, nn.Parameter):
+ setattr(node, key, self.visit_param(f"{name}.{key}",
value))
+ return node
diff --git a/tests/python/relax/test_frontend_nn_mutator.py
b/tests/python/relax/test_frontend_nn_mutator.py
new file mode 100644
index 0000000000..ffb6586159
--- /dev/null
+++ b/tests/python/relax/test_frontend_nn_mutator.py
@@ -0,0 +1,225 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import Any
+
+import tvm
+from tvm.relax.frontend import nn
+
+
+def test_mutator_naming_basic():
+ class Module0(nn.Module):
+ def __init__(self) -> None:
+ super().__init__()
+ self.param0 = nn.Parameter((32, 128), "float64")
+
+ class Module1(nn.Module):
+ def __init__(self) -> None:
+ super().__init__()
+ self.mod0 = Module0()
+ self.param1 = nn.Parameter((32, 128), "float32")
+
+ class Module2(nn.Module):
+ def __init__(self) -> None:
+ super().__init__()
+ self.mod1 = Module1()
+ self.param2 = nn.Parameter((32, 128), "float16")
+
+ class Module3(nn.Module):
+ def __init__(self) -> None:
+ super().__init__()
+ self.mod2 = Module2()
+ self.param3 = nn.Parameter((32, 128), "float8")
+
+ class Mutator(nn.Mutator):
+ def visit_param(self, name: str, node: nn.Parameter) -> Any:
+ if node.dtype == "float8":
+ assert name == "mod3.param3"
+ return node
+ elif node.dtype == "float16":
+ assert name == "mod3.mod2.param2"
+ return node
+ elif node.dtype == "float32":
+ assert name == "mod3.mod2.mod1.param1"
+ return node
+ elif node.dtype == "float64":
+ assert name == "mod3.mod2.mod1.mod0.param0"
+ return node
+
+ mod3 = Module3()
+ mutator = Mutator()
+ mutator.visit("mod3", mod3)
+
+
+def test_mutator_naming_modulelist():
+ class Module(nn.Module):
+ def __init__(self, dtype) -> None:
+ super().__init__()
+ self.param = nn.Parameter((32, 128), dtype)
+
+ class Mutator(nn.Mutator):
+ def visit_param(self, name: str, node: nn.Parameter) -> Any:
+ if node.dtype == "float64":
+ assert name == "mod_list.0.0.param"
+ return node
+ elif node.dtype == "float32":
+ assert name == "mod_list.0.1.param"
+ return node
+ elif node.dtype == "float16":
+ assert name == "mod_list.1.0.param"
+ return node
+ elif node.dtype == "float8":
+ assert name == "mod_list.1.1.param"
+ return node
+
+ mod_list = nn.ModuleList(
+ [
+ nn.ModuleList([Module("float64"), Module("float32")]),
+ nn.ModuleList([Module("float16"), Module("float8")]),
+ ]
+ )
+ mutator = Mutator()
+ mutator.visit("mod_list", mod_list)
+
+
+def test_mutator_module():
+ class SubModule1(nn.Module):
+ def __init__(self) -> None:
+ super().__init__()
+
+ class SubModule2(nn.Module):
+ def __init__(self) -> None:
+ super().__init__()
+
+ class Module(nn.Module):
+ def __init__(self) -> None:
+ super().__init__()
+ self.mod = SubModule1()
+
+ class Mutator(nn.Mutator):
+ def visit_module(self, name: str, node: nn.Module) -> Any:
+ if isinstance(node, SubModule1):
+ return SubModule2()
+ else:
+ return node
+
+ mutator = Mutator()
+ module = Module()
+ assert isinstance(module.mod, SubModule1)
+ module = mutator.visit("", module)
+ assert isinstance(module.mod, SubModule2)
+
+
+def test_mutator_modulelist():
+ class Module1(nn.Module):
+ def __init__(self) -> None:
+ super().__init__()
+
+ class Module2(nn.Module):
+ def __init__(self) -> None:
+ super().__init__()
+
+ class Module3(nn.Module):
+ def __init__(self) -> None:
+ super().__init__()
+
+ class Mutator(nn.Mutator):
+ def visit_module(self, name: str, node: nn.Module) -> Any:
+ if isinstance(node, Module3):
+ return Module1()
+ else:
+ return node
+
+ mutator = Mutator()
+ module_list = nn.ModuleList([Module1(), Module2(), Module3()])
+ assert isinstance(module_list[0], Module1)
+ assert isinstance(module_list[1], Module2)
+ assert isinstance(module_list[2], Module3)
+ module_list = mutator.visit("", module_list)
+ print(module_list[2])
+ assert isinstance(module_list[0], Module1)
+ assert isinstance(module_list[1], Module2)
+ assert isinstance(module_list[2], Module1)
+
+
+def test_mutator_effect():
+ class Effect1(nn.Effect):
+ pass
+
+ class Effect2(nn.Effect):
+ pass
+
+ class Module(nn.Module):
+ def __init__(self) -> None:
+ super().__init__()
+ self.effect = Effect1()
+
+ class Mutator(nn.Mutator):
+ def visit_effect(self, name: str, node: nn.Effect) -> Any:
+ if isinstance(node, Effect1):
+ return Effect2()
+
+ mutator = Mutator()
+ module = Module()
+ assert isinstance(module.effect, Effect1)
+ module = mutator.visit("", module)
+ assert isinstance(module.effect, Effect2)
+
+
+def test_mutator_param():
+ class Module(nn.Module):
+ def __init__(self) -> None:
+ super().__init__()
+ self.weight = nn.Parameter((128, 64), "float16")
+
+ class Mutator(nn.Mutator):
+ def visit_param(self, name: str, node: nn.Parameter) -> Any:
+ if node.dtype == "float16":
+ return nn.Parameter(node.shape, "float32")
+
+ mutator = Mutator()
+ module = Module()
+ assert module.weight.dtype == "float16"
+ module = mutator.visit("", module)
+ assert module.weight.dtype == "float32"
+
+
+def test_mutator_recursively():
+ class SubModule(nn.Module):
+ def __init__(self) -> None:
+ super().__init__()
+ self.weight = nn.Parameter((128, 64), "float16")
+
+ class Module(nn.Module):
+ def __init__(self) -> None:
+ super().__init__()
+ self.mod = SubModule()
+
+ class Mutator(nn.Mutator):
+ def visit_param(self, name: str, node: nn.Parameter) -> Any:
+ if node.dtype == "float16":
+ return nn.Parameter(node.shape, "float32")
+
+ mutator = Mutator()
+ module = Module()
+ assert module.mod.weight.dtype == "float16"
+ module = mutator.visit("", module)
+ assert module.mod.weight.dtype == "float32"
+
+
+if __name__ == "__main__":
+ tvm.testing.main()