This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new e37165f40f [Unity][MSC][M1.5-1.7] Add Runner and test with torch, 
tensorflow && tensorrt (#16072)
e37165f40f is described below

commit e37165f40f51d8cf0b8a9aef22314aed576760b1
Author: Archermmt <[email protected]>
AuthorDate: Thu Nov 9 10:50:13 2023 +0800

    [Unity][MSC][M1.5-1.7] Add Runner and test with torch, tensorflow && 
tensorrt (#16072)
    
    * update runners
    
    * format fix
    
    * format fix
---
 python/tvm/contrib/msc/core/frontend/translate.py  |   3 +-
 python/tvm/contrib/msc/core/runtime/runner.py      |  34 ++--
 python/tvm/contrib/msc/core/transform/transform.py |  18 ++
 python/tvm/contrib/msc/core/utils/file.py          |   8 +-
 python/tvm/contrib/msc/core/utils/info.py          |   8 +-
 .../msc/framework/tensorflow/frontend/translate.py |   3 +-
 .../msc/framework/tensorflow/runtime/__init__.py   |  19 ++
 .../msc/framework/tensorflow/runtime/runner.py     | 217 +++++++++++++++++++++
 .../msc/framework/tensorrt/codegen/codegen.py      |   4 +-
 .../msc/framework/tensorrt/runtime/__init__.py     |  19 ++
 .../msc/framework/tensorrt/runtime/runner.py       |  45 +++++
 .../msc/framework/torch/runtime/__init__.py        |  19 ++
 .../contrib/msc/framework/torch/runtime/runner.py  | 197 +++++++++++++++++++
 .../contrib/msc/framework/tvm/runtime/runner.py    |   8 +-
 src/contrib/msc/core/codegen/codegen_utils.cc      |   2 +-
 src/contrib/msc/core/ir/graph.cc                   |   3 +
 src/contrib/msc/core/ir/graph_builder.cc           |  22 ++-
 src/contrib/msc/core/printer/prototxt_printer.cc   |   6 +-
 src/contrib/msc/core/transform/set_byoc_attrs.cc   |  92 +++++++++
 src/contrib/msc/framework/tensorflow/codegen.cc    |   3 +
 .../msc/framework/tensorflow/tf_v1_opcode.cc       |  46 ++++-
 src/contrib/msc/framework/tensorrt/codegen.cc      |   7 +-
 src/contrib/msc/framework/tensorrt/codegen_utils.h |   2 +-
 tests/python/contrib/test_msc/test_runner.py       |  71 +++++++
 .../contrib/test_msc/test_translate_relay.py       |   2 +-
 25 files changed, 811 insertions(+), 47 deletions(-)

diff --git a/python/tvm/contrib/msc/core/frontend/translate.py 
b/python/tvm/contrib/msc/core/frontend/translate.py
index e1dce2ae28..1e7efe6489 100644
--- a/python/tvm/contrib/msc/core/frontend/translate.py
+++ b/python/tvm/contrib/msc/core/frontend/translate.py
@@ -315,6 +315,7 @@ def byoc_partition(
                 msc_transform.BindShape(),
                 msc_transform.FuseTuple(target),
                 tvm.relax.transform.MergeCompositeFunctions(),
+                msc_transform.SetBYOCAttrs(target),
                 msc_transform.SetExprName(target=target),
                 
msc_transform.SetExprLayout(trans_config.get("allow_layout_missing", True)),
             ]
@@ -335,7 +336,7 @@ def byoc_partition(
 
     graphs_info, all_weights = [], _ffi_api.GetRelaxWeights(msc_mod, entry)
     for name in func_names:
-        build_config.update({"graph_name": name, "byoc_entry": name})
+        build_config.update({"graph_name": msc_mod[name].attrs["byoc_name"], 
"byoc_entry": name})
         graph = _ffi_api.BuildFromRelax(msc_mod, entry, 
msc_utils.dump_dict(build_config))
         graphs_info.append((graph, normalize_weights(all_weights, graph)))
     return _partition_mod(mod, False), graphs_info
diff --git a/python/tvm/contrib/msc/core/runtime/runner.py 
b/python/tvm/contrib/msc/core/runtime/runner.py
index 65e86e4896..3c8212f02d 100644
--- a/python/tvm/contrib/msc/core/runtime/runner.py
+++ b/python/tvm/contrib/msc/core/runtime/runner.py
@@ -94,7 +94,7 @@ class BaseRunner(object):
         self._model, self._model_info = None, {}
         self._runnable = None
 
-    def build(self, cache_dir: msc_utils.MSCDirectory = None, build_graph: 
bool = False) -> object:
+    def build(self, cache_dir: msc_utils.MSCDirectory = None, build_graph: 
bool = False) -> Any:
         """Build the runnable object
 
         Parameters
@@ -106,7 +106,7 @@ class BaseRunner(object):
 
         Returns
         -------
-        runnable: object
+        runnable: Any
            The runnable object.
         """
 
@@ -319,18 +319,18 @@ class BaseRunner(object):
 
         raise NotImplementedError("_save_graphs is not implemented for " + 
str(self.__class__))
 
-    def _generate_model(self) -> object:
+    def _generate_model(self) -> Any:
         """Codegen the model according to framework
 
         Returns
         -------
-        model: object
+        model: Any
             The meta model
         """
 
         raise NotImplementedError("_load is not implemented for " + 
str(self.__class__))
 
-    def _load_model(self, cache_dir: msc_utils.MSCDirectory, cache_info: dict) 
-> object:
+    def _load_model(self, cache_dir: msc_utils.MSCDirectory, cache_info: dict) 
-> Any:
         """Load the model from cache
 
         Parameters
@@ -342,7 +342,7 @@ class BaseRunner(object):
 
         Returns
         -------
-        model: object
+        model: Any
             The meta model
         """
 
@@ -365,12 +365,12 @@ class BaseRunner(object):
         # disable save model by default
         return {}
 
-    def _to_runnable(self, model: object, device: str, is_training: bool) -> 
object:
+    def _to_runnable(self, model: Any, device: str, is_training: bool) -> Any:
         """Build runnable object
 
         Parameters
         -------
-        model: object
+        model: Any
             The meta model.
         device: str
             The device for place model
@@ -379,13 +379,13 @@ class BaseRunner(object):
 
         Returns
         -------
-        runnable: object
+        runnable: Any
             The runnable
         """
 
         raise NotImplementedError("_to_runnable is not implemented for " + 
str(self.__class__))
 
-    def _load_runnable(self, cache_dir: msc_utils.MSCDirectory, cache_info: 
dict) -> object:
+    def _load_runnable(self, cache_dir: msc_utils.MSCDirectory, cache_info: 
dict) -> Any:
         """Load the runnable from cache
 
         Parameters
@@ -397,7 +397,7 @@ class BaseRunner(object):
 
         Returns
         -------
-        runnable: object
+        runnable: Any
             The runnable
         """
 
@@ -432,7 +432,7 @@ class BaseRunner(object):
         raise NotImplementedError("_inspect_model is not implemented for " + 
str(self.__class__))
 
     def _call_runnable(
-        self, runnable: object, inputs: Dict[str, np.ndarray], device: str
+        self, runnable: Any, inputs: Dict[str, np.ndarray], device: str
     ) -> Union[List[np.ndarray], Dict[str, np.ndarray]]:
         """Call the runnable to get outputs
 
@@ -558,12 +558,12 @@ class ModelRunner(BaseRunner):
                 f_params.write(tvm.runtime.save_param_dict(self._weights[0]))
         return {"main": main_info}
 
-    def _generate_model(self) -> object:
+    def _generate_model(self) -> Any:
         """Codegen the model according to framework
 
         Returns
         -------
-        model: object
+        model: Any
             The runnable model
         """
 
@@ -719,12 +719,12 @@ class BYOCRunner(BaseRunner):
             output_folder=self._load_config.get("output_folder", 
msc_utils.get_output_dir()),
         )
 
-    def _to_runnable(self, model: object, device: str, is_training: bool) -> 
object:
+    def _to_runnable(self, model: Any, device: str, is_training: bool) -> Any:
         """Build runnable object
 
         Parameters
         -------
-        model: object
+        model: Any
             The runnable model on cpu.
         device: str
             The device for place model
@@ -733,7 +733,7 @@ class BYOCRunner(BaseRunner):
 
         Returns
         -------
-        runnable: object
+        runnable: Any
             The runnable
         """
 
diff --git a/python/tvm/contrib/msc/core/transform/transform.py 
b/python/tvm/contrib/msc/core/transform/transform.py
index 24f7d38426..8bd4ca9521 100644
--- a/python/tvm/contrib/msc/core/transform/transform.py
+++ b/python/tvm/contrib/msc/core/transform/transform.py
@@ -118,3 +118,21 @@ def FuseTuple(target, entry_name: str = "main") -> 
tvm.ir.transform.Pass:
     """
 
     return relax_api.FuseTuple(target, entry_name)  # type: ignore
+
+
+def SetBYOCAttrs(target, entry_name: str = "main") -> tvm.ir.transform.Pass:
+    """set attributes for byoc
+
+    Parameters
+    ----------
+    target: str
+        The byoc target name
+    entry_name: str
+        The entry name
+
+    Returns
+    -------
+    ret: tvm.ir.transform.Pass
+    """
+
+    return relax_api.SetBYOCAttrs(target, entry_name)  # type: ignore
diff --git a/python/tvm/contrib/msc/core/utils/file.py 
b/python/tvm/contrib/msc/core/utils/file.py
index 278d9d56b9..88808c61d2 100644
--- a/python/tvm/contrib/msc/core/utils/file.py
+++ b/python/tvm/contrib/msc/core/utils/file.py
@@ -21,7 +21,7 @@ import shutil
 import tempfile
 import types
 from functools import partial
-from typing import List
+from typing import List, Any
 from importlib.machinery import SourceFileLoader
 
 from .namespace import MSCMap, MSCKey, MSCFramework
@@ -109,7 +109,7 @@ class MSCDirectory(object):
             f.write(contains)
         return file_path
 
-    def move_file(self, src_file: str, dst_folder: object, dst_file: str = 
None):
+    def move_file(self, src_file: str, dst_folder: Any, dst_file: str = None):
         """Move a file to another folder
 
         Parameters
@@ -133,7 +133,7 @@ class MSCDirectory(object):
         os.rename(src_path, dst_path)
         return dst_path
 
-    def copy_file(self, src_file: str, dst_folder: object, dst_file: str = 
None):
+    def copy_file(self, src_file: str, dst_folder: Any, dst_file: str = None):
         """Copy a file to another folder
 
         Parameters
@@ -157,7 +157,7 @@ class MSCDirectory(object):
         shutil.copy2(src_path, dst_path)
         return dst_path
 
-    def create_dir(self, name: str, keep_history: bool = True, cleanup: bool = 
False) -> object:
+    def create_dir(self, name: str, keep_history: bool = True, cleanup: bool = 
False) -> Any:
         """Add a dir under the folder
 
         Parameters
diff --git a/python/tvm/contrib/msc/core/utils/info.py 
b/python/tvm/contrib/msc/core/utils/info.py
index 894447b169..6053d8ddc8 100644
--- a/python/tvm/contrib/msc/core/utils/info.py
+++ b/python/tvm/contrib/msc/core/utils/info.py
@@ -19,7 +19,7 @@
 import os
 import json
 import copy
-from typing import List, Tuple, Dict
+from typing import List, Tuple, Dict, Any
 from distutils.version import LooseVersion
 import numpy as np
 
@@ -36,13 +36,13 @@ class MSCArray(object):
         The data object.
     """
 
-    def __init__(self, data: object):
+    def __init__(self, data: Any):
         self._type, self._data = self._analysis(data)
 
     def __str__(self):
         return "<{}>{}".format(self._type, self.abstract())
 
-    def _analysis(self, data: object) -> Tuple[str, np.ndarray]:
+    def _analysis(self, data: Any) -> Tuple[str, np.ndarray]:
         if isinstance(data, np.ndarray):
             return "np", data
         if isinstance(data, tvm.runtime.NDArray):
@@ -76,7 +76,7 @@ class MSCArray(object):
         return self._data
 
 
-def cast_array(data: object):
+def cast_array(data: Any):
     """Cast array like object to np.ndarray
 
     Parameters
diff --git a/python/tvm/contrib/msc/framework/tensorflow/frontend/translate.py 
b/python/tvm/contrib/msc/framework/tensorflow/frontend/translate.py
index dc97a315d0..dab19ca81f 100644
--- a/python/tvm/contrib/msc/framework/tensorflow/frontend/translate.py
+++ b/python/tvm/contrib/msc/framework/tensorflow/frontend/translate.py
@@ -25,6 +25,7 @@ from tvm.contrib.msc.core import transform as msc_transform
 from tvm.contrib.msc.core.frontend import from_relax
 from tvm.contrib.msc.core.codegen import relay_to_relax
 from tvm.contrib.msc.framework.tensorflow import tf_v1
+from tvm.contrib.msc.core import utils as msc_utils
 
 
 def from_tensorflow(
@@ -75,7 +76,7 @@ def from_tensorflow(
     relax_mod = relay_to_relax(relay_mod, params, trans_config, build_config, 
opt_config)
     if not as_msc:
         return relax_mod, params
-    build_config = build_config or {}
+    build_config = msc_utils.copy_dict(build_config)
     build_config["use_var_name"] = True
     graph, weights = from_relax(relax_mod, trans_config=trans_config, 
build_config=build_config)
     return graph, weights
diff --git a/python/tvm/contrib/msc/framework/tensorflow/runtime/__init__.py 
b/python/tvm/contrib/msc/framework/tensorflow/runtime/__init__.py
new file mode 100644
index 0000000000..de9dea8e4b
--- /dev/null
+++ b/python/tvm/contrib/msc/framework/tensorflow/runtime/__init__.py
@@ -0,0 +1,19 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""tvm.contrib.msc.framework.tensorflow.runtime"""
+
+from .runner import *
diff --git a/python/tvm/contrib/msc/framework/tensorflow/runtime/runner.py 
b/python/tvm/contrib/msc/framework/tensorflow/runtime/runner.py
new file mode 100644
index 0000000000..e2c2e919ff
--- /dev/null
+++ b/python/tvm/contrib/msc/framework/tensorflow/runtime/runner.py
@@ -0,0 +1,217 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=not-context-manager
+"""tvm.contrib.msc.framework.tensorflow.runtime.runner"""
+
+import time
+from typing import Dict, List, Union, Any
+import numpy as np
+
+from tensorflow.python.client import device_lib
+from tensorflow.python.ops import variables
+
+from tvm.contrib.msc.core.runtime import ModelRunner
+from tvm.contrib.msc.core.utils.namespace import MSCFramework
+from tvm.contrib.msc.framework.tensorflow.codegen import to_tensorflow
+from tvm.contrib.msc.framework.tensorflow import tf_v1
+
+
+class WrapSession(tf_v1.Session):
+    """Wrapped session for MSC"""
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self._inputs, self._outputs = None, None
+
+    def set_bindings(self, inputs: List[Dict[str, str]], outputs: 
List[Dict[str, str]]):
+        """Set inputs and outputs for session
+
+        Parameters
+        -------
+        inputs: list
+            The inputs info of the model.
+        outputs: list
+            The outputs info of the model.
+        """
+
+        self._inputs = inputs
+        self._outputs = outputs
+
+    def run(self, fetches, *args, **kwargs):  # pylint: 
disable=useless-parent-delegation
+        return super().run(fetches, *args, **kwargs)
+
+
+class TensorflowRunner(ModelRunner):
+    """Runner of Tensorflow"""
+
+    def setup(self):
+        """Setup the runner"""
+
+        super().setup()
+        self._tf_graph = None
+        self._tf_outputs = None
+        self._session = None
+
+    def destory(self):
+        """Destory runner"""
+
+        self._session.close()
+        del self._tf_graph
+        del self._tf_outputs
+        del self._session
+        super().destory()
+
+    def _generate_model(self) -> Any:
+        """Codegen the model according to framework
+
+        Returns
+        -------
+        model: Any
+            The runnable model
+        """
+
+        if self._tf_graph:
+            del self._tf_graph
+        self._tf_graph = tf_v1.Graph()
+        with self._tf_graph.as_default():
+            self._tf_outputs = super()._generate_model()
+        return self._tf_graph
+
+    def _to_runnable(self, model: Any, device: str, is_training: bool) -> Any:
+        """Build runnable object
+
+        Parameters
+        -------
+        model: Any
+            The meta model.
+        device: str
+            The device for place model
+        is_training: bool
+            Whether to load model for training
+
+        Returns
+        -------
+        runnable: Any
+            The runnable
+        """
+
+        if self._session:
+            self._session.close()
+            del self._session
+        self._session = WrapSession(graph=self._tf_graph)
+        self._session.set_bindings(self.get_inputs(), self.get_outputs())
+        with self._tf_graph.as_default():
+            self._session.run(variables.global_variables_initializer())
+        return self._session
+
+    def _call_runnable(
+        self, runnable: WrapSession, inputs: Dict[str, np.ndarray], device: str
+    ) -> Union[List[np.ndarray], Dict[str, np.ndarray]]:
+        """Call the runnable to get outputs
+
+        Parameters
+        -------
+        runnable: WrapSession
+            The wrapped session.
+        inputs: dict<str, data>
+            The inputs in dict.
+        device: str
+            The device.
+
+        Returns
+        -------
+        outputs: list<data> or dict<str, data>
+            The outputs in list or dict.
+        """
+
+        feed_dict = {i["name"] + ":0": inputs[i["name"]] for i in 
self.get_inputs()}
+        return runnable.run(self._tf_outputs, feed_dict)
+
+    def _device_enabled(self, device: str) -> bool:
+        """Check if the device is enabled
+
+        Returns
+        -------
+        enabled: bool
+            Whether the device is enabled.
+        """
+
+        if device == "cpu":
+            return True
+        if device.startswith("cuda"):
+            device_protos = device_lib.list_local_devices()
+            return any(dev.device_type == "GPU" for dev in device_protos)
+        return False
+
+    @property
+    def codegen_func(self):
+        return to_tensorflow
+
+    @property
+    def framework(self):
+        return MSCFramework.TENSORFLOW
+
+    @classmethod
+    def run_native(
+        cls,
+        model: tf_v1.GraphDef,
+        inputs: Dict[str, np.ndarray],
+        input_names: List[str],
+        output_names: List[str],
+        warm_up: int = 10,
+        repeat: int = 0,
+    ) -> Dict[str, np.ndarray]:
+        """Run the datas and get outputs
+
+        Parameters
+        -------
+        model: tf_v1.GraphDef
+            The graph def.
+        inputs: dict<str, data>
+            The inputs in dict.
+        input_names: list<str>
+            The input names.
+        output_names: list<str>
+            The outut names.
+        warm_up: int
+            The warm_up num for profile.
+        repeat: int
+            The repeat num for profile.
+
+
+        Returns
+        -------
+        outputs: dict<str, np.array>
+            The outputs in dict.
+        """
+
+        feed_dict = {i_name + ":0": inputs[i_name] for i_name in input_names}
+        with tf_v1.Graph().as_default():
+            tf_v1.import_graph_def(model, name="")
+            with tf_v1.Session() as sess:
+                if repeat > 0:
+                    for _ in range(warm_up):
+                        outputs = sess.run(output_names, feed_dict)
+                    start = time.time()
+                    for _ in range(repeat):
+                        outputs = sess.run(output_names, feed_dict)
+                    avg_time = (time.time() - start) * 1000 / repeat
+                else:
+                    outputs = sess.run(output_names, feed_dict)
+                    avg_time = -1
+        outputs = dict(zip(output_names, outputs))
+        return outputs, avg_time
diff --git a/python/tvm/contrib/msc/framework/tensorrt/codegen/codegen.py 
b/python/tvm/contrib/msc/framework/tensorrt/codegen/codegen.py
index be233d9465..5539e614bf 100644
--- a/python/tvm/contrib/msc/framework/tensorrt/codegen/codegen.py
+++ b/python/tvm/contrib/msc/framework/tensorrt/codegen/codegen.py
@@ -24,8 +24,8 @@ import numpy as np
 import tvm
 from tvm.contrib.msc.core.ir import MSCGraph
 from tvm.contrib.msc.core.codegen import CodeGen
-from tvm.contrib.msc.core import utils as msc_utils
 from tvm.contrib.msc.core.utils import MSCFramework
+from tvm.contrib.msc.core import utils as msc_utils
 from tvm.contrib.msc.framework.tensorrt import _ffi_api
 from .sources import get_trt_sources
 from .utils import write_weight
@@ -62,7 +62,7 @@ def to_sub_tensorrt(
         The engine file.
     """
 
-    codegen_config = codegen_config or {}
+    codegen_config = msc_utils.copy_dict(codegen_config)
     codegen_config["version"] = msc_utils.get_version(MSCFramework.TENSORRT)
     if "tensorrt_root" not in codegen_config:
         codegen_config["tensorrt_root"] = _ffi_api.GetTensorRTRoot()
diff --git a/python/tvm/contrib/msc/framework/tensorrt/runtime/__init__.py 
b/python/tvm/contrib/msc/framework/tensorrt/runtime/__init__.py
new file mode 100644
index 0000000000..56203292e4
--- /dev/null
+++ b/python/tvm/contrib/msc/framework/tensorrt/runtime/__init__.py
@@ -0,0 +1,19 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""tvm.contrib.msc.framework.tensorrt.runtime"""
+
+from .runner import *
diff --git a/python/tvm/contrib/msc/framework/tensorrt/runtime/runner.py 
b/python/tvm/contrib/msc/framework/tensorrt/runtime/runner.py
new file mode 100644
index 0000000000..88c45c786b
--- /dev/null
+++ b/python/tvm/contrib/msc/framework/tensorrt/runtime/runner.py
@@ -0,0 +1,45 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""tvm.contrib.msc.framework.tensorrt.runtime.runner"""
+
+from tvm.contrib.msc.core.runtime import BYOCRunner
+from tvm.contrib.msc.core.utils.namespace import MSCFramework
+from tvm.contrib.msc.framework.tensorrt.frontend import partition_for_tensorrt
+from tvm.contrib.msc.framework.tensorrt.codegen import to_tensorrt
+
+
+class TensorRTRunner(BYOCRunner):
+    """Runner of tensorrt"""
+
+    def setup(self):
+        """Setup the runner"""
+
+        super().setup()
+        if not self._device.startswith("cuda"):
+            self._device = "cuda"
+
+    @property
+    def codegen_func(self):
+        return to_tensorrt
+
+    @property
+    def partition_func(self):
+        return partition_for_tensorrt
+
+    @property
+    def framework(self):
+        return MSCFramework.TENSORRT
diff --git a/python/tvm/contrib/msc/framework/torch/runtime/__init__.py 
b/python/tvm/contrib/msc/framework/torch/runtime/__init__.py
new file mode 100644
index 0000000000..83a1830b29
--- /dev/null
+++ b/python/tvm/contrib/msc/framework/torch/runtime/__init__.py
@@ -0,0 +1,19 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""tvm.contrib.msc.framework.torch.runtime"""
+
+from .runner import *
diff --git a/python/tvm/contrib/msc/framework/torch/runtime/runner.py 
b/python/tvm/contrib/msc/framework/torch/runtime/runner.py
new file mode 100644
index 0000000000..a4f65b4fe1
--- /dev/null
+++ b/python/tvm/contrib/msc/framework/torch/runtime/runner.py
@@ -0,0 +1,197 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""tvm.contrib.msc.framework.torch.runtime.runner"""
+
+import time
+from typing import Dict, List, Union, Tuple, Any
+import numpy as np
+
+import torch
+import tvm
+from tvm.contrib.msc.core.runtime import ModelRunner
+from tvm.contrib.msc.core.ir import MSCGraph
+from tvm.contrib.msc.core.utils.namespace import MSCFramework
+from tvm.contrib.msc.framework.torch.codegen import to_torch
+from tvm.contrib.msc.framework.torch.frontend import set_weight_alias
+from tvm.contrib.msc.core import utils as msc_utils
+
+
+class TorchRunner(ModelRunner):
+    """Runner of Torch"""
+
+    def _translate(self) -> Tuple[List[MSCGraph], Dict[str, tvm.nd.array]]:
+        """Translate IRModule to MSCgraphs
+
+        Returns
+        -------
+        graph_list: list<MSCGraph>
+            The translated graphs
+        weights_list: list<dict<str, tvm.nd.array>>
+            The translated weights
+        """
+        graphs, weights = super()._translate()
+        return [set_weight_alias(graphs[0])], weights
+
+    def _to_runnable(self, model: Any, device: str, is_training: bool) -> Any:
+        """Build runnable object
+
+        Parameters
+        -------
+        model: Any
+            The meta model.
+        device: str
+            The device for place model
+        is_training: bool
+            Whether to load model for training
+
+        Returns
+        -------
+        runnable: Any
+            The runnable
+        """
+
+        if device == "cpu":
+            pass
+        elif device.startswith("cuda"):
+            model = model.to(torch.device(device))
+        else:
+            raise NotImplementedError("Unsupported device " + str(device))
+        if is_training:
+            model = model.train()
+        else:
+            model = model.eval()
+        return model
+
+    def _call_runnable(
+        self, runnable: torch.nn.Module, inputs: Dict[str, np.ndarray], 
device: str
+    ) -> Union[List[np.ndarray], Dict[str, np.ndarray]]:
+        """Call the runnable to get outputs
+
+        Parameters
+        -------
+        runnable: torch.nn.Module
+            The runnable model.
+        inputs: dict<str, data>
+            The inputs in dict.
+        device: str
+            The device.
+
+        Returns
+        -------
+        outputs: list<torch.Tensor>
+            The outputs in list.
+        """
+
+        model_inputs = self.get_inputs()
+        parameters = list(runnable.parameters())
+        if parameters:
+            in_dev = parameters[0].device
+        elif device == "cpu":
+            in_dev = torch.device(device)
+        elif device.startswith("cuda"):
+            in_dev = torch.device(device)
+        else:
+            raise NotImplementedError("Unsupported device " + str(device))
+        torch_inputs = [torch.from_numpy(inputs[i["name"]]).to(in_dev) for i 
in model_inputs]
+        return runnable(*torch_inputs)
+
+    def _device_enabled(self, device: str) -> bool:
+        """Check if the device is enabled
+
+        Returns
+        -------
+        enabled: bool
+            Whether the device is enabled.
+        """
+
+        if device == "cpu":
+            return True
+        if device.startswith("cuda"):
+            return torch.cuda.is_available()
+        return False
+
+    @property
+    def codegen_func(self):
+        return to_torch
+
+    @property
+    def framework(self):
+        return MSCFramework.TORCH
+
+    @classmethod
+    def run_native(
+        cls,
+        model: torch.nn.Module,
+        inputs: Dict[str, np.ndarray],
+        input_names: List[str],
+        output_names: List[str],
+        warm_up: int = 10,
+        repeat: int = 0,
+    ) -> Dict[str, np.ndarray]:
+        """Run the datas and get outputs
+
+        Parameters
+        -------
+        model: torch.nn.Module
+            The runnable model.
+        inputs: dict<str, data>
+            The inputs in dict.
+        input_names: list<str>
+            The input names.
+        output_names: list<str>
+            The outut names.
+        warm_up: int
+            The warm_up num for profile.
+        repeat: int
+            The repeat num for profile.
+
+        Returns
+        -------
+        outputs: dict<str, np.array>
+            The outputs in dict.
+        """
+
+        parameters = list(model.parameters())
+        if parameters:
+            device = parameters[0].device
+        else:
+            device = torch.device("cpu")
+
+        def _run_once():
+            torch_inputs = [torch.from_numpy(inputs[i_name]).to(device) for 
i_name in input_names]
+            return model(*torch_inputs)
+
+        if repeat > 0:
+            for _ in range(warm_up):
+                _run_once()
+            start = time.time()
+            for _ in range(repeat):
+                outputs = _run_once()
+            avg_time = (time.time() - start) * 1000 / repeat
+        else:
+            outputs = _run_once()
+            avg_time = -1
+        if isinstance(outputs, torch.Tensor):
+            assert len(output_names) == 1, "Expect 1 outputs, get " + 
str(output_names)
+            return {output_names[0]: msc_utils.cast_array(outputs)}, avg_time
+        assert len(output_names) == len(outputs), "Outputs mismatch, {} with 
{}".format(
+            output_names, len(outputs)
+        )
+        outputs = {
+            o_name: msc_utils.cast_array(o_data) for o_name, o_data in 
zip(output_names, outputs)
+        }
+        return outputs, avg_time
diff --git a/python/tvm/contrib/msc/framework/tvm/runtime/runner.py 
b/python/tvm/contrib/msc/framework/tvm/runtime/runner.py
index 90ba8e4cce..c5240ca229 100644
--- a/python/tvm/contrib/msc/framework/tvm/runtime/runner.py
+++ b/python/tvm/contrib/msc/framework/tvm/runtime/runner.py
@@ -16,7 +16,7 @@
 # under the License.
 """tvm.contrib.msc.framework.runtime.tvm.runner"""
 
-from typing import Dict, List, Union
+from typing import Dict, List, Union, Any
 import numpy as np
 
 import tvm
@@ -28,12 +28,12 @@ from tvm.contrib.msc.framework.tvm.codegen import to_relax
 class TVMRunner(ModelRunner):
     """Runner of Relax"""
 
-    def _to_runnable(self, model: object, device: str, is_training: bool) -> 
object:
+    def _to_runnable(self, model: Any, device: str, is_training: bool) -> Any:
         """Build runnable object
 
         Parameters
         -------
-        model: object
+        model: Any
             The meta model.
         device: str
             The device for place model
@@ -42,7 +42,7 @@ class TVMRunner(ModelRunner):
 
         Returns
         -------
-        runnable: object
+        runnable: Any
             The runnable
         """
 
diff --git a/src/contrib/msc/core/codegen/codegen_utils.cc 
b/src/contrib/msc/core/codegen/codegen_utils.cc
index bdc542994d..0c751a8fbf 100644
--- a/src/contrib/msc/core/codegen/codegen_utils.cc
+++ b/src/contrib/msc/core/codegen/codegen_utils.cc
@@ -51,7 +51,7 @@ const String CodeGenUtils::IdxInput(const MSCJoint& node, 
const String& prefix,
 
 const String CodeGenUtils::IdxWeight(const MSCJoint& node, const String& wtype,
                                      const String& suffix) {
-  return wtype + std::to_string(node->index) + suffix;
+  return wtype + "_" + std::to_string(node->index) + suffix;
 }
 
 const String CodeGenUtils::CommentNode(const MSCJoint& node, const String& 
prefix) {
diff --git a/src/contrib/msc/core/ir/graph.cc b/src/contrib/msc/core/ir/graph.cc
index 68f14376b9..abd581c797 100644
--- a/src/contrib/msc/core/ir/graph.cc
+++ b/src/contrib/msc/core/ir/graph.cc
@@ -578,6 +578,9 @@ const String MSCGraphNode::ToPrototxt() const {
     for (const auto& pair : node->weights) {
       param.Set("param_" + pair.first, pair.second);
     }
+    for (const auto& pair : node->attrs) {
+      param.Set(pair.first, pair.second);
+    }
     layer.push_back(std::make_pair("layer_param", 
PrototxtPrinter::ToDictDoc(param)));
     // Append the layer Map
     printer.Append(Map<String, ObjectRef>{{"layer", 
PrototxtPrinter::ToDictDoc(layer)}});
diff --git a/src/contrib/msc/core/ir/graph_builder.cc 
b/src/contrib/msc/core/ir/graph_builder.cc
index f3fb383f8f..dab4ae813e 100644
--- a/src/contrib/msc/core/ir/graph_builder.cc
+++ b/src/contrib/msc/core/ir/graph_builder.cc
@@ -205,7 +205,7 @@ const MSCJoint RelaxGraphBuilder::AddNode(const Expr& expr, 
const Optional<Expr>
   String node_name = name.size() > 0 ? name : SpanUtils::GetAttr(expr->span, 
"name");
   const auto& shared_ref = SpanUtils::GetAttr(expr->span, "shared_ref");
 
-  // Get optype
+  // Get optype and node_name
   String optype;
   if (expr->IsInstance<relax::VarNode>()) {
     if (func_params_.count(expr) && 
func_params_[expr]->IsInstance<relax::ConstantNode>()) {
@@ -227,9 +227,18 @@ const MSCJoint RelaxGraphBuilder::AddNode(const Expr& 
expr, const Optional<Expr>
       optype = StringUtils::Replace(op_node->name, "relax.", "");
     } else if (const auto* v_node = call_node->op.as<GlobalVarNode>()) {
       const auto& func = 
Downcast<relax::Function>(ref_module_->Lookup(v_node->name_hint));
-      const auto& name_opt = 
func->GetAttr<runtime::String>(relax::attr::kComposite);
-      ICHECK(name_opt.defined()) << "Unexpected global func without composite";
-      optype = name_opt.value();
+      const auto& byoc_name_opt = func->GetAttr<runtime::String>("byoc_name");
+      if (byoc_name_opt.defined()) {
+        node_name = byoc_name_opt.value();
+      }
+      const auto& codegen_opt = 
func->GetAttr<runtime::String>(relax::attr::kCodegen);
+      if (codegen_opt.defined()) {
+        optype = codegen_opt.value();
+      } else {
+        const auto& name_opt = 
func->GetAttr<runtime::String>(relax::attr::kComposite);
+        ICHECK(name_opt.defined()) << "Unexpected global func without 
composite";
+        optype = name_opt.value();
+      }
     } else if (call_node->op->IsInstance<relax::VarNode>()) {
       ICHECK(target_funcs_.count(call_node->op)) << "Can not find target func: 
" << call_node->op;
       const auto& func = target_funcs_[call_node->op];
@@ -251,7 +260,10 @@ const MSCJoint RelaxGraphBuilder::AddNode(const Expr& 
expr, const Optional<Expr>
   if (const auto* call_node = expr.as<relax::CallNode>()) {
     if (const auto* v_node = call_node->op.as<GlobalVarNode>()) {
       const auto& func = 
Downcast<relax::Function>(ref_module_->Lookup(v_node->name_hint));
-      attrs = RelaxFuncAttrGetter().GetAttrs(func);
+      const auto& byoc_name_opt = func->GetAttr<runtime::String>("byoc_name");
+      if (!byoc_name_opt.defined()) {
+        attrs = RelaxFuncAttrGetter().GetAttrs(func);
+      }
     } else if (call_node->op->IsInstance<relax::VarNode>()) {
       ICHECK(target_funcs_.count(call_node->op)) << "Can not find target func: 
" << call_node->op;
       attrs = RelaxFuncAttrGetter().GetAttrs(target_funcs_[call_node->op]);
diff --git a/src/contrib/msc/core/printer/prototxt_printer.cc 
b/src/contrib/msc/core/printer/prototxt_printer.cc
index e15894a272..7e96c657a7 100644
--- a/src/contrib/msc/core/printer/prototxt_printer.cc
+++ b/src/contrib/msc/core/printer/prototxt_printer.cc
@@ -73,12 +73,12 @@ DictDoc PrototxtPrinter::ToDictDoc(const 
std::vector<std::pair<String, ObjectRef
 
 void PrototxtPrinter::Append(const Map<String, ObjectRef>& dict) {
   DictDoc doc = ToDictDoc(dict);
-  PrintDoc(doc);
+  PrintDoc(doc, false);
 }
 
 void PrototxtPrinter::Append(const std::vector<std::pair<String, ObjectRef>>& 
dict) {
   DictDoc doc = ToDictDoc(dict);
-  PrintDoc(doc);
+  PrintDoc(doc, false);
 }
 
 void PrototxtPrinter::AppendPair(const String& key, const ObjectRef& value) {
@@ -97,7 +97,7 @@ void PrototxtPrinter::PrintTypedDoc(const DictDoc& doc) {
     if (doc->values[i].as<DictDocNode>()) {
       output_ << " {";
       IncreaseIndent();
-      PrintDoc(doc->values[i]);
+      PrintDoc(doc->values[i], false);
       DecreaseIndent();
       NewLine() << "}";
     } else {
diff --git a/src/contrib/msc/core/transform/set_byoc_attrs.cc 
b/src/contrib/msc/core/transform/set_byoc_attrs.cc
new file mode 100644
index 0000000000..4fa8ab584e
--- /dev/null
+++ b/src/contrib/msc/core/transform/set_byoc_attrs.cc
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/contrib/msc/core/transform/set_byoc_attrs.cc
+ * \brief Pass for fuse ShapeExpr.
+ */
+
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/expr.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+
+#include "../../../../relax/transform/utils.h"
+#include "../utils.h"
+
+namespace tvm {
+namespace relax {
+
+using namespace tvm::contrib::msc;
+
+/*!
+ * \brief Fuse Tuple and TupleGetItem to BYOC
+ */
+class ByocNameSetter : public ExprMutator {
+ public:
+  explicit ByocNameSetter(IRModule ctx_module, const String& target, const 
String& entry_name)
+      : ExprMutator(ctx_module) {
+    mod_ = ctx_module;
+    target_ = target;
+    entry_name_ = entry_name;
+  }
+
+  IRModule SetAttrs() {
+    GlobalVar main_var;
+    size_t func_cnt = 0;
+    for (const auto& [gv, func] : mod_->functions) {
+      if (gv->name_hint == entry_name_) {
+        main_var = gv;
+      } else {
+        const auto& name_opt = func->GetAttr<runtime::String>(attr::kCodegen);
+        if (name_opt.defined() && name_opt.value() == target_) {
+          const auto& new_func = WithAttr(Downcast<Function>(func), 
"byoc_name",
+                                          target_ + "_" + 
std::to_string(func_cnt));
+          builder_->UpdateFunction(gv, new_func);
+          func_cnt += 1;
+        }
+      }
+    }
+    return builder_->GetContextIRModule();
+  }
+
+ private:
+  IRModule mod_;
+  String target_;
+  String entry_name_;
+  Map<Function, Function> new_funcs_;
+};
+
+IRModule SetBYOCAttrs(IRModule mod, const String& target, const String& 
entry_name) {
+  return ByocNameSetter(mod, target, entry_name).SetAttrs();
+}
+
+namespace transform {
+
+Pass SetBYOCAttrs(const String& target, const String& entry_name) {
+  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
+      [=](IRModule m, PassContext pc) { return relax::SetBYOCAttrs(m, target, 
entry_name); };
+  return CreateModulePass(pass_func, 0, "SetBYOCAttrs", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.SetBYOCAttrs").set_body_typed(SetBYOCAttrs);
+
+}  // namespace transform
+}  // namespace relax
+}  // namespace tvm
diff --git a/src/contrib/msc/framework/tensorflow/codegen.cc 
b/src/contrib/msc/framework/tensorflow/codegen.cc
index e4b4feb8ca..30f06b43e7 100644
--- a/src/contrib/msc/framework/tensorflow/codegen.cc
+++ b/src/contrib/msc/framework/tensorflow/codegen.cc
@@ -64,6 +64,9 @@ void TensorflowCodeGen::CodeGenGraph() {
   stack_.comment("Define the weights");
   for (const auto& n : graph()->node_names) {
     const auto& node = graph()->FindNode(n);
+    if (node->optype == "nn.batch_norm") {
+      continue;
+    }
     for (const auto& pair : node->weights) {
       stack_.func_call("get_variable", IdxWeightBase(node, pair.first))
           .call_arg(DocUtils::ToStrDoc(pair.second->name))
diff --git a/src/contrib/msc/framework/tensorflow/tf_v1_opcode.cc 
b/src/contrib/msc/framework/tensorflow/tf_v1_opcode.cc
index d17a326424..56dba21ac8 100644
--- a/src/contrib/msc/framework/tensorflow/tf_v1_opcode.cc
+++ b/src/contrib/msc/framework/tensorflow/tf_v1_opcode.cc
@@ -157,6 +157,33 @@ class TFV1AxisCodeGen : public TFV1OpCode {
   String attr_name_;
 };
 
+class TFV1BatchnormCodeGen : public TFV1OpCode {
+  TFV1_OP_CODEGEN_METHODS(TFV1BatchnormCodeGen)
+
+ protected:
+  void CodeGenBuild() final {
+    stack_.op_call()
+        .op_input_arg()
+        .op_arg<bool>("scale")
+        .op_arg<bool>("center")
+        .op_arg<float>("momentum")
+        .op_arg<float>("epsilon")
+        .call_arg("tf_v1.constant_initializer(weights[\"" + 
node()->WeightAt("gamma")->name +
+                      "\"].asnumpy())",
+                  "gamma_initializer")
+        .call_arg("tf_v1.constant_initializer(weights[\"" + 
node()->WeightAt("beta")->name +
+                      "\"].asnumpy())",
+                  "beta_initializer")
+        .call_arg("tf_v1.constant_initializer(weights[\"" + 
node()->WeightAt("mean")->name +
+                      "\"].asnumpy())",
+                  "moving_mean_initializer")
+        .call_arg("tf_v1.constant_initializer(weights[\"" + 
node()->WeightAt("var")->name +
+                      "\"].asnumpy())",
+                  "moving_variance_initializer")
+        .op_name_arg();
+  }
+};
+
 class TFV1BroadcastToCodeGen : public TFV1OpCode {
   TFV1_OP_CODEGEN_METHODS(TFV1BroadcastToCodeGen)
 
@@ -166,6 +193,19 @@ class TFV1BroadcastToCodeGen : public TFV1OpCode {
   }
 };
 
+class TFV1ClipCodeGen : public TFV1OpCode {
+  TFV1_OP_CODEGEN_METHODS(TFV1ClipCodeGen)
+
+ protected:
+  void CodeGenBuild() final {
+    stack_.op_call()
+        .op_input_arg()
+        .op_arg<float>("min", "clip_value_min")
+        .op_arg<float>("max", "clip_value_max")
+        .op_name_arg();
+  }
+};
+
 class TFV1ConcatCodeGen : public TFV1OpCode {
   TFV1_OP_CODEGEN_METHODS(TFV1ConcatCodeGen)
 
@@ -542,6 +582,7 @@ const std::shared_ptr<std::unordered_map<String, 
std::shared_ptr<TFV1OpCode>>> G
   map->emplace("argmin", 
std::make_shared<TFV1ArgMaxMinCodeGen>("tf_v1.argmin"));
   map->emplace("astype", std::make_shared<TFV1AstypeCodeGen>("tf_v1.cast"));
   map->emplace("broadcast_to", 
std::make_shared<TFV1BroadcastToCodeGen>("tf_v1.broadcast_to"));
+  map->emplace("clip", 
std::make_shared<TFV1ClipCodeGen>("tf_v1.clip_by_value"));
   map->emplace("concat", 
std::make_shared<TFV1ConcatCodeGen>("ops.array_ops.concat_v2"));
   map->emplace("concatenate", 
std::make_shared<TFV1ConcatCodeGen>("ops.array_ops.concat_v2"));
   map->emplace("einsum", std::make_shared<TFV1EinsumCodeGen>("tf_v1.einsum"));
@@ -555,6 +596,8 @@ const std::shared_ptr<std::unordered_map<String, 
std::shared_ptr<TFV1OpCode>>> G
 
   // nn ops
   map->emplace("nn.avg_pool2d", 
std::make_shared<TFV1Pool2dCodeGen>("ops.nn_ops.pool"));
+  map->emplace("nn.batch_norm",
+               
std::make_shared<TFV1BatchnormCodeGen>("tf_v1.layers.batch_normalization"));
   map->emplace("nn.conv2d", 
std::make_shared<TFV1ConvCodeGen>("ops.nn_ops.conv2d", false));
   map->emplace("nn.max_pool2d", 
std::make_shared<TFV1Pool2dCodeGen>("ops.nn_ops.pool"));
   map->emplace("nn.pad", std::make_shared<TFV1PadCodeGen>("tf_v1.pad"));
@@ -569,7 +612,8 @@ const std::shared_ptr<std::unordered_map<String, 
std::shared_ptr<TFV1OpCode>>> G
   map->emplace("tuple", std::make_shared<TFV1TupleCodeGen>("tuple"));
 
   // msc ops
-  map->emplace("msc.conv2d", 
std::make_shared<TFV1ConvCodeGen>("ops.nn_ops.conv2d", true));
+  map->emplace("msc.conv2d", 
std::make_shared<TFV1ConvCodeGen>("ops.nn_ops.conv2d", false));
+  map->emplace("msc.conv2d_bias", 
std::make_shared<TFV1ConvCodeGen>("ops.nn_ops.conv2d", true));
 
   return map;
 }
diff --git a/src/contrib/msc/framework/tensorrt/codegen.cc 
b/src/contrib/msc/framework/tensorrt/codegen.cc
index b8b2335da1..a59697377a 100644
--- a/src/contrib/msc/framework/tensorrt/codegen.cc
+++ b/src/contrib/msc/framework/tensorrt/codegen.cc
@@ -494,9 +494,12 @@ Array<runtime::Module> MSCTensorRTCompiler(Array<Function> 
functions,
   Array<runtime::Module> compiled_functions;
   for (const auto& func : functions) {
     VLOG(1) << "MSC.TensorRT partition:" << std::endl << func;
+    const auto& byoc_name_opt = func->GetAttr<runtime::String>("byoc_name");
+    ICHECK(byoc_name_opt.defined()) << "Can not find byoc_name from attrs";
+    const auto& byoc_name = byoc_name_opt.value();
     std::string func_name = GetExtSymbol(func);
-    ICHECK(target_option.count(func_name)) << "Can not find target option for 
" << func_name;
-    const auto& options = Downcast<String>(target_option[func_name]);
+    ICHECK(target_option.count(byoc_name)) << "Can not find target option for 
" << byoc_name;
+    const auto& options = Downcast<String>(target_option[byoc_name]);
     MSCJSONSerializer serializer(constant_names, options);
     serializer.serialize(func);
     std::string graph_json = serializer.GetJSON();
diff --git a/src/contrib/msc/framework/tensorrt/codegen_utils.h 
b/src/contrib/msc/framework/tensorrt/codegen_utils.h
index 8249444d9d..d598396f6f 100644
--- a/src/contrib/msc/framework/tensorrt/codegen_utils.h
+++ b/src/contrib/msc/framework/tensorrt/codegen_utils.h
@@ -46,7 +46,7 @@ class TensorRTCodeGenHelper : public BaseCodeGenHelper {
       return "*" + IdxNodeBase(pair.first, prefix, suffix);
     }
     if (pair.first->optype == "tuple" || pair.first->optype == "get_item") {
-      return IdxNodeBase(pair.first, prefix, suffix);
+      return "*" + IdxNodeBase(pair.first, prefix, suffix);
     }
     return "*" + IdxOutputBase(pair.first, prefix, pair.second, suffix);
   }
diff --git a/tests/python/contrib/test_msc/test_runner.py 
b/tests/python/contrib/test_msc/test_runner.py
index 4653a2225b..3f2d0d0c90 100644
--- a/tests/python/contrib/test_msc/test_runner.py
+++ b/tests/python/contrib/test_msc/test_runner.py
@@ -22,10 +22,16 @@ import numpy as np
 
 import torch
 from torch import fx
+from tvm.contrib.msc.framework.tensorflow import tf_v1
 
 import tvm.testing
+import tvm.relay.testing.tf as tf_testing
 from tvm.relax.frontend.torch import from_fx
 from tvm.contrib.msc.framework.tvm.runtime import TVMRunner
+from tvm.contrib.msc.framework.torch.runtime import TorchRunner
+from tvm.contrib.msc.framework.tensorrt.runtime import TensorRTRunner
+from tvm.contrib.msc.framework.tensorflow.frontend import from_tensorflow
+from tvm.contrib.msc.framework.tensorflow.runtime import TensorflowRunner
 from tvm.contrib.msc.core import utils as msc_utils
 
 requires_tensorrt = pytest.mark.skipif(
@@ -51,11 +57,32 @@ def _get_torch_model(name, is_training=False):
         return None
 
 
+def _get_tf_graph():
+    """Get tensorflow graphdef"""
+
+    try:
+        tf_graph = tf_v1.Graph()
+        with tf_graph.as_default():
+            graph_def = tf_testing.get_workload(
+                
"https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.4_224.tgz";,
+                "mobilenet_v2_1.4_224_frozen.pb",
+            )
+            # Call the utility to import the graph definition into default 
graph.
+            graph_def = tf_testing.ProcessGraphDefParam(graph_def)
+        return tf_graph, graph_def
+    except:  # pylint: disable=bare-except
+        print("please install tensorflow package")
+        return None, None
+
+
 def _test_from_torch(runner_cls, device, is_training=False, atol=1e-3, 
rtol=1e-3):
     """Test runner from torch model"""
+
     torch_model = _get_torch_model("resnet50", is_training)
     if torch_model:
         workspace = msc_utils.set_workspace()
+        log_path = workspace.relpath("MSC_LOG", keep_history=False)
+        msc_utils.set_global_logger("info", log_path)
         input_info = [([1, 3, 224, 224], "float32")]
         datas = [np.random.rand(*i[0]).astype(i[1]) for i in input_info]
         torch_datas = [torch.from_numpy(d) for d in datas]
@@ -85,5 +112,49 @@ def test_tvm_runner_gpu():
     _test_from_torch(TVMRunner, "cuda", is_training=True)
 
 
+def test_torch_runner_cpu():
+    """Test runner for torch on cpu"""
+
+    _test_from_torch(TorchRunner, "cpu")
+
+
[email protected]_gpu
+def test_torch_runner_gpu():
+    """Test runner for torch on cuda"""
+
+    _test_from_torch(TorchRunner, "cuda", atol=1e-2, rtol=1e-2)
+
+
+@requires_tensorrt
+def test_tensorrt_runner():
+    """Test runner for tensorrt"""
+
+    _test_from_torch(TensorRTRunner, "cuda", atol=1e-2, rtol=1e-2)
+
+
+def test_tensorflow_runner():
+    """Test runner from tf graph"""
+
+    tf_graph, graph_def = _get_tf_graph()
+    if tf_graph and graph_def:
+        workspace = msc_utils.set_workspace()
+        log_path = workspace.relpath("MSC_LOG", keep_history=False)
+        msc_utils.set_global_logger("info", log_path)
+        data = np.random.uniform(size=(1, 224, 224, 3)).astype("float32")
+        out_name = "MobilenetV2/Predictions/Reshape_1:0"
+        # get golden
+        with tf_v1.Session(graph=tf_graph) as sess:
+            golden = sess.run([out_name], {"input:0": data})
+        # get outputs
+        shape_dict = {"input": data.shape}
+        mod, _ = from_tensorflow(graph_def, shape_dict, [out_name], 
as_msc=False)
+        runner = TensorflowRunner(mod)
+        runner.build()
+        outputs = runner.run([data], ret_type="list")
+        for gol_r, out_r in zip(golden, outputs):
+            tvm.testing.assert_allclose(gol_r, out_r, atol=1e-3, rtol=1e-3)
+        workspace.destory()
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/contrib/test_msc/test_translate_relay.py 
b/tests/python/contrib/test_msc/test_translate_relay.py
index e4c8dbf3b5..aca0f26890 100644
--- a/tests/python/contrib/test_msc/test_translate_relay.py
+++ b/tests/python/contrib/test_msc/test_translate_relay.py
@@ -27,7 +27,7 @@ from torch.nn import Module
 import tvm.testing
 from tvm.relax.frontend.torch import from_fx
 from tvm.relay.frontend import from_pytorch
-from tvm.contrib.msc.core.ir import translate
+from tvm.contrib.msc.core.frontend import translate
 from tvm.contrib.msc.framework.tvm import codegen as tvm_codegen
 
 

Reply via email to