This is an automated email from the ASF dual-hosted git repository.
leandron pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new ac9bf1f [TVMC] Add configuration `tir.add_lower_pass` to option
`--pass-config` (#9817)
ac9bf1f is described below
commit ac9bf1fb8f67c61f1f8005ae0064b9acd924f4ea
Author: Colin Y. Li <[email protected]>
AuthorDate: Mon Feb 14 22:14:24 2022 +0800
[TVMC] Add configuration `tir.add_lower_pass` to option `--pass-config`
(#9817)
---
python/tvm/driver/tvmc/compiler.py | 3 +-
python/tvm/driver/tvmc/pass_config.py | 82 ++++++++++++++++++++++++--
tests/python/driver/tvmc/test_pass_config.py | 88 ++++++++++++++++++++++++++++
3 files changed, 167 insertions(+), 6 deletions(-)
diff --git a/python/tvm/driver/tvmc/compiler.py
b/python/tvm/driver/tvmc/compiler.py
index d260c98..df56e3b 100644
--- a/python/tvm/driver/tvmc/compiler.py
+++ b/python/tvm/driver/tvmc/compiler.py
@@ -95,7 +95,8 @@ def add_compile_parser(subparsers, _):
metavar=("name=value"),
help="configurations to be used at compile time. This option can be
provided multiple "
"times, each one to set one configuration value, "
- "e.g. '--pass-config relay.backend.use_auto_scheduler=0'.",
+ "e.g. '--pass-config relay.backend.use_auto_scheduler=0', "
+ "e.g. '--pass-config
tir.add_lower_pass=opt_level1,pass1,opt_level2,pass2'.",
)
generate_target_args(parser)
diff --git a/python/tvm/driver/tvmc/pass_config.py
b/python/tvm/driver/tvmc/pass_config.py
index 7cf0f01..dde5b9c 100644
--- a/python/tvm/driver/tvmc/pass_config.py
+++ b/python/tvm/driver/tvmc/pass_config.py
@@ -18,10 +18,41 @@
TVMC PassContext Interface
"""
+import importlib
+
import tvm
from tvm.driver.tvmc import TVMCException
+def load_function(full_name):
+ """Dynamic loading a function by the full name.
+ Parameters
+ ----------
+ full_name: str
+ The name of a PackedFunc or a string of the form "path.to.module.func"
+ that indicates the module that can be imported.
+ You must be aware of the load order here, it first tries to find it via
+ TVM global function, if not find, try to import it by
"importlib.import_module".
+ Returns
+ -------
+ func: function or PackedFunc
+ The loaded fucntion.
+ """
+ global_func = tvm.get_global_func(full_name, allow_missing=True)
+ if global_func is not None:
+ return global_func
+
+ # split full name "path.to.module.func" into two parts ["path.to.module",
"func"]
+ module_name, func_name = full_name.rsplit(".", 1)
+
+ # import module and find the function
+ module = importlib.import_module(module_name)
+ if hasattr(module, func_name):
+ return getattr(module, func_name)
+
+ raise TVMCException(f"No function '{func_name}' found in module
'{module_name}'.")
+
+
def get_pass_config_value(name, value, config_type):
"""Get a PassContext configuration value, based on its config data type.
@@ -41,6 +72,8 @@ def get_pass_config_value(name, value, config_type):
specified by config_type.
"""
+ parsed_value = None
+
if config_type == "IntImm":
# "Bool" configurations in the PassContext are recognized as
# IntImm, so deal with this case here
@@ -56,11 +89,44 @@ def get_pass_config_value(name, value, config_type):
parsed_value = mapping_values.get(value.lower(), None)
if parsed_value is None:
- raise TVMCException(f"Invalid value '{value}' for configuration
'{name}'. ")
+ raise TVMCException(f"Invalid value '{value}' for configuration
'{name}'.")
- if config_type == "runtime.String":
+ elif config_type == "runtime.String":
parsed_value = value
+ elif config_type == "Array":
+ if name == "tir.add_lower_pass":
+ pass_list = value.split(",")
+ if len(pass_list) % 2 != 0:
+ raise TVMCException(
+ f"The configuration of '{name}' must be of the form "
+ "'tir.add_lower_pass=opt_level1,pass1,opt_evel2,pass2'"
+ )
+
+ parsed_value = []
+ for i in range(0, len(pass_list), 2):
+ level, pass_func = pass_list[i].strip(), pass_list[i +
1].strip()
+ try:
+ level = int(level)
+ except ValueError:
+ raise TVMCException(f"Only integer is allow for
configuration '{name}'.")
+
+ # TODO (@leeexyz) We should parse configurations of each tir
Pass.
+ # For now, we only use the defaults. Currently, There are
four config nodes:
+ # `tir.transform.LoopPartitionConfig`
+ # `tir.transform.UnrollLoopConfig`
+ # `tir.transform.HoistIfThenElseConfig`
+ # `tir.transform.InjectDoubleBufferConfig`
+ # loading pass func and calling it to get the Pass
+ pass_func = load_function(pass_func)()
+ parsed_value.append((level, pass_func))
+ else:
+ raise TVMCException(f"Unsupported configuration '{name}' for
'{config_type}' type.")
+
+ else:
+ # not raise here cause we alreay checked before calling this function
+ pass
+
return parsed_value
@@ -81,7 +147,7 @@ def parse_configs(input_configs):
return {}
all_configs = tvm.ir.transform.PassContext.list_configs()
- supported_config_types = ("IntImm", "runtime.String")
+ supported_config_types = ("IntImm", "runtime.String", "Array")
supported_configs = [
name for name in all_configs.keys() if all_configs[name]["type"] in
supported_config_types
]
@@ -116,7 +182,13 @@ def parse_configs(input_configs):
f"The following configurations are supported: {',
'.join(supported_configs)}"
)
- parsed_value = get_pass_config_value(name, value,
all_configs[name]["type"])
- pass_context_configs[name] = parsed_value
+ config_type = all_configs[name]["type"]
+ parsed_value = get_pass_config_value(name, value, config_type)
+
+ if config_type == "Array" and name in pass_context_configs:
+ # merge configs if the configuration exists
+ pass_context_configs[name].extend(parsed_value)
+ else:
+ pass_context_configs[name] = parsed_value
return pass_context_configs
diff --git a/tests/python/driver/tvmc/test_pass_config.py
b/tests/python/driver/tvmc/test_pass_config.py
index bb815e1..f928c8a 100644
--- a/tests/python/driver/tvmc/test_pass_config.py
+++ b/tests/python/driver/tvmc/test_pass_config.py
@@ -16,11 +16,13 @@
# under the License.
import pytest
+from unittest import mock
from tvm.contrib.target.vitis_ai import vitis_ai_available
from tvm.driver.tvmc import TVMCException
from tvm.driver.tvmc.pass_config import parse_configs
+from tvm.tir.transform import PrimFuncPass
def test_config_invalid_format():
@@ -71,3 +73,89 @@ def test_config_valid_multiple_configs():
assert configs["tir.detect_global_barrier"] == 10
assert "relay.ext.vitis_ai.options.build_dir" in configs.keys()
assert configs["relay.ext.vitis_ai.options.build_dir"] == "mystring"
+
+
+def test_add_lower_pass_multi_built_in_pass():
+ configs = parse_configs(
+ [
+ "tir.add_lower_pass=1,tir.transform.UnrollLoop",
+
"tir.add_lower_pass=1,tir.transform.HoistIfThenElse,2,tir.transform.LoopPartition",
+ ]
+ )
+
+ assert len(configs["tir.add_lower_pass"]) == 3
+ # opt_level: 1, pass: tir.transform.UnrollLoop
+ assert configs["tir.add_lower_pass"][0][0] == 1
+ assert isinstance(configs["tir.add_lower_pass"][0][1], PrimFuncPass)
+ # opt_level: 1, pass: tir.transform.HoistIfThenElse
+ assert configs["tir.add_lower_pass"][1][0] == 1
+ assert isinstance(configs["tir.add_lower_pass"][1][1], PrimFuncPass)
+ # opt_level: 2, pass: tir.transform.LoopPartition
+ assert configs["tir.add_lower_pass"][2][0] == 2
+ assert isinstance(configs["tir.add_lower_pass"][2][1], PrimFuncPass)
+
+
+def test_add_lower_pass_multi_external_pass():
+ fake_pass_1 = mock.MagicMock()
+ fake_pass_2 = mock.MagicMock()
+ fake_pass_3 = mock.MagicMock()
+ with mock.patch.dict(
+ "sys.modules",
+ {"fake_module": fake_pass_1, "fake_module": fake_pass_2,
"fake_module": fake_pass_3},
+ ):
+ configs = parse_configs(
+ [
+
"tir.add_lower_pass=1,fake_module.fake_pass_1,2,fake_module.fake_pass2",
+ "tir.add_lower_pass=3,fake_module.fake_pass_3",
+ ]
+ )
+ assert len(configs["tir.add_lower_pass"]) == 3
+ # opt_level: 1, pass: fake_module.fake_pass_1
+ assert configs["tir.add_lower_pass"][0][0] == 1
+ # opt_level: 2, pass: fake_module.fake_pass_2
+ assert configs["tir.add_lower_pass"][1][0] == 2
+ # opt_level: 3, pass: fake_module.fake_pass_3
+ assert configs["tir.add_lower_pass"][2][0] == 3
+
+
+def test_add_lower_pass_multi_mix_pass():
+ fake_pass_1 = mock.MagicMock()
+ fake_pass_2 = mock.MagicMock()
+ with mock.patch.dict("sys.modules", {"fake_module": fake_pass_1,
"fake_module": fake_pass_2}):
+ configs = parse_configs(
+ [
+
"tir.add_lower_pass=1,fake_module.fake_pass_1,1,tir.transform.UnrollLoop",
+
"tir.add_lower_pass=2,fake_module.fake_pass_2,2,tir.transform.LoopPartition",
+ ]
+ )
+ assert len(configs["tir.add_lower_pass"]) == 4
+ # opt_level: 1, pass: fake_module.fake_pass_1
+ assert configs["tir.add_lower_pass"][0][0] == 1
+ # opt_level: 1, pass: tir.transform.UnrollLoop
+ assert configs["tir.add_lower_pass"][1][0] == 1
+ assert isinstance(configs["tir.add_lower_pass"][1][1], PrimFuncPass)
+ # opt_level: 2, pass: fake_module.fake_pass_2
+ assert configs["tir.add_lower_pass"][2][0] == 2
+ # opt_level: 2, pass: tir.transform.LoopPartition
+ assert configs["tir.add_lower_pass"][3][0] == 2
+ assert isinstance(configs["tir.add_lower_pass"][3][1], PrimFuncPass)
+
+
+def test_add_lower_pass_invalid_format():
+ # wrong format
+ with pytest.raises(TVMCException):
+ _ = parse_configs(["tir.add_lower_pass=tir.transform.UnrollLoop,1"])
+ # missing pass name
+ with pytest.raises(TVMCException):
+ _ = parse_configs(["tir.add_lower_pass=1,tir.transform.UnrollLoop,3"])
+ # wrong opt level
+ with pytest.raises(TVMCException):
+ _ = parse_configs(["tir.add_lower_pass=a,tir.transform.UnrollLoop"])
+ # fake module
+ with pytest.raises(ModuleNotFoundError):
+ _ = parse_configs(
+
["tir.add_lower_pass=1,tir.transform.UnrollLoop,2,path.to.module.fake_func"]
+ )
+ # real module and fake func
+ with pytest.raises(TVMCException):
+ _ =
parse_configs(["tir.add_lower_pass=1,tir.transform.UnrollLoop,2,tvm.tir.fake_func"])