This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 2066ce9612 [Unity][MSC][M4.2][Step2] Enable plugin with manager, test 
plugins in compile pipeline (#16581)
2066ce9612 is described below

commit 2066ce9612871d1b6ce4f9fb0db85e547449fe6f
Author: Archermmt <[email protected]>
AuthorDate: Tue Feb 20 15:49:56 2024 +0800

    [Unity][MSC][M4.2][Step2] Enable plugin with manager, test plugins in 
compile pipeline (#16581)
    
    enable plugin with manager
---
 .../msc/framework/torch/frontend/translate.py      |  9 +++-
 python/tvm/relax/frontend/torch/fx_translator.py   | 33 +++++++++++-
 tests/python/contrib/test_msc/test_plugin.py       | 58 ++++++++++++++++++++++
 3 files changed, 96 insertions(+), 4 deletions(-)

diff --git a/python/tvm/contrib/msc/framework/torch/frontend/translate.py 
b/python/tvm/contrib/msc/framework/torch/frontend/translate.py
index 3ac1b81a2c..2509f1abfc 100644
--- a/python/tvm/contrib/msc/framework/torch/frontend/translate.py
+++ b/python/tvm/contrib/msc/framework/torch/frontend/translate.py
@@ -70,6 +70,7 @@ def from_torch(
     build_config: Optional[Dict[str, str]] = None,
     opt_config: Optional[Dict[str, str]] = None,
     as_msc: bool = True,
+    custom_convert_map: dict = None,
 ) -> Tuple[Union[MSCGraph, tvm.IRModule], Dict[str, tvm.nd.array]]:
     """Change torch nn.Module to MSCGraph.
 
@@ -91,6 +92,8 @@ def from_torch(
         The config for optimize the relay before translate.
     as_msc: bool
         Set to to return msc graph, otherwise relax mod
+    custom_convert_map: dict
+        The convert map for plugin
 
     Returns
     -------
@@ -103,7 +106,7 @@ def from_torch(
     if via_relax:
         graph_model, params = torch.fx.symbolic_trace(model), None
         with torch.no_grad():
-            relax_mod = from_fx(graph_model, input_info)
+            relax_mod = from_fx(graph_model, input_info, 
custom_convert_map=custom_convert_map)
     else:
         datas = [np.random.rand(*i[0]).astype(i[1]) for i in input_info]
         torch_datas = [torch.from_numpy(i) for i in datas]
@@ -116,7 +119,9 @@ def from_torch(
             shape_list = list(zip(input_names, input_info))
         else:
             shape_list = [("input" + str(idx), i_info) for idx, i_info in 
enumerate(input_info)]
-        relay_mod, params = tvm.relay.frontend.from_pytorch(scripted_model, 
shape_list)
+        relay_mod, params = tvm.relay.frontend.from_pytorch(
+            scripted_model, shape_list, custom_convert_map=custom_convert_map
+        )
         relax_mod = relay_to_relax(relay_mod, params, trans_config, 
build_config, opt_config)
     if not as_msc:
         return relax_mod, params
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index 5e581e81f3..49e9fc4495 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -1459,6 +1459,17 @@ class TorchFXImporter:
             "scaled_dot_product_attention": self._scaled_dot_product_attention,
         }
 
+    def update_convert_map(self, custom_convert_map: dict):
+        """Update self.convert_map with custom convert map
+
+        Parameters
+        ----------
+        custom_convert_map : Dictionary of str to Relax op
+            A custom op conversion map in the same format as self.convert_map
+        """
+
+        self.convert_map.update(custom_convert_map)
+
     def from_fx(
         self,
         model,
@@ -1466,10 +1477,16 @@ class TorchFXImporter:
         keep_params_as_input: bool,
         unwrap_unit_return_tuple: bool,
         no_bind_return_tuple: bool,
+        custom_convert_map: dict = None,
     ) -> tvm.IRModule:
         """Convert a PyTorch FX GraphModule to a Relax program."""
         from torch import fx
 
+        if custom_convert_map:
+            custom_ops = set(custom_convert_map.keys())
+            self.update_convert_map(custom_convert_map)
+        else:
+            custom_ops = set()
         self.named_modules = dict(model.named_modules())
 
         graph: fx.Graph = model.graph
@@ -1548,7 +1565,10 @@ class TorchFXImporter:
                         assert (
                             func_name in self.convert_map
                         ), f"Unsupported function type {func_name}"
-                        self.env[node] = self.convert_map[func_name](node)
+                        if func_name in custom_ops:
+                            self.env[node] = self.convert_map[func_name](node, 
self)
+                        else:
+                            self.env[node] = self.convert_map[func_name](node)
                     elif node.op == "call_method":
                         assert (
                             node.target in self.convert_map
@@ -1572,6 +1592,7 @@ def from_fx(
     keep_params_as_input: bool = False,
     unwrap_unit_return_tuple: bool = False,
     no_bind_return_tuple: bool = False,
+    custom_convert_map: dict = None,
 ) -> tvm.IRModule:
     """Convert a PyTorch FX GraphModule to a Relax program
 
@@ -1594,6 +1615,9 @@ def from_fx(
         A boolean flag indicating whether to bind the return tuple as a relax 
var.
         If the flag is true and the return value is a tuple, it will not bind 
it to a var.
 
+    custom_convert_map : Dictionary of str to Relax op
+        A custom op conversion map in the same format as 
TorchFXImporter.convert_map
+
     Returns
     -------
     output : tvm.IRModule
@@ -1662,5 +1686,10 @@ def from_fx(
     check the placeholder rows in the beginning of the tabular.
     """
     return TorchFXImporter().from_fx(
-        model, input_info, keep_params_as_input, unwrap_unit_return_tuple, 
no_bind_return_tuple
+        model,
+        input_info,
+        keep_params_as_input,
+        unwrap_unit_return_tuple,
+        no_bind_return_tuple,
+        custom_convert_map=custom_convert_map,
     )
diff --git a/tests/python/contrib/test_msc/test_plugin.py 
b/tests/python/contrib/test_msc/test_plugin.py
index 277268f8ae..e2d3b5fcd3 100644
--- a/tests/python/contrib/test_msc/test_plugin.py
+++ b/tests/python/contrib/test_msc/test_plugin.py
@@ -26,6 +26,7 @@ import tvm.testing
 from tvm import relax
 from tvm.relax.transform import BindParams
 from tvm.script import relax as R
+from tvm.contrib.msc.pipeline import MSCManager
 from tvm.contrib.msc.plugin import build_plugins
 from tvm.contrib.msc.core.utils.namespace import MSCFramework
 from tvm.contrib.msc.core import utils as msc_utils
@@ -287,6 +288,39 @@ def _test_torch_plugin(manager):
     assert outputs.min() >= 0 and outputs.max() <= 0.5
 
 
+def _test_with_manager(plugins, compile_type, expected_info):
+    """Test the plugin with manager"""
+
+    path = "test_plugin_" + compile_type
+    model = _get_torch_model(plugins[MSCFramework.TORCH])
+    if torch.cuda.is_available():
+        model = model.to(torch.device("cuda:0"))
+    config = {
+        "workspace": msc_utils.msc_dir(path),
+        "model_type": MSCFramework.TORCH,
+        "verbose": "critical",
+        "inputs": [["input_0", [1, 3, 224, 224], "float32"]],
+        "outputs": ["output"],
+        "dataset": {"prepare": {"loader": "from_random", "max_iter": 5}},
+        "prepare": {"profile": {"benchmark": {"repeat": 10}}},
+        "baseline": {
+            "profile": {"check": {"atol": 1e-2, "rtol": 1e-2}, "benchmark": 
{"repeat": 10}},
+        },
+        "compile": {
+            "run_type": compile_type,
+            "profile": {"check": {"atol": 1e-2, "rtol": 1e-2}, "benchmark": 
{"repeat": 10}},
+        },
+    }
+    manager = MSCManager(model, config, plugins=plugins)
+    report = manager.run_pipe()
+    model_info = manager.runner.model_info
+    manager.destory()
+    assert report["success"], "Failed to run pipe for torch -> 
{}".format(compile_type)
+    assert msc_utils.dict_equal(
+        model_info, expected_info
+    ), "Model info {} mismatch with expected {}".format(model_info, 
expected_info)
+
+
 def test_plugin():
     """Test the plugins"""
 
@@ -302,6 +336,30 @@ def test_plugin():
         _test_tvm_plugin(managers[MSCFramework.TVM], "cuda")
     _test_torch_plugin(managers[MSCFramework.TORCH])
 
+    # test the plugin with manager
+    model_info = {
+        "inputs": [
+            {"name": "input_0", "shape": [1, 3, 224, 224], "dtype": "float32", 
"layout": "NCHW"}
+        ],
+        "outputs": [
+            {"name": "output", "shape": [1, 6, 218, 218], "dtype": "float32", 
"layout": "NCHW"}
+        ],
+        "nodes": {"total": 4, "input": 1, "msc.conv2d_bias": 1, "MyRelu": 1, 
"nn.max_pool2d": 1},
+    }
+    _test_with_manager(managers, MSCFramework.TORCH, model_info)
+    _test_with_manager(managers, MSCFramework.TVM, model_info)
+    if tvm.get_global_func("relax.ext.tensorrt", True) is not None:
+        byoc_info = {
+            "inputs": [
+                {"name": "input_0", "shape": [1, 3, 224, 224], "dtype": 
"float32", "layout": "NCHW"}
+            ],
+            "outputs": [
+                {"name": "output", "shape": [1, 6, 218, 218], "dtype": 
"float32", "layout": ""}
+            ],
+            "nodes": {"total": 2, "input": 1, "msc_tensorrt": 1},
+        }
+        _test_with_manager(managers, MSCFramework.TENSORRT, byoc_info)
+
     plugin_root.destory()
 
 

Reply via email to