This is an automated email from the ASF dual-hosted git repository.

leandron pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ac9bf1f  [TVMC] Add configuration `tir.add_lower_pass` to option 
`--pass-config` (#9817)
ac9bf1f is described below

commit ac9bf1fb8f67c61f1f8005ae0064b9acd924f4ea
Author: Colin Y. Li <c...@live.com>
AuthorDate: Mon Feb 14 22:14:24 2022 +0800

    [TVMC] Add configuration `tir.add_lower_pass` to option `--pass-config` 
(#9817)
---
 python/tvm/driver/tvmc/compiler.py           |  3 +-
 python/tvm/driver/tvmc/pass_config.py        | 82 ++++++++++++++++++++++++--
 tests/python/driver/tvmc/test_pass_config.py | 88 ++++++++++++++++++++++++++++
 3 files changed, 167 insertions(+), 6 deletions(-)

diff --git a/python/tvm/driver/tvmc/compiler.py 
b/python/tvm/driver/tvmc/compiler.py
index d260c98..df56e3b 100644
--- a/python/tvm/driver/tvmc/compiler.py
+++ b/python/tvm/driver/tvmc/compiler.py
@@ -95,7 +95,8 @@ def add_compile_parser(subparsers, _):
         metavar=("name=value"),
         help="configurations to be used at compile time. This option can be 
provided multiple "
         "times, each one to set one configuration value, "
-        "e.g. '--pass-config relay.backend.use_auto_scheduler=0'.",
+        "e.g. '--pass-config relay.backend.use_auto_scheduler=0', "
+        "e.g. '--pass-config 
tir.add_lower_pass=opt_level1,pass1,opt_level2,pass2'.",
     )
 
     generate_target_args(parser)
diff --git a/python/tvm/driver/tvmc/pass_config.py 
b/python/tvm/driver/tvmc/pass_config.py
index 7cf0f01..dde5b9c 100644
--- a/python/tvm/driver/tvmc/pass_config.py
+++ b/python/tvm/driver/tvmc/pass_config.py
@@ -18,10 +18,41 @@
 TVMC PassContext Interface
 """
 
+import importlib
+
 import tvm
 from tvm.driver.tvmc import TVMCException
 
 
+def load_function(full_name):
+    """Dynamic loading a function by the full name.
+    Parameters
+    ----------
+    full_name: str
+        The name of a PackedFunc or a string of the form "path.to.module.func"
+        that indicates the module that can be imported.
+        You must be aware of the load order here, it first tries to find it via
+        TVM global function, if not find, try to import it by 
"importlib.import_module".
+    Returns
+    -------
+    func: function or PackedFunc
+        The loaded fucntion.
+    """
+    global_func = tvm.get_global_func(full_name, allow_missing=True)
+    if global_func is not None:
+        return global_func
+
+    # split full name "path.to.module.func" into two parts ["path.to.module", 
"func"]
+    module_name, func_name = full_name.rsplit(".", 1)
+
+    # import module and find the function
+    module = importlib.import_module(module_name)
+    if hasattr(module, func_name):
+        return getattr(module, func_name)
+
+    raise TVMCException(f"No function '{func_name}' found in module 
'{module_name}'.")
+
+
 def get_pass_config_value(name, value, config_type):
     """Get a PassContext configuration value, based on its config data type.
 
@@ -41,6 +72,8 @@ def get_pass_config_value(name, value, config_type):
         specified by config_type.
     """
 
+    parsed_value = None
+
     if config_type == "IntImm":
         # "Bool" configurations in the PassContext are recognized as
         # IntImm, so deal with this case here
@@ -56,11 +89,44 @@ def get_pass_config_value(name, value, config_type):
             parsed_value = mapping_values.get(value.lower(), None)
 
         if parsed_value is None:
-            raise TVMCException(f"Invalid value '{value}' for configuration 
'{name}'. ")
+            raise TVMCException(f"Invalid value '{value}' for configuration 
'{name}'.")
 
-    if config_type == "runtime.String":
+    elif config_type == "runtime.String":
         parsed_value = value
 
+    elif config_type == "Array":
+        if name == "tir.add_lower_pass":
+            pass_list = value.split(",")
+            if len(pass_list) % 2 != 0:
+                raise TVMCException(
+                    f"The configuration of '{name}' must be of the form "
+                    "'tir.add_lower_pass=opt_level1,pass1,opt_evel2,pass2'"
+                )
+
+            parsed_value = []
+            for i in range(0, len(pass_list), 2):
+                level, pass_func = pass_list[i].strip(), pass_list[i + 
1].strip()
+                try:
+                    level = int(level)
+                except ValueError:
+                    raise TVMCException(f"Only integer is allow for 
configuration '{name}'.")
+
+                # TODO (@leeexyz) We should parse configurations of each tir 
Pass.
+                #     For now, we only use the defaults. Currently, There are 
four config nodes:
+                #     `tir.transform.LoopPartitionConfig`
+                #     `tir.transform.UnrollLoopConfig`
+                #     `tir.transform.HoistIfThenElseConfig`
+                #     `tir.transform.InjectDoubleBufferConfig`
+                # loading pass func and calling it to get the Pass
+                pass_func = load_function(pass_func)()
+                parsed_value.append((level, pass_func))
+        else:
+            raise TVMCException(f"Unsupported configuration '{name}' for 
'{config_type}' type.")
+
+    else:
+        # not raise here cause we alreay checked before calling this function
+        pass
+
     return parsed_value
 
 
@@ -81,7 +147,7 @@ def parse_configs(input_configs):
         return {}
 
     all_configs = tvm.ir.transform.PassContext.list_configs()
-    supported_config_types = ("IntImm", "runtime.String")
+    supported_config_types = ("IntImm", "runtime.String", "Array")
     supported_configs = [
         name for name in all_configs.keys() if all_configs[name]["type"] in 
supported_config_types
     ]
@@ -116,7 +182,13 @@ def parse_configs(input_configs):
                 f"The following configurations are supported: {', 
'.join(supported_configs)}"
             )
 
-        parsed_value = get_pass_config_value(name, value, 
all_configs[name]["type"])
-        pass_context_configs[name] = parsed_value
+        config_type = all_configs[name]["type"]
+        parsed_value = get_pass_config_value(name, value, config_type)
+
+        if config_type == "Array" and name in pass_context_configs:
+            # merge configs if the configuration exists
+            pass_context_configs[name].extend(parsed_value)
+        else:
+            pass_context_configs[name] = parsed_value
 
     return pass_context_configs
diff --git a/tests/python/driver/tvmc/test_pass_config.py 
b/tests/python/driver/tvmc/test_pass_config.py
index bb815e1..f928c8a 100644
--- a/tests/python/driver/tvmc/test_pass_config.py
+++ b/tests/python/driver/tvmc/test_pass_config.py
@@ -16,11 +16,13 @@
 # under the License.
 
 import pytest
+from unittest import mock
 
 from tvm.contrib.target.vitis_ai import vitis_ai_available
 
 from tvm.driver.tvmc import TVMCException
 from tvm.driver.tvmc.pass_config import parse_configs
+from tvm.tir.transform import PrimFuncPass
 
 
 def test_config_invalid_format():
@@ -71,3 +73,89 @@ def test_config_valid_multiple_configs():
     assert configs["tir.detect_global_barrier"] == 10
     assert "relay.ext.vitis_ai.options.build_dir" in configs.keys()
     assert configs["relay.ext.vitis_ai.options.build_dir"] == "mystring"
+
+
+def test_add_lower_pass_multi_built_in_pass():
+    configs = parse_configs(
+        [
+            "tir.add_lower_pass=1,tir.transform.UnrollLoop",
+            
"tir.add_lower_pass=1,tir.transform.HoistIfThenElse,2,tir.transform.LoopPartition",
+        ]
+    )
+
+    assert len(configs["tir.add_lower_pass"]) == 3
+    # opt_level: 1, pass: tir.transform.UnrollLoop
+    assert configs["tir.add_lower_pass"][0][0] == 1
+    assert isinstance(configs["tir.add_lower_pass"][0][1], PrimFuncPass)
+    # opt_level: 1, pass: tir.transform.HoistIfThenElse
+    assert configs["tir.add_lower_pass"][1][0] == 1
+    assert isinstance(configs["tir.add_lower_pass"][1][1], PrimFuncPass)
+    # opt_level: 2, pass: tir.transform.LoopPartition
+    assert configs["tir.add_lower_pass"][2][0] == 2
+    assert isinstance(configs["tir.add_lower_pass"][2][1], PrimFuncPass)
+
+
+def test_add_lower_pass_multi_external_pass():
+    fake_pass_1 = mock.MagicMock()
+    fake_pass_2 = mock.MagicMock()
+    fake_pass_3 = mock.MagicMock()
+    with mock.patch.dict(
+        "sys.modules",
+        {"fake_module": fake_pass_1, "fake_module": fake_pass_2, 
"fake_module": fake_pass_3},
+    ):
+        configs = parse_configs(
+            [
+                
"tir.add_lower_pass=1,fake_module.fake_pass_1,2,fake_module.fake_pass2",
+                "tir.add_lower_pass=3,fake_module.fake_pass_3",
+            ]
+        )
+        assert len(configs["tir.add_lower_pass"]) == 3
+        # opt_level: 1, pass: fake_module.fake_pass_1
+        assert configs["tir.add_lower_pass"][0][0] == 1
+        # opt_level: 2, pass: fake_module.fake_pass_2
+        assert configs["tir.add_lower_pass"][1][0] == 2
+        # opt_level: 3, pass: fake_module.fake_pass_3
+        assert configs["tir.add_lower_pass"][2][0] == 3
+
+
+def test_add_lower_pass_multi_mix_pass():
+    fake_pass_1 = mock.MagicMock()
+    fake_pass_2 = mock.MagicMock()
+    with mock.patch.dict("sys.modules", {"fake_module": fake_pass_1, 
"fake_module": fake_pass_2}):
+        configs = parse_configs(
+            [
+                
"tir.add_lower_pass=1,fake_module.fake_pass_1,1,tir.transform.UnrollLoop",
+                
"tir.add_lower_pass=2,fake_module.fake_pass_2,2,tir.transform.LoopPartition",
+            ]
+        )
+        assert len(configs["tir.add_lower_pass"]) == 4
+        # opt_level: 1, pass: fake_module.fake_pass_1
+        assert configs["tir.add_lower_pass"][0][0] == 1
+        # opt_level: 1, pass: tir.transform.UnrollLoop
+        assert configs["tir.add_lower_pass"][1][0] == 1
+        assert isinstance(configs["tir.add_lower_pass"][1][1], PrimFuncPass)
+        # opt_level: 2, pass: fake_module.fake_pass_2
+        assert configs["tir.add_lower_pass"][2][0] == 2
+        # opt_level: 2, pass: tir.transform.LoopPartition
+        assert configs["tir.add_lower_pass"][3][0] == 2
+        assert isinstance(configs["tir.add_lower_pass"][3][1], PrimFuncPass)
+
+
+def test_add_lower_pass_invalid_format():
+    # wrong format
+    with pytest.raises(TVMCException):
+        _ = parse_configs(["tir.add_lower_pass=tir.transform.UnrollLoop,1"])
+    # missing pass name
+    with pytest.raises(TVMCException):
+        _ = parse_configs(["tir.add_lower_pass=1,tir.transform.UnrollLoop,3"])
+    # wrong opt level
+    with pytest.raises(TVMCException):
+        _ = parse_configs(["tir.add_lower_pass=a,tir.transform.UnrollLoop"])
+    # fake module
+    with pytest.raises(ModuleNotFoundError):
+        _ = parse_configs(
+            
["tir.add_lower_pass=1,tir.transform.UnrollLoop,2,path.to.module.fake_func"]
+        )
+    # real module and fake func
+    with pytest.raises(TVMCException):
+        _ = 
parse_configs(["tir.add_lower_pass=1,tir.transform.UnrollLoop,2,tvm.tir.fake_func"])

Reply via email to