This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 2066ce9612 [Unity][MSC][M4.2][Step2] Enable plugin with manager, test
plugins in compile pipeline (#16581)
2066ce9612 is described below
commit 2066ce9612871d1b6ce4f9fb0db85e547449fe6f
Author: Archermmt <[email protected]>
AuthorDate: Tue Feb 20 15:49:56 2024 +0800
[Unity][MSC][M4.2][Step2] Enable plugin with manager, test plugins in
compile pipeline (#16581)
enable plugin with manager
---
.../msc/framework/torch/frontend/translate.py | 9 +++-
python/tvm/relax/frontend/torch/fx_translator.py | 33 +++++++++++-
tests/python/contrib/test_msc/test_plugin.py | 58 ++++++++++++++++++++++
3 files changed, 96 insertions(+), 4 deletions(-)
diff --git a/python/tvm/contrib/msc/framework/torch/frontend/translate.py
b/python/tvm/contrib/msc/framework/torch/frontend/translate.py
index 3ac1b81a2c..2509f1abfc 100644
--- a/python/tvm/contrib/msc/framework/torch/frontend/translate.py
+++ b/python/tvm/contrib/msc/framework/torch/frontend/translate.py
@@ -70,6 +70,7 @@ def from_torch(
build_config: Optional[Dict[str, str]] = None,
opt_config: Optional[Dict[str, str]] = None,
as_msc: bool = True,
+ custom_convert_map: dict = None,
) -> Tuple[Union[MSCGraph, tvm.IRModule], Dict[str, tvm.nd.array]]:
"""Change torch nn.Module to MSCGraph.
@@ -91,6 +92,8 @@ def from_torch(
The config for optimize the relay before translate.
as_msc: bool
Set to to return msc graph, otherwise relax mod
+ custom_convert_map: dict
+ The convert map for plugin
Returns
-------
@@ -103,7 +106,7 @@ def from_torch(
if via_relax:
graph_model, params = torch.fx.symbolic_trace(model), None
with torch.no_grad():
- relax_mod = from_fx(graph_model, input_info)
+ relax_mod = from_fx(graph_model, input_info,
custom_convert_map=custom_convert_map)
else:
datas = [np.random.rand(*i[0]).astype(i[1]) for i in input_info]
torch_datas = [torch.from_numpy(i) for i in datas]
@@ -116,7 +119,9 @@ def from_torch(
shape_list = list(zip(input_names, input_info))
else:
shape_list = [("input" + str(idx), i_info) for idx, i_info in
enumerate(input_info)]
- relay_mod, params = tvm.relay.frontend.from_pytorch(scripted_model,
shape_list)
+ relay_mod, params = tvm.relay.frontend.from_pytorch(
+ scripted_model, shape_list, custom_convert_map=custom_convert_map
+ )
relax_mod = relay_to_relax(relay_mod, params, trans_config,
build_config, opt_config)
if not as_msc:
return relax_mod, params
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index 5e581e81f3..49e9fc4495 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -1459,6 +1459,17 @@ class TorchFXImporter:
"scaled_dot_product_attention": self._scaled_dot_product_attention,
}
+ def update_convert_map(self, custom_convert_map: dict):
+ """Update self.convert_map with custom convert map
+
+ Parameters
+ ----------
+ custom_convert_map : Dictionary of str to Relax op
+ A custom op conversion map in the same format as self.convert_map
+ """
+
+ self.convert_map.update(custom_convert_map)
+
def from_fx(
self,
model,
@@ -1466,10 +1477,16 @@ class TorchFXImporter:
keep_params_as_input: bool,
unwrap_unit_return_tuple: bool,
no_bind_return_tuple: bool,
+ custom_convert_map: dict = None,
) -> tvm.IRModule:
"""Convert a PyTorch FX GraphModule to a Relax program."""
from torch import fx
+ if custom_convert_map:
+ custom_ops = set(custom_convert_map.keys())
+ self.update_convert_map(custom_convert_map)
+ else:
+ custom_ops = set()
self.named_modules = dict(model.named_modules())
graph: fx.Graph = model.graph
@@ -1548,7 +1565,10 @@ class TorchFXImporter:
assert (
func_name in self.convert_map
), f"Unsupported function type {func_name}"
- self.env[node] = self.convert_map[func_name](node)
+ if func_name in custom_ops:
+ self.env[node] = self.convert_map[func_name](node,
self)
+ else:
+ self.env[node] = self.convert_map[func_name](node)
elif node.op == "call_method":
assert (
node.target in self.convert_map
@@ -1572,6 +1592,7 @@ def from_fx(
keep_params_as_input: bool = False,
unwrap_unit_return_tuple: bool = False,
no_bind_return_tuple: bool = False,
+ custom_convert_map: dict = None,
) -> tvm.IRModule:
"""Convert a PyTorch FX GraphModule to a Relax program
@@ -1594,6 +1615,9 @@ def from_fx(
A boolean flag indicating whether to bind the return tuple as a relax
var.
If the flag is true and the return value is a tuple, it will not bind
it to a var.
+ custom_convert_map : Dictionary of str to Relax op
+ A custom op conversion map in the same format as
TorchFXImporter.convert_map
+
Returns
-------
output : tvm.IRModule
@@ -1662,5 +1686,10 @@ def from_fx(
check the placeholder rows in the beginning of the tabular.
"""
return TorchFXImporter().from_fx(
- model, input_info, keep_params_as_input, unwrap_unit_return_tuple,
no_bind_return_tuple
+ model,
+ input_info,
+ keep_params_as_input,
+ unwrap_unit_return_tuple,
+ no_bind_return_tuple,
+ custom_convert_map=custom_convert_map,
)
diff --git a/tests/python/contrib/test_msc/test_plugin.py
b/tests/python/contrib/test_msc/test_plugin.py
index 277268f8ae..e2d3b5fcd3 100644
--- a/tests/python/contrib/test_msc/test_plugin.py
+++ b/tests/python/contrib/test_msc/test_plugin.py
@@ -26,6 +26,7 @@ import tvm.testing
from tvm import relax
from tvm.relax.transform import BindParams
from tvm.script import relax as R
+from tvm.contrib.msc.pipeline import MSCManager
from tvm.contrib.msc.plugin import build_plugins
from tvm.contrib.msc.core.utils.namespace import MSCFramework
from tvm.contrib.msc.core import utils as msc_utils
@@ -287,6 +288,39 @@ def _test_torch_plugin(manager):
assert outputs.min() >= 0 and outputs.max() <= 0.5
+def _test_with_manager(plugins, compile_type, expected_info):
+ """Test the plugin with manager"""
+
+ path = "test_plugin_" + compile_type
+ model = _get_torch_model(plugins[MSCFramework.TORCH])
+ if torch.cuda.is_available():
+ model = model.to(torch.device("cuda:0"))
+ config = {
+ "workspace": msc_utils.msc_dir(path),
+ "model_type": MSCFramework.TORCH,
+ "verbose": "critical",
+ "inputs": [["input_0", [1, 3, 224, 224], "float32"]],
+ "outputs": ["output"],
+ "dataset": {"prepare": {"loader": "from_random", "max_iter": 5}},
+ "prepare": {"profile": {"benchmark": {"repeat": 10}}},
+ "baseline": {
+ "profile": {"check": {"atol": 1e-2, "rtol": 1e-2}, "benchmark":
{"repeat": 10}},
+ },
+ "compile": {
+ "run_type": compile_type,
+ "profile": {"check": {"atol": 1e-2, "rtol": 1e-2}, "benchmark":
{"repeat": 10}},
+ },
+ }
+ manager = MSCManager(model, config, plugins=plugins)
+ report = manager.run_pipe()
+ model_info = manager.runner.model_info
+ manager.destory()
+ assert report["success"], "Failed to run pipe for torch ->
{}".format(compile_type)
+ assert msc_utils.dict_equal(
+ model_info, expected_info
+ ), "Model info {} mismatch with expected {}".format(model_info,
expected_info)
+
+
def test_plugin():
"""Test the plugins"""
@@ -302,6 +336,30 @@ def test_plugin():
_test_tvm_plugin(managers[MSCFramework.TVM], "cuda")
_test_torch_plugin(managers[MSCFramework.TORCH])
+ # test the plugin with manager
+ model_info = {
+ "inputs": [
+ {"name": "input_0", "shape": [1, 3, 224, 224], "dtype": "float32",
"layout": "NCHW"}
+ ],
+ "outputs": [
+ {"name": "output", "shape": [1, 6, 218, 218], "dtype": "float32",
"layout": "NCHW"}
+ ],
+ "nodes": {"total": 4, "input": 1, "msc.conv2d_bias": 1, "MyRelu": 1,
"nn.max_pool2d": 1},
+ }
+ _test_with_manager(managers, MSCFramework.TORCH, model_info)
+ _test_with_manager(managers, MSCFramework.TVM, model_info)
+ if tvm.get_global_func("relax.ext.tensorrt", True) is not None:
+ byoc_info = {
+ "inputs": [
+ {"name": "input_0", "shape": [1, 3, 224, 224], "dtype":
"float32", "layout": "NCHW"}
+ ],
+ "outputs": [
+ {"name": "output", "shape": [1, 6, 218, 218], "dtype":
"float32", "layout": ""}
+ ],
+ "nodes": {"total": 2, "input": 1, "msc_tensorrt": 1},
+ }
+ _test_with_manager(managers, MSCFramework.TENSORRT, byoc_info)
+
plugin_root.destory()