This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new e37165f40f [Unity][MSC][M1.5-1.7] Add Runner and test with torch,
tensorflow && tensorrt (#16072)
e37165f40f is described below
commit e37165f40f51d8cf0b8a9aef22314aed576760b1
Author: Archermmt <[email protected]>
AuthorDate: Thu Nov 9 10:50:13 2023 +0800
[Unity][MSC][M1.5-1.7] Add Runner and test with torch, tensorflow &&
tensorrt (#16072)
* update runners
* format fix
* format fix
---
python/tvm/contrib/msc/core/frontend/translate.py | 3 +-
python/tvm/contrib/msc/core/runtime/runner.py | 34 ++--
python/tvm/contrib/msc/core/transform/transform.py | 18 ++
python/tvm/contrib/msc/core/utils/file.py | 8 +-
python/tvm/contrib/msc/core/utils/info.py | 8 +-
.../msc/framework/tensorflow/frontend/translate.py | 3 +-
.../msc/framework/tensorflow/runtime/__init__.py | 19 ++
.../msc/framework/tensorflow/runtime/runner.py | 217 +++++++++++++++++++++
.../msc/framework/tensorrt/codegen/codegen.py | 4 +-
.../msc/framework/tensorrt/runtime/__init__.py | 19 ++
.../msc/framework/tensorrt/runtime/runner.py | 45 +++++
.../msc/framework/torch/runtime/__init__.py | 19 ++
.../contrib/msc/framework/torch/runtime/runner.py | 197 +++++++++++++++++++
.../contrib/msc/framework/tvm/runtime/runner.py | 8 +-
src/contrib/msc/core/codegen/codegen_utils.cc | 2 +-
src/contrib/msc/core/ir/graph.cc | 3 +
src/contrib/msc/core/ir/graph_builder.cc | 22 ++-
src/contrib/msc/core/printer/prototxt_printer.cc | 6 +-
src/contrib/msc/core/transform/set_byoc_attrs.cc | 92 +++++++++
src/contrib/msc/framework/tensorflow/codegen.cc | 3 +
.../msc/framework/tensorflow/tf_v1_opcode.cc | 46 ++++-
src/contrib/msc/framework/tensorrt/codegen.cc | 7 +-
src/contrib/msc/framework/tensorrt/codegen_utils.h | 2 +-
tests/python/contrib/test_msc/test_runner.py | 71 +++++++
.../contrib/test_msc/test_translate_relay.py | 2 +-
25 files changed, 811 insertions(+), 47 deletions(-)
diff --git a/python/tvm/contrib/msc/core/frontend/translate.py
b/python/tvm/contrib/msc/core/frontend/translate.py
index e1dce2ae28..1e7efe6489 100644
--- a/python/tvm/contrib/msc/core/frontend/translate.py
+++ b/python/tvm/contrib/msc/core/frontend/translate.py
@@ -315,6 +315,7 @@ def byoc_partition(
msc_transform.BindShape(),
msc_transform.FuseTuple(target),
tvm.relax.transform.MergeCompositeFunctions(),
+ msc_transform.SetBYOCAttrs(target),
msc_transform.SetExprName(target=target),
msc_transform.SetExprLayout(trans_config.get("allow_layout_missing", True)),
]
@@ -335,7 +336,7 @@ def byoc_partition(
graphs_info, all_weights = [], _ffi_api.GetRelaxWeights(msc_mod, entry)
for name in func_names:
- build_config.update({"graph_name": name, "byoc_entry": name})
+ build_config.update({"graph_name": msc_mod[name].attrs["byoc_name"],
"byoc_entry": name})
graph = _ffi_api.BuildFromRelax(msc_mod, entry,
msc_utils.dump_dict(build_config))
graphs_info.append((graph, normalize_weights(all_weights, graph)))
return _partition_mod(mod, False), graphs_info
diff --git a/python/tvm/contrib/msc/core/runtime/runner.py
b/python/tvm/contrib/msc/core/runtime/runner.py
index 65e86e4896..3c8212f02d 100644
--- a/python/tvm/contrib/msc/core/runtime/runner.py
+++ b/python/tvm/contrib/msc/core/runtime/runner.py
@@ -94,7 +94,7 @@ class BaseRunner(object):
self._model, self._model_info = None, {}
self._runnable = None
- def build(self, cache_dir: msc_utils.MSCDirectory = None, build_graph:
bool = False) -> object:
+ def build(self, cache_dir: msc_utils.MSCDirectory = None, build_graph:
bool = False) -> Any:
"""Build the runnable object
Parameters
@@ -106,7 +106,7 @@ class BaseRunner(object):
Returns
-------
- runnable: object
+ runnable: Any
The runnable object.
"""
@@ -319,18 +319,18 @@ class BaseRunner(object):
raise NotImplementedError("_save_graphs is not implemented for " +
str(self.__class__))
- def _generate_model(self) -> object:
+ def _generate_model(self) -> Any:
"""Codegen the model according to framework
Returns
-------
- model: object
+ model: Any
The meta model
"""
raise NotImplementedError("_load is not implemented for " +
str(self.__class__))
- def _load_model(self, cache_dir: msc_utils.MSCDirectory, cache_info: dict)
-> object:
+ def _load_model(self, cache_dir: msc_utils.MSCDirectory, cache_info: dict)
-> Any:
"""Load the model from cache
Parameters
@@ -342,7 +342,7 @@ class BaseRunner(object):
Returns
-------
- model: object
+ model: Any
The meta model
"""
@@ -365,12 +365,12 @@ class BaseRunner(object):
# disable save model by default
return {}
- def _to_runnable(self, model: object, device: str, is_training: bool) ->
object:
+ def _to_runnable(self, model: Any, device: str, is_training: bool) -> Any:
"""Build runnable object
Parameters
-------
- model: object
+ model: Any
The meta model.
device: str
The device for place model
@@ -379,13 +379,13 @@ class BaseRunner(object):
Returns
-------
- runnable: object
+ runnable: Any
The runnable
"""
raise NotImplementedError("_to_runnable is not implemented for " +
str(self.__class__))
- def _load_runnable(self, cache_dir: msc_utils.MSCDirectory, cache_info:
dict) -> object:
+ def _load_runnable(self, cache_dir: msc_utils.MSCDirectory, cache_info:
dict) -> Any:
"""Load the runnable from cache
Parameters
@@ -397,7 +397,7 @@ class BaseRunner(object):
Returns
-------
- runnable: object
+ runnable: Any
The runnable
"""
@@ -432,7 +432,7 @@ class BaseRunner(object):
raise NotImplementedError("_inspect_model is not implemented for " +
str(self.__class__))
def _call_runnable(
- self, runnable: object, inputs: Dict[str, np.ndarray], device: str
+ self, runnable: Any, inputs: Dict[str, np.ndarray], device: str
) -> Union[List[np.ndarray], Dict[str, np.ndarray]]:
"""Call the runnable to get outputs
@@ -558,12 +558,12 @@ class ModelRunner(BaseRunner):
f_params.write(tvm.runtime.save_param_dict(self._weights[0]))
return {"main": main_info}
- def _generate_model(self) -> object:
+ def _generate_model(self) -> Any:
"""Codegen the model according to framework
Returns
-------
- model: object
+ model: Any
The runnable model
"""
@@ -719,12 +719,12 @@ class BYOCRunner(BaseRunner):
output_folder=self._load_config.get("output_folder",
msc_utils.get_output_dir()),
)
- def _to_runnable(self, model: object, device: str, is_training: bool) ->
object:
+ def _to_runnable(self, model: Any, device: str, is_training: bool) -> Any:
"""Build runnable object
Parameters
-------
- model: object
+ model: Any
The runnable model on cpu.
device: str
The device for place model
@@ -733,7 +733,7 @@ class BYOCRunner(BaseRunner):
Returns
-------
- runnable: object
+ runnable: Any
The runnable
"""
diff --git a/python/tvm/contrib/msc/core/transform/transform.py
b/python/tvm/contrib/msc/core/transform/transform.py
index 24f7d38426..8bd4ca9521 100644
--- a/python/tvm/contrib/msc/core/transform/transform.py
+++ b/python/tvm/contrib/msc/core/transform/transform.py
@@ -118,3 +118,21 @@ def FuseTuple(target, entry_name: str = "main") ->
tvm.ir.transform.Pass:
"""
return relax_api.FuseTuple(target, entry_name) # type: ignore
+
+
+def SetBYOCAttrs(target, entry_name: str = "main") -> tvm.ir.transform.Pass:
+ """set attributes for byoc
+
+ Parameters
+ ----------
+ target: str
+ The byoc target name
+ entry_name: str
+ The entry name
+
+ Returns
+ -------
+ ret: tvm.ir.transform.Pass
+ """
+
+ return relax_api.SetBYOCAttrs(target, entry_name) # type: ignore
diff --git a/python/tvm/contrib/msc/core/utils/file.py
b/python/tvm/contrib/msc/core/utils/file.py
index 278d9d56b9..88808c61d2 100644
--- a/python/tvm/contrib/msc/core/utils/file.py
+++ b/python/tvm/contrib/msc/core/utils/file.py
@@ -21,7 +21,7 @@ import shutil
import tempfile
import types
from functools import partial
-from typing import List
+from typing import List, Any
from importlib.machinery import SourceFileLoader
from .namespace import MSCMap, MSCKey, MSCFramework
@@ -109,7 +109,7 @@ class MSCDirectory(object):
f.write(contains)
return file_path
- def move_file(self, src_file: str, dst_folder: object, dst_file: str =
None):
+ def move_file(self, src_file: str, dst_folder: Any, dst_file: str = None):
"""Move a file to another folder
Parameters
@@ -133,7 +133,7 @@ class MSCDirectory(object):
os.rename(src_path, dst_path)
return dst_path
- def copy_file(self, src_file: str, dst_folder: object, dst_file: str =
None):
+ def copy_file(self, src_file: str, dst_folder: Any, dst_file: str = None):
"""Copy a file to another folder
Parameters
@@ -157,7 +157,7 @@ class MSCDirectory(object):
shutil.copy2(src_path, dst_path)
return dst_path
- def create_dir(self, name: str, keep_history: bool = True, cleanup: bool =
False) -> object:
+ def create_dir(self, name: str, keep_history: bool = True, cleanup: bool =
False) -> Any:
"""Add a dir under the folder
Parameters
diff --git a/python/tvm/contrib/msc/core/utils/info.py
b/python/tvm/contrib/msc/core/utils/info.py
index 894447b169..6053d8ddc8 100644
--- a/python/tvm/contrib/msc/core/utils/info.py
+++ b/python/tvm/contrib/msc/core/utils/info.py
@@ -19,7 +19,7 @@
import os
import json
import copy
-from typing import List, Tuple, Dict
+from typing import List, Tuple, Dict, Any
from distutils.version import LooseVersion
import numpy as np
@@ -36,13 +36,13 @@ class MSCArray(object):
The data object.
"""
- def __init__(self, data: object):
+ def __init__(self, data: Any):
self._type, self._data = self._analysis(data)
def __str__(self):
return "<{}>{}".format(self._type, self.abstract())
- def _analysis(self, data: object) -> Tuple[str, np.ndarray]:
+ def _analysis(self, data: Any) -> Tuple[str, np.ndarray]:
if isinstance(data, np.ndarray):
return "np", data
if isinstance(data, tvm.runtime.NDArray):
@@ -76,7 +76,7 @@ class MSCArray(object):
return self._data
-def cast_array(data: object):
+def cast_array(data: Any):
"""Cast array like object to np.ndarray
Parameters
diff --git a/python/tvm/contrib/msc/framework/tensorflow/frontend/translate.py
b/python/tvm/contrib/msc/framework/tensorflow/frontend/translate.py
index dc97a315d0..dab19ca81f 100644
--- a/python/tvm/contrib/msc/framework/tensorflow/frontend/translate.py
+++ b/python/tvm/contrib/msc/framework/tensorflow/frontend/translate.py
@@ -25,6 +25,7 @@ from tvm.contrib.msc.core import transform as msc_transform
from tvm.contrib.msc.core.frontend import from_relax
from tvm.contrib.msc.core.codegen import relay_to_relax
from tvm.contrib.msc.framework.tensorflow import tf_v1
+from tvm.contrib.msc.core import utils as msc_utils
def from_tensorflow(
@@ -75,7 +76,7 @@ def from_tensorflow(
relax_mod = relay_to_relax(relay_mod, params, trans_config, build_config,
opt_config)
if not as_msc:
return relax_mod, params
- build_config = build_config or {}
+ build_config = msc_utils.copy_dict(build_config)
build_config["use_var_name"] = True
graph, weights = from_relax(relax_mod, trans_config=trans_config,
build_config=build_config)
return graph, weights
diff --git a/python/tvm/contrib/msc/framework/tensorflow/runtime/__init__.py
b/python/tvm/contrib/msc/framework/tensorflow/runtime/__init__.py
new file mode 100644
index 0000000000..de9dea8e4b
--- /dev/null
+++ b/python/tvm/contrib/msc/framework/tensorflow/runtime/__init__.py
@@ -0,0 +1,19 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""tvm.contrib.msc.framework.tensorflow.runtime"""
+
+from .runner import *
diff --git a/python/tvm/contrib/msc/framework/tensorflow/runtime/runner.py
b/python/tvm/contrib/msc/framework/tensorflow/runtime/runner.py
new file mode 100644
index 0000000000..e2c2e919ff
--- /dev/null
+++ b/python/tvm/contrib/msc/framework/tensorflow/runtime/runner.py
@@ -0,0 +1,217 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=not-context-manager
+"""tvm.contrib.msc.framework.tensorflow.runtime.runner"""
+
+import time
+from typing import Dict, List, Union, Any
+import numpy as np
+
+from tensorflow.python.client import device_lib
+from tensorflow.python.ops import variables
+
+from tvm.contrib.msc.core.runtime import ModelRunner
+from tvm.contrib.msc.core.utils.namespace import MSCFramework
+from tvm.contrib.msc.framework.tensorflow.codegen import to_tensorflow
+from tvm.contrib.msc.framework.tensorflow import tf_v1
+
+
+class WrapSession(tf_v1.Session):
+ """Wrapped session for MSC"""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._inputs, self._outputs = None, None
+
+ def set_bindings(self, inputs: List[Dict[str, str]], outputs:
List[Dict[str, str]]):
+ """Set inputs and outputs for session
+
+ Parameters
+ -------
+ inputs: list
+ The inputs info of the model.
+ outputs: list
+ The outputs info of the model.
+ """
+
+ self._inputs = inputs
+ self._outputs = outputs
+
+ def run(self, fetches, *args, **kwargs): # pylint:
disable=useless-parent-delegation
+ return super().run(fetches, *args, **kwargs)
+
+
+class TensorflowRunner(ModelRunner):
+ """Runner of Tensorflow"""
+
+ def setup(self):
+ """Setup the runner"""
+
+ super().setup()
+ self._tf_graph = None
+ self._tf_outputs = None
+ self._session = None
+
+ def destory(self):
+ """Destory runner"""
+
+ self._session.close()
+ del self._tf_graph
+ del self._tf_outputs
+ del self._session
+ super().destory()
+
+ def _generate_model(self) -> Any:
+ """Codegen the model according to framework
+
+ Returns
+ -------
+ model: Any
+ The runnable model
+ """
+
+ if self._tf_graph:
+ del self._tf_graph
+ self._tf_graph = tf_v1.Graph()
+ with self._tf_graph.as_default():
+ self._tf_outputs = super()._generate_model()
+ return self._tf_graph
+
+ def _to_runnable(self, model: Any, device: str, is_training: bool) -> Any:
+ """Build runnable object
+
+ Parameters
+ -------
+ model: Any
+ The meta model.
+ device: str
+ The device for place model
+ is_training: bool
+ Whether to load model for training
+
+ Returns
+ -------
+ runnable: Any
+ The runnable
+ """
+
+ if self._session:
+ self._session.close()
+ del self._session
+ self._session = WrapSession(graph=self._tf_graph)
+ self._session.set_bindings(self.get_inputs(), self.get_outputs())
+ with self._tf_graph.as_default():
+ self._session.run(variables.global_variables_initializer())
+ return self._session
+
+ def _call_runnable(
+ self, runnable: WrapSession, inputs: Dict[str, np.ndarray], device: str
+ ) -> Union[List[np.ndarray], Dict[str, np.ndarray]]:
+ """Call the runnable to get outputs
+
+ Parameters
+ -------
+ runnable: WrapSession
+ The wrapped session.
+ inputs: dict<str, data>
+ The inputs in dict.
+ device: str
+ The device.
+
+ Returns
+ -------
+ outputs: list<data> or dict<str, data>
+ The outputs in list or dict.
+ """
+
+ feed_dict = {i["name"] + ":0": inputs[i["name"]] for i in
self.get_inputs()}
+ return runnable.run(self._tf_outputs, feed_dict)
+
+ def _device_enabled(self, device: str) -> bool:
+ """Check if the device is enabled
+
+ Returns
+ -------
+ enabled: bool
+ Whether the device is enabled.
+ """
+
+ if device == "cpu":
+ return True
+ if device.startswith("cuda"):
+ device_protos = device_lib.list_local_devices()
+ return any(dev.device_type == "GPU" for dev in device_protos)
+ return False
+
+ @property
+ def codegen_func(self):
+ return to_tensorflow
+
+ @property
+ def framework(self):
+ return MSCFramework.TENSORFLOW
+
+ @classmethod
+ def run_native(
+ cls,
+ model: tf_v1.GraphDef,
+ inputs: Dict[str, np.ndarray],
+ input_names: List[str],
+ output_names: List[str],
+ warm_up: int = 10,
+ repeat: int = 0,
+ ) -> Dict[str, np.ndarray]:
+ """Run the datas and get outputs
+
+ Parameters
+ -------
+ model: tf_v1.GraphDef
+ The graph def.
+ inputs: dict<str, data>
+ The inputs in dict.
+ input_names: list<str>
+ The input names.
+ output_names: list<str>
+ The outut names.
+ warm_up: int
+ The warm_up num for profile.
+ repeat: int
+ The repeat num for profile.
+
+
+ Returns
+ -------
+ outputs: dict<str, np.array>
+ The outputs in dict.
+ """
+
+ feed_dict = {i_name + ":0": inputs[i_name] for i_name in input_names}
+ with tf_v1.Graph().as_default():
+ tf_v1.import_graph_def(model, name="")
+ with tf_v1.Session() as sess:
+ if repeat > 0:
+ for _ in range(warm_up):
+ outputs = sess.run(output_names, feed_dict)
+ start = time.time()
+ for _ in range(repeat):
+ outputs = sess.run(output_names, feed_dict)
+ avg_time = (time.time() - start) * 1000 / repeat
+ else:
+ outputs = sess.run(output_names, feed_dict)
+ avg_time = -1
+ outputs = dict(zip(output_names, outputs))
+ return outputs, avg_time
diff --git a/python/tvm/contrib/msc/framework/tensorrt/codegen/codegen.py
b/python/tvm/contrib/msc/framework/tensorrt/codegen/codegen.py
index be233d9465..5539e614bf 100644
--- a/python/tvm/contrib/msc/framework/tensorrt/codegen/codegen.py
+++ b/python/tvm/contrib/msc/framework/tensorrt/codegen/codegen.py
@@ -24,8 +24,8 @@ import numpy as np
import tvm
from tvm.contrib.msc.core.ir import MSCGraph
from tvm.contrib.msc.core.codegen import CodeGen
-from tvm.contrib.msc.core import utils as msc_utils
from tvm.contrib.msc.core.utils import MSCFramework
+from tvm.contrib.msc.core import utils as msc_utils
from tvm.contrib.msc.framework.tensorrt import _ffi_api
from .sources import get_trt_sources
from .utils import write_weight
@@ -62,7 +62,7 @@ def to_sub_tensorrt(
The engine file.
"""
- codegen_config = codegen_config or {}
+ codegen_config = msc_utils.copy_dict(codegen_config)
codegen_config["version"] = msc_utils.get_version(MSCFramework.TENSORRT)
if "tensorrt_root" not in codegen_config:
codegen_config["tensorrt_root"] = _ffi_api.GetTensorRTRoot()
diff --git a/python/tvm/contrib/msc/framework/tensorrt/runtime/__init__.py
b/python/tvm/contrib/msc/framework/tensorrt/runtime/__init__.py
new file mode 100644
index 0000000000..56203292e4
--- /dev/null
+++ b/python/tvm/contrib/msc/framework/tensorrt/runtime/__init__.py
@@ -0,0 +1,19 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""tvm.contrib.msc.framework.tensorrt.runtime"""
+
+from .runner import *
diff --git a/python/tvm/contrib/msc/framework/tensorrt/runtime/runner.py
b/python/tvm/contrib/msc/framework/tensorrt/runtime/runner.py
new file mode 100644
index 0000000000..88c45c786b
--- /dev/null
+++ b/python/tvm/contrib/msc/framework/tensorrt/runtime/runner.py
@@ -0,0 +1,45 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""tvm.contrib.msc.framework.tensorrt.runtime.runner"""
+
+from tvm.contrib.msc.core.runtime import BYOCRunner
+from tvm.contrib.msc.core.utils.namespace import MSCFramework
+from tvm.contrib.msc.framework.tensorrt.frontend import partition_for_tensorrt
+from tvm.contrib.msc.framework.tensorrt.codegen import to_tensorrt
+
+
+class TensorRTRunner(BYOCRunner):
+ """Runner of tensorrt"""
+
+ def setup(self):
+ """Setup the runner"""
+
+ super().setup()
+ if not self._device.startswith("cuda"):
+ self._device = "cuda"
+
+ @property
+ def codegen_func(self):
+ return to_tensorrt
+
+ @property
+ def partition_func(self):
+ return partition_for_tensorrt
+
+ @property
+ def framework(self):
+ return MSCFramework.TENSORRT
diff --git a/python/tvm/contrib/msc/framework/torch/runtime/__init__.py
b/python/tvm/contrib/msc/framework/torch/runtime/__init__.py
new file mode 100644
index 0000000000..83a1830b29
--- /dev/null
+++ b/python/tvm/contrib/msc/framework/torch/runtime/__init__.py
@@ -0,0 +1,19 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""tvm.contrib.msc.framework.torch.runtime"""
+
+from .runner import *
diff --git a/python/tvm/contrib/msc/framework/torch/runtime/runner.py
b/python/tvm/contrib/msc/framework/torch/runtime/runner.py
new file mode 100644
index 0000000000..a4f65b4fe1
--- /dev/null
+++ b/python/tvm/contrib/msc/framework/torch/runtime/runner.py
@@ -0,0 +1,197 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""tvm.contrib.msc.framework.torch.runtime.runner"""
+
+import time
+from typing import Dict, List, Union, Tuple, Any
+import numpy as np
+
+import torch
+import tvm
+from tvm.contrib.msc.core.runtime import ModelRunner
+from tvm.contrib.msc.core.ir import MSCGraph
+from tvm.contrib.msc.core.utils.namespace import MSCFramework
+from tvm.contrib.msc.framework.torch.codegen import to_torch
+from tvm.contrib.msc.framework.torch.frontend import set_weight_alias
+from tvm.contrib.msc.core import utils as msc_utils
+
+
+class TorchRunner(ModelRunner):
+ """Runner of Torch"""
+
+ def _translate(self) -> Tuple[List[MSCGraph], Dict[str, tvm.nd.array]]:
+ """Translate IRModule to MSCgraphs
+
+ Returns
+ -------
+ graph_list: list<MSCGraph>
+ The translated graphs
+ weights_list: list<dict<str, tvm.nd.array>>
+ The translated weights
+ """
+ graphs, weights = super()._translate()
+ return [set_weight_alias(graphs[0])], weights
+
+ def _to_runnable(self, model: Any, device: str, is_training: bool) -> Any:
+ """Build runnable object
+
+ Parameters
+ -------
+ model: Any
+ The meta model.
+ device: str
+ The device for place model
+ is_training: bool
+ Whether to load model for training
+
+ Returns
+ -------
+ runnable: Any
+ The runnable
+ """
+
+ if device == "cpu":
+ pass
+ elif device.startswith("cuda"):
+ model = model.to(torch.device(device))
+ else:
+ raise NotImplementedError("Unsupported device " + str(device))
+ if is_training:
+ model = model.train()
+ else:
+ model = model.eval()
+ return model
+
+ def _call_runnable(
+ self, runnable: torch.nn.Module, inputs: Dict[str, np.ndarray],
device: str
+ ) -> Union[List[np.ndarray], Dict[str, np.ndarray]]:
+ """Call the runnable to get outputs
+
+ Parameters
+ -------
+ runnable: torch.nn.Module
+ The runnable model.
+ inputs: dict<str, data>
+ The inputs in dict.
+ device: str
+ The device.
+
+ Returns
+ -------
+ outputs: list<torch.Tensor>
+ The outputs in list.
+ """
+
+ model_inputs = self.get_inputs()
+ parameters = list(runnable.parameters())
+ if parameters:
+ in_dev = parameters[0].device
+ elif device == "cpu":
+ in_dev = torch.device(device)
+ elif device.startswith("cuda"):
+ in_dev = torch.device(device)
+ else:
+ raise NotImplementedError("Unsupported device " + str(device))
+ torch_inputs = [torch.from_numpy(inputs[i["name"]]).to(in_dev) for i
in model_inputs]
+ return runnable(*torch_inputs)
+
+ def _device_enabled(self, device: str) -> bool:
+ """Check if the device is enabled
+
+ Returns
+ -------
+ enabled: bool
+ Whether the device is enabled.
+ """
+
+ if device == "cpu":
+ return True
+ if device.startswith("cuda"):
+ return torch.cuda.is_available()
+ return False
+
+ @property
+ def codegen_func(self):
+ return to_torch
+
+ @property
+ def framework(self):
+ return MSCFramework.TORCH
+
+ @classmethod
+ def run_native(
+ cls,
+ model: torch.nn.Module,
+ inputs: Dict[str, np.ndarray],
+ input_names: List[str],
+ output_names: List[str],
+ warm_up: int = 10,
+ repeat: int = 0,
+ ) -> Dict[str, np.ndarray]:
+ """Run the datas and get outputs
+
+ Parameters
+ -------
+ model: torch.nn.Module
+ The runnable model.
+ inputs: dict<str, data>
+ The inputs in dict.
+ input_names: list<str>
+ The input names.
+ output_names: list<str>
+ The outut names.
+ warm_up: int
+ The warm_up num for profile.
+ repeat: int
+ The repeat num for profile.
+
+ Returns
+ -------
+ outputs: dict<str, np.array>
+ The outputs in dict.
+ """
+
+ parameters = list(model.parameters())
+ if parameters:
+ device = parameters[0].device
+ else:
+ device = torch.device("cpu")
+
+ def _run_once():
+ torch_inputs = [torch.from_numpy(inputs[i_name]).to(device) for
i_name in input_names]
+ return model(*torch_inputs)
+
+ if repeat > 0:
+ for _ in range(warm_up):
+ _run_once()
+ start = time.time()
+ for _ in range(repeat):
+ outputs = _run_once()
+ avg_time = (time.time() - start) * 1000 / repeat
+ else:
+ outputs = _run_once()
+ avg_time = -1
+ if isinstance(outputs, torch.Tensor):
+ assert len(output_names) == 1, "Expect 1 outputs, get " +
str(output_names)
+ return {output_names[0]: msc_utils.cast_array(outputs)}, avg_time
+ assert len(output_names) == len(outputs), "Outputs mismatch, {} with
{}".format(
+ output_names, len(outputs)
+ )
+ outputs = {
+ o_name: msc_utils.cast_array(o_data) for o_name, o_data in
zip(output_names, outputs)
+ }
+ return outputs, avg_time
diff --git a/python/tvm/contrib/msc/framework/tvm/runtime/runner.py
b/python/tvm/contrib/msc/framework/tvm/runtime/runner.py
index 90ba8e4cce..c5240ca229 100644
--- a/python/tvm/contrib/msc/framework/tvm/runtime/runner.py
+++ b/python/tvm/contrib/msc/framework/tvm/runtime/runner.py
@@ -16,7 +16,7 @@
# under the License.
"""tvm.contrib.msc.framework.runtime.tvm.runner"""
-from typing import Dict, List, Union
+from typing import Dict, List, Union, Any
import numpy as np
import tvm
@@ -28,12 +28,12 @@ from tvm.contrib.msc.framework.tvm.codegen import to_relax
class TVMRunner(ModelRunner):
"""Runner of Relax"""
- def _to_runnable(self, model: object, device: str, is_training: bool) ->
object:
+ def _to_runnable(self, model: Any, device: str, is_training: bool) -> Any:
"""Build runnable object
Parameters
-------
- model: object
+ model: Any
The meta model.
device: str
The device for place model
@@ -42,7 +42,7 @@ class TVMRunner(ModelRunner):
Returns
-------
- runnable: object
+ runnable: Any
The runnable
"""
diff --git a/src/contrib/msc/core/codegen/codegen_utils.cc
b/src/contrib/msc/core/codegen/codegen_utils.cc
index bdc542994d..0c751a8fbf 100644
--- a/src/contrib/msc/core/codegen/codegen_utils.cc
+++ b/src/contrib/msc/core/codegen/codegen_utils.cc
@@ -51,7 +51,7 @@ const String CodeGenUtils::IdxInput(const MSCJoint& node,
const String& prefix,
const String CodeGenUtils::IdxWeight(const MSCJoint& node, const String& wtype,
const String& suffix) {
- return wtype + std::to_string(node->index) + suffix;
+ return wtype + "_" + std::to_string(node->index) + suffix;
}
const String CodeGenUtils::CommentNode(const MSCJoint& node, const String&
prefix) {
diff --git a/src/contrib/msc/core/ir/graph.cc b/src/contrib/msc/core/ir/graph.cc
index 68f14376b9..abd581c797 100644
--- a/src/contrib/msc/core/ir/graph.cc
+++ b/src/contrib/msc/core/ir/graph.cc
@@ -578,6 +578,9 @@ const String MSCGraphNode::ToPrototxt() const {
for (const auto& pair : node->weights) {
param.Set("param_" + pair.first, pair.second);
}
+ for (const auto& pair : node->attrs) {
+ param.Set(pair.first, pair.second);
+ }
layer.push_back(std::make_pair("layer_param",
PrototxtPrinter::ToDictDoc(param)));
// Append the layer Map
printer.Append(Map<String, ObjectRef>{{"layer",
PrototxtPrinter::ToDictDoc(layer)}});
diff --git a/src/contrib/msc/core/ir/graph_builder.cc
b/src/contrib/msc/core/ir/graph_builder.cc
index f3fb383f8f..dab4ae813e 100644
--- a/src/contrib/msc/core/ir/graph_builder.cc
+++ b/src/contrib/msc/core/ir/graph_builder.cc
@@ -205,7 +205,7 @@ const MSCJoint RelaxGraphBuilder::AddNode(const Expr& expr,
const Optional<Expr>
String node_name = name.size() > 0 ? name : SpanUtils::GetAttr(expr->span,
"name");
const auto& shared_ref = SpanUtils::GetAttr(expr->span, "shared_ref");
- // Get optype
+ // Get optype and node_name
String optype;
if (expr->IsInstance<relax::VarNode>()) {
if (func_params_.count(expr) &&
func_params_[expr]->IsInstance<relax::ConstantNode>()) {
@@ -227,9 +227,18 @@ const MSCJoint RelaxGraphBuilder::AddNode(const Expr&
expr, const Optional<Expr>
optype = StringUtils::Replace(op_node->name, "relax.", "");
} else if (const auto* v_node = call_node->op.as<GlobalVarNode>()) {
const auto& func =
Downcast<relax::Function>(ref_module_->Lookup(v_node->name_hint));
- const auto& name_opt =
func->GetAttr<runtime::String>(relax::attr::kComposite);
- ICHECK(name_opt.defined()) << "Unexpected global func without composite";
- optype = name_opt.value();
+ const auto& byoc_name_opt = func->GetAttr<runtime::String>("byoc_name");
+ if (byoc_name_opt.defined()) {
+ node_name = byoc_name_opt.value();
+ }
+ const auto& codegen_opt =
func->GetAttr<runtime::String>(relax::attr::kCodegen);
+ if (codegen_opt.defined()) {
+ optype = codegen_opt.value();
+ } else {
+ const auto& name_opt =
func->GetAttr<runtime::String>(relax::attr::kComposite);
+ ICHECK(name_opt.defined()) << "Unexpected global func without
composite";
+ optype = name_opt.value();
+ }
} else if (call_node->op->IsInstance<relax::VarNode>()) {
ICHECK(target_funcs_.count(call_node->op)) << "Can not find target func:
" << call_node->op;
const auto& func = target_funcs_[call_node->op];
@@ -251,7 +260,10 @@ const MSCJoint RelaxGraphBuilder::AddNode(const Expr&
expr, const Optional<Expr>
if (const auto* call_node = expr.as<relax::CallNode>()) {
if (const auto* v_node = call_node->op.as<GlobalVarNode>()) {
const auto& func =
Downcast<relax::Function>(ref_module_->Lookup(v_node->name_hint));
- attrs = RelaxFuncAttrGetter().GetAttrs(func);
+ const auto& byoc_name_opt = func->GetAttr<runtime::String>("byoc_name");
+ if (!byoc_name_opt.defined()) {
+ attrs = RelaxFuncAttrGetter().GetAttrs(func);
+ }
} else if (call_node->op->IsInstance<relax::VarNode>()) {
ICHECK(target_funcs_.count(call_node->op)) << "Can not find target func:
" << call_node->op;
attrs = RelaxFuncAttrGetter().GetAttrs(target_funcs_[call_node->op]);
diff --git a/src/contrib/msc/core/printer/prototxt_printer.cc
b/src/contrib/msc/core/printer/prototxt_printer.cc
index e15894a272..7e96c657a7 100644
--- a/src/contrib/msc/core/printer/prototxt_printer.cc
+++ b/src/contrib/msc/core/printer/prototxt_printer.cc
@@ -73,12 +73,12 @@ DictDoc PrototxtPrinter::ToDictDoc(const
std::vector<std::pair<String, ObjectRef
void PrototxtPrinter::Append(const Map<String, ObjectRef>& dict) {
DictDoc doc = ToDictDoc(dict);
- PrintDoc(doc);
+ PrintDoc(doc, false);
}
void PrototxtPrinter::Append(const std::vector<std::pair<String, ObjectRef>>&
dict) {
DictDoc doc = ToDictDoc(dict);
- PrintDoc(doc);
+ PrintDoc(doc, false);
}
void PrototxtPrinter::AppendPair(const String& key, const ObjectRef& value) {
@@ -97,7 +97,7 @@ void PrototxtPrinter::PrintTypedDoc(const DictDoc& doc) {
if (doc->values[i].as<DictDocNode>()) {
output_ << " {";
IncreaseIndent();
- PrintDoc(doc->values[i]);
+ PrintDoc(doc->values[i], false);
DecreaseIndent();
NewLine() << "}";
} else {
diff --git a/src/contrib/msc/core/transform/set_byoc_attrs.cc
b/src/contrib/msc/core/transform/set_byoc_attrs.cc
new file mode 100644
index 0000000000..4fa8ab584e
--- /dev/null
+++ b/src/contrib/msc/core/transform/set_byoc_attrs.cc
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/contrib/msc/core/transform/set_byoc_attrs.cc
+ * \brief Pass for fuse ShapeExpr.
+ */
+
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/expr.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+
+#include "../../../../relax/transform/utils.h"
+#include "../utils.h"
+
+namespace tvm {
+namespace relax {
+
+using namespace tvm::contrib::msc;
+
+/*!
+ * \brief Fuse Tuple and TupleGetItem to BYOC
+ */
+class ByocNameSetter : public ExprMutator {
+ public:
+ explicit ByocNameSetter(IRModule ctx_module, const String& target, const
String& entry_name)
+ : ExprMutator(ctx_module) {
+ mod_ = ctx_module;
+ target_ = target;
+ entry_name_ = entry_name;
+ }
+
+ IRModule SetAttrs() {
+ GlobalVar main_var;
+ size_t func_cnt = 0;
+ for (const auto& [gv, func] : mod_->functions) {
+ if (gv->name_hint == entry_name_) {
+ main_var = gv;
+ } else {
+ const auto& name_opt = func->GetAttr<runtime::String>(attr::kCodegen);
+ if (name_opt.defined() && name_opt.value() == target_) {
+ const auto& new_func = WithAttr(Downcast<Function>(func),
"byoc_name",
+ target_ + "_" +
std::to_string(func_cnt));
+ builder_->UpdateFunction(gv, new_func);
+ func_cnt += 1;
+ }
+ }
+ }
+ return builder_->GetContextIRModule();
+ }
+
+ private:
+ IRModule mod_;
+ String target_;
+ String entry_name_;
+ Map<Function, Function> new_funcs_;
+};
+
+IRModule SetBYOCAttrs(IRModule mod, const String& target, const String&
entry_name) {
+ return ByocNameSetter(mod, target, entry_name).SetAttrs();
+}
+
+namespace transform {
+
+Pass SetBYOCAttrs(const String& target, const String& entry_name) {
+ runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
+ [=](IRModule m, PassContext pc) { return relax::SetBYOCAttrs(m, target,
entry_name); };
+ return CreateModulePass(pass_func, 0, "SetBYOCAttrs", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.SetBYOCAttrs").set_body_typed(SetBYOCAttrs);
+
+} // namespace transform
+} // namespace relax
+} // namespace tvm
diff --git a/src/contrib/msc/framework/tensorflow/codegen.cc
b/src/contrib/msc/framework/tensorflow/codegen.cc
index e4b4feb8ca..30f06b43e7 100644
--- a/src/contrib/msc/framework/tensorflow/codegen.cc
+++ b/src/contrib/msc/framework/tensorflow/codegen.cc
@@ -64,6 +64,9 @@ void TensorflowCodeGen::CodeGenGraph() {
stack_.comment("Define the weights");
for (const auto& n : graph()->node_names) {
const auto& node = graph()->FindNode(n);
+ if (node->optype == "nn.batch_norm") {
+ continue;
+ }
for (const auto& pair : node->weights) {
stack_.func_call("get_variable", IdxWeightBase(node, pair.first))
.call_arg(DocUtils::ToStrDoc(pair.second->name))
diff --git a/src/contrib/msc/framework/tensorflow/tf_v1_opcode.cc
b/src/contrib/msc/framework/tensorflow/tf_v1_opcode.cc
index d17a326424..56dba21ac8 100644
--- a/src/contrib/msc/framework/tensorflow/tf_v1_opcode.cc
+++ b/src/contrib/msc/framework/tensorflow/tf_v1_opcode.cc
@@ -157,6 +157,33 @@ class TFV1AxisCodeGen : public TFV1OpCode {
String attr_name_;
};
+class TFV1BatchnormCodeGen : public TFV1OpCode {
+ TFV1_OP_CODEGEN_METHODS(TFV1BatchnormCodeGen)
+
+ protected:
+ void CodeGenBuild() final {
+ stack_.op_call()
+ .op_input_arg()
+ .op_arg<bool>("scale")
+ .op_arg<bool>("center")
+ .op_arg<float>("momentum")
+ .op_arg<float>("epsilon")
+ .call_arg("tf_v1.constant_initializer(weights[\"" +
node()->WeightAt("gamma")->name +
+ "\"].asnumpy())",
+ "gamma_initializer")
+ .call_arg("tf_v1.constant_initializer(weights[\"" +
node()->WeightAt("beta")->name +
+ "\"].asnumpy())",
+ "beta_initializer")
+ .call_arg("tf_v1.constant_initializer(weights[\"" +
node()->WeightAt("mean")->name +
+ "\"].asnumpy())",
+ "moving_mean_initializer")
+ .call_arg("tf_v1.constant_initializer(weights[\"" +
node()->WeightAt("var")->name +
+ "\"].asnumpy())",
+ "moving_variance_initializer")
+ .op_name_arg();
+ }
+};
+
class TFV1BroadcastToCodeGen : public TFV1OpCode {
TFV1_OP_CODEGEN_METHODS(TFV1BroadcastToCodeGen)
@@ -166,6 +193,19 @@ class TFV1BroadcastToCodeGen : public TFV1OpCode {
}
};
+class TFV1ClipCodeGen : public TFV1OpCode {
+ TFV1_OP_CODEGEN_METHODS(TFV1ClipCodeGen)
+
+ protected:
+ void CodeGenBuild() final {
+ stack_.op_call()
+ .op_input_arg()
+ .op_arg<float>("min", "clip_value_min")
+ .op_arg<float>("max", "clip_value_max")
+ .op_name_arg();
+ }
+};
+
class TFV1ConcatCodeGen : public TFV1OpCode {
TFV1_OP_CODEGEN_METHODS(TFV1ConcatCodeGen)
@@ -542,6 +582,7 @@ const std::shared_ptr<std::unordered_map<String,
std::shared_ptr<TFV1OpCode>>> G
map->emplace("argmin",
std::make_shared<TFV1ArgMaxMinCodeGen>("tf_v1.argmin"));
map->emplace("astype", std::make_shared<TFV1AstypeCodeGen>("tf_v1.cast"));
map->emplace("broadcast_to",
std::make_shared<TFV1BroadcastToCodeGen>("tf_v1.broadcast_to"));
+ map->emplace("clip",
std::make_shared<TFV1ClipCodeGen>("tf_v1.clip_by_value"));
map->emplace("concat",
std::make_shared<TFV1ConcatCodeGen>("ops.array_ops.concat_v2"));
map->emplace("concatenate",
std::make_shared<TFV1ConcatCodeGen>("ops.array_ops.concat_v2"));
map->emplace("einsum", std::make_shared<TFV1EinsumCodeGen>("tf_v1.einsum"));
@@ -555,6 +596,8 @@ const std::shared_ptr<std::unordered_map<String,
std::shared_ptr<TFV1OpCode>>> G
// nn ops
map->emplace("nn.avg_pool2d",
std::make_shared<TFV1Pool2dCodeGen>("ops.nn_ops.pool"));
+ map->emplace("nn.batch_norm",
+
std::make_shared<TFV1BatchnormCodeGen>("tf_v1.layers.batch_normalization"));
map->emplace("nn.conv2d",
std::make_shared<TFV1ConvCodeGen>("ops.nn_ops.conv2d", false));
map->emplace("nn.max_pool2d",
std::make_shared<TFV1Pool2dCodeGen>("ops.nn_ops.pool"));
map->emplace("nn.pad", std::make_shared<TFV1PadCodeGen>("tf_v1.pad"));
@@ -569,7 +612,8 @@ const std::shared_ptr<std::unordered_map<String,
std::shared_ptr<TFV1OpCode>>> G
map->emplace("tuple", std::make_shared<TFV1TupleCodeGen>("tuple"));
// msc ops
- map->emplace("msc.conv2d",
std::make_shared<TFV1ConvCodeGen>("ops.nn_ops.conv2d", true));
+ map->emplace("msc.conv2d",
std::make_shared<TFV1ConvCodeGen>("ops.nn_ops.conv2d", false));
+ map->emplace("msc.conv2d_bias",
std::make_shared<TFV1ConvCodeGen>("ops.nn_ops.conv2d", true));
return map;
}
diff --git a/src/contrib/msc/framework/tensorrt/codegen.cc
b/src/contrib/msc/framework/tensorrt/codegen.cc
index b8b2335da1..a59697377a 100644
--- a/src/contrib/msc/framework/tensorrt/codegen.cc
+++ b/src/contrib/msc/framework/tensorrt/codegen.cc
@@ -494,9 +494,12 @@ Array<runtime::Module> MSCTensorRTCompiler(Array<Function>
functions,
Array<runtime::Module> compiled_functions;
for (const auto& func : functions) {
VLOG(1) << "MSC.TensorRT partition:" << std::endl << func;
+ const auto& byoc_name_opt = func->GetAttr<runtime::String>("byoc_name");
+ ICHECK(byoc_name_opt.defined()) << "Can not find byoc_name from attrs";
+ const auto& byoc_name = byoc_name_opt.value();
std::string func_name = GetExtSymbol(func);
- ICHECK(target_option.count(func_name)) << "Can not find target option for
" << func_name;
- const auto& options = Downcast<String>(target_option[func_name]);
+ ICHECK(target_option.count(byoc_name)) << "Can not find target option for
" << byoc_name;
+ const auto& options = Downcast<String>(target_option[byoc_name]);
MSCJSONSerializer serializer(constant_names, options);
serializer.serialize(func);
std::string graph_json = serializer.GetJSON();
diff --git a/src/contrib/msc/framework/tensorrt/codegen_utils.h
b/src/contrib/msc/framework/tensorrt/codegen_utils.h
index 8249444d9d..d598396f6f 100644
--- a/src/contrib/msc/framework/tensorrt/codegen_utils.h
+++ b/src/contrib/msc/framework/tensorrt/codegen_utils.h
@@ -46,7 +46,7 @@ class TensorRTCodeGenHelper : public BaseCodeGenHelper {
return "*" + IdxNodeBase(pair.first, prefix, suffix);
}
if (pair.first->optype == "tuple" || pair.first->optype == "get_item") {
- return IdxNodeBase(pair.first, prefix, suffix);
+ return "*" + IdxNodeBase(pair.first, prefix, suffix);
}
return "*" + IdxOutputBase(pair.first, prefix, pair.second, suffix);
}
diff --git a/tests/python/contrib/test_msc/test_runner.py
b/tests/python/contrib/test_msc/test_runner.py
index 4653a2225b..3f2d0d0c90 100644
--- a/tests/python/contrib/test_msc/test_runner.py
+++ b/tests/python/contrib/test_msc/test_runner.py
@@ -22,10 +22,16 @@ import numpy as np
import torch
from torch import fx
+from tvm.contrib.msc.framework.tensorflow import tf_v1
import tvm.testing
+import tvm.relay.testing.tf as tf_testing
from tvm.relax.frontend.torch import from_fx
from tvm.contrib.msc.framework.tvm.runtime import TVMRunner
+from tvm.contrib.msc.framework.torch.runtime import TorchRunner
+from tvm.contrib.msc.framework.tensorrt.runtime import TensorRTRunner
+from tvm.contrib.msc.framework.tensorflow.frontend import from_tensorflow
+from tvm.contrib.msc.framework.tensorflow.runtime import TensorflowRunner
from tvm.contrib.msc.core import utils as msc_utils
requires_tensorrt = pytest.mark.skipif(
@@ -51,11 +57,32 @@ def _get_torch_model(name, is_training=False):
return None
+def _get_tf_graph():
+ """Get tensorflow graphdef"""
+
+ try:
+ tf_graph = tf_v1.Graph()
+ with tf_graph.as_default():
+ graph_def = tf_testing.get_workload(
+
"https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.4_224.tgz",
+ "mobilenet_v2_1.4_224_frozen.pb",
+ )
+ # Call the utility to import the graph definition into default
graph.
+ graph_def = tf_testing.ProcessGraphDefParam(graph_def)
+ return tf_graph, graph_def
+ except: # pylint: disable=bare-except
+ print("please install tensorflow package")
+ return None, None
+
+
def _test_from_torch(runner_cls, device, is_training=False, atol=1e-3,
rtol=1e-3):
"""Test runner from torch model"""
+
torch_model = _get_torch_model("resnet50", is_training)
if torch_model:
workspace = msc_utils.set_workspace()
+ log_path = workspace.relpath("MSC_LOG", keep_history=False)
+ msc_utils.set_global_logger("info", log_path)
input_info = [([1, 3, 224, 224], "float32")]
datas = [np.random.rand(*i[0]).astype(i[1]) for i in input_info]
torch_datas = [torch.from_numpy(d) for d in datas]
@@ -85,5 +112,49 @@ def test_tvm_runner_gpu():
_test_from_torch(TVMRunner, "cuda", is_training=True)
+def test_torch_runner_cpu():
+ """Test runner for torch on cpu"""
+
+ _test_from_torch(TorchRunner, "cpu")
+
+
[email protected]_gpu
+def test_torch_runner_gpu():
+ """Test runner for torch on cuda"""
+
+ _test_from_torch(TorchRunner, "cuda", atol=1e-2, rtol=1e-2)
+
+
+@requires_tensorrt
+def test_tensorrt_runner():
+ """Test runner for tensorrt"""
+
+ _test_from_torch(TensorRTRunner, "cuda", atol=1e-2, rtol=1e-2)
+
+
+def test_tensorflow_runner():
+ """Test runner from tf graph"""
+
+ tf_graph, graph_def = _get_tf_graph()
+ if tf_graph and graph_def:
+ workspace = msc_utils.set_workspace()
+ log_path = workspace.relpath("MSC_LOG", keep_history=False)
+ msc_utils.set_global_logger("info", log_path)
+ data = np.random.uniform(size=(1, 224, 224, 3)).astype("float32")
+ out_name = "MobilenetV2/Predictions/Reshape_1:0"
+ # get golden
+ with tf_v1.Session(graph=tf_graph) as sess:
+ golden = sess.run([out_name], {"input:0": data})
+ # get outputs
+ shape_dict = {"input": data.shape}
+ mod, _ = from_tensorflow(graph_def, shape_dict, [out_name],
as_msc=False)
+ runner = TensorflowRunner(mod)
+ runner.build()
+ outputs = runner.run([data], ret_type="list")
+ for gol_r, out_r in zip(golden, outputs):
+ tvm.testing.assert_allclose(gol_r, out_r, atol=1e-3, rtol=1e-3)
+ workspace.destory()
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/contrib/test_msc/test_translate_relay.py
b/tests/python/contrib/test_msc/test_translate_relay.py
index e4c8dbf3b5..aca0f26890 100644
--- a/tests/python/contrib/test_msc/test_translate_relay.py
+++ b/tests/python/contrib/test_msc/test_translate_relay.py
@@ -27,7 +27,7 @@ from torch.nn import Module
import tvm.testing
from tvm.relax.frontend.torch import from_fx
from tvm.relay.frontend import from_pytorch
-from tvm.contrib.msc.core.ir import translate
+from tvm.contrib.msc.core.frontend import translate
from tvm.contrib.msc.framework.tvm import codegen as tvm_codegen