This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 8edfee8574 [Unity][MSC][M2.4] Add quantizer for quantize model (#16228)
8edfee8574 is described below
commit 8edfee85742da697957e9e33e95bce28727362ac
Author: Archermmt <[email protected]>
AuthorDate: Fri Dec 15 15:59:56 2023 +0800
[Unity][MSC][M2.4] Add quantizer for quantize model (#16228)
* add quantizer
* add quantizer test
---
python/tvm/contrib/msc/core/runtime/runner.py | 8 +
python/tvm/contrib/msc/core/tools/__init__.py | 1 +
.../tvm/tools => core/tools/quantize}/__init__.py | 6 +-
.../tvm/contrib/msc/core/tools/quantize/method.py | 472 +++++++++++++++++++++
.../contrib/msc/core/tools/quantize/quantizer.py | 249 +++++++++++
python/tvm/contrib/msc/core/utils/file.py | 5 +
python/tvm/contrib/msc/core/utils/info.py | 4 +
.../msc/framework/tensorflow/tools/__init__.py | 1 +
.../tools/quantize}/__init__.py | 5 +-
.../tensorflow/tools/quantize/quantizer.py | 55 +++
.../msc/framework/tensorrt/codegen/sources.py | 172 +++++++-
.../msc/framework/tensorrt/runtime/runner.py | 53 +++
.../msc/framework/tensorrt/tools/__init__.py | 1 +
.../tools => tensorrt/tools/quantize}/__init__.py | 6 +-
.../framework/tensorrt/tools/quantize/method.py | 149 +++++++
.../framework/tensorrt/tools/quantize/quantizer.py | 366 ++++++++++++++++
.../contrib/msc/framework/torch/tools/__init__.py | 1 +
.../tools => torch/tools/quantize}/__init__.py | 6 +-
.../msc/framework/torch/tools/quantize/method.py | 237 +++++++++++
.../framework/torch/tools/quantize/quantizer.py | 55 +++
.../contrib/msc/framework/tvm/tools/__init__.py | 1 +
.../tools => tvm/tools/quantize}/__init__.py | 6 +-
.../msc/framework/tvm/tools/quantize/method.py | 204 +++++++++
.../msc/framework/tvm/tools/quantize/quantizer.py | 167 ++++++++
python/tvm/contrib/msc/pipeline/manager.py | 4 +
tests/python/contrib/test_msc/test_tools.py | 44 +-
26 files changed, 2259 insertions(+), 19 deletions(-)
diff --git a/python/tvm/contrib/msc/core/runtime/runner.py
b/python/tvm/contrib/msc/core/runtime/runner.py
index dcf24225fe..5228b06b10 100644
--- a/python/tvm/contrib/msc/core/runtime/runner.py
+++ b/python/tvm/contrib/msc/core/runtime/runner.py
@@ -414,6 +414,14 @@ class BaseRunner(object):
self.run(inputs, ret_type="native")
break
plan = pruner.finalize()
+ elif tool_type == ToolType.QUANTIZER:
+ quantizer = self.get_tool(ToolType.QUANTIZER)
+ while not quantizer.calibrated:
+ assert data_loader, "data_loader should be given to plan prune"
+ for inputs in data_loader():
+ self.run(inputs, ret_type="native")
+ quantizer.calibrate()
+ plan = quantizer.finalize()
else:
plan = self.get_tool(tool_type).finalize()
assert plan, "Failed to create plan for {}".format(tool_type)
diff --git a/python/tvm/contrib/msc/core/tools/__init__.py
b/python/tvm/contrib/msc/core/tools/__init__.py
index 0524e4c823..e97771cf6c 100644
--- a/python/tvm/contrib/msc/core/tools/__init__.py
+++ b/python/tvm/contrib/msc/core/tools/__init__.py
@@ -19,4 +19,5 @@
from .tool import *
from .execute import *
from .prune import *
+from .quantize import *
from .track import *
diff --git a/python/tvm/contrib/msc/framework/tvm/tools/__init__.py
b/python/tvm/contrib/msc/core/tools/quantize/__init__.py
similarity index 89%
copy from python/tvm/contrib/msc/framework/tvm/tools/__init__.py
copy to python/tvm/contrib/msc/core/tools/quantize/__init__.py
index 226ae3102d..1aad17c055 100644
--- a/python/tvm/contrib/msc/framework/tvm/tools/__init__.py
+++ b/python/tvm/contrib/msc/core/tools/quantize/__init__.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""tvm.contrib.msc.framework.tvm.tools"""
+"""tvm.contrib.msc.core.tools.quantize"""
-from .prune import *
-from .track import *
+from .quantizer import *
+from .method import *
diff --git a/python/tvm/contrib/msc/core/tools/quantize/method.py
b/python/tvm/contrib/msc/core/tools/quantize/method.py
new file mode 100644
index 0000000000..9701858267
--- /dev/null
+++ b/python/tvm/contrib/msc/core/tools/quantize/method.py
@@ -0,0 +1,472 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-argument
+"""tvm.contrib.msc.core.tools.quantize.method"""
+
+from typing import Union, Any
+import numpy as np
+
+from tvm.contrib.msc.core.tools.tool import ToolType, BaseTool
+from tvm.contrib.msc.core.utils.namespace import MSCFramework
+from tvm.contrib.msc.core import utils as msc_utils
+
+
+class QuantizeMethod(object):
+ """Default quantize method"""
+
+ @classmethod
+ def amplify_data(
+ cls, data: np.array, scale: float, min_val: float, max_val: float,
rounding: str = "round"
+ ) -> np.ndarray:
+ """Amplify the data
+
+ Parameters
+ ----------
+ data: np.ndarray
+ The source data.
+ scale: float
+ The scale factor
+ min_val: float
+ The min.
+ max_val: float
+ The max.
+ rounding: str
+ The round method
+
+ Returns
+ -------
+ data: np.ndarray
+ The processed data.
+ """
+
+ if rounding == "null":
+ return np.clip(data * scale, min_val, max_val)
+ if rounding == "floor":
+ return np.clip(np.floor(data * scale), min_val, max_val)
+ if rounding == "ceil":
+ return np.clip(np.ceil(data * scale), min_val, max_val)
+ if rounding == "round":
+ return np.clip(np.round(data * scale), min_val, max_val)
+ if rounding == "trunc":
+ return np.clip(np.trunc(data * scale), min_val, max_val)
+ if rounding == "logic_round":
+ data = np.clip(data * scale, min_val, max_val)
+ negative_ceil = np.where(
+ np.logical_and(data < 0, (data - np.floor(data)) == 0.5),
np.ceil(data), 0
+ )
+ data = np.where(np.logical_and(data < 0, (data - np.floor(data))
== 0.5), 0, data)
+ data = np.where((data - np.floor(data)) >= 0.5, np.ceil(data),
data)
+ data = np.where((data - np.floor(data)) < 0.5, np.floor(data),
data)
+ return data + negative_ceil
+ raise TypeError("Unexpected rounding " + str(rounding))
+
+ @classmethod
+ def get_scale_tensor(
+ cls,
+ data: Any,
+ scale: float,
+ axis: int = -1,
+ epsilon: float = 1.0 / (1 << 24),
+ expand_dims: bool = True,
+ ) -> Union[float, np.ndarray]:
+ """Get the scale tensor
+
+ Parameters
+ ----------
+ quantizer: BaseQuantizer
+ The quantizer
+ data: array_like
+ The source data.
+ name: str
+ The name of the tensor.
+ consumer: str
+ The name of the consumer.
+ scale: float
+ The scale factor
+ axis: int
+ The axis.
+ epsilon: float
+ The epsilon for get scale.
+ expand_dims: bool
+ Whether to expand dims
+
+ Returns
+ -------
+ scale_tensor: np.ndarray
+ The processed tensor.
+ """
+
+ data = msc_utils.cast_array(data)
+ if isinstance(scale, list):
+ scale_tensor = np.array(scale).astype(data.dtype)
+ if expand_dims:
+ scale_shape = [s if idx == axis else 1 for idx, s in
enumerate(data.shape)]
+ scale_tensor = scale_tensor.reshape(scale_shape)
+ if scale_tensor.min() <= epsilon:
+ scale_mask = scale_tensor <= epsilon
+ scale_tensor[scale_mask] = 0
+ elif scale <= epsilon:
+ scale_tensor = 0
+ else:
+ scale_tensor = scale
+ return scale_tensor
+
+ @classmethod
+ def gather_maxmin(
+ cls,
+ quantizer: BaseTool,
+ data: np.ndarray,
+ name: str,
+ consumer: str,
+ plan: dict,
+ nbits: int = 8,
+ ) -> dict:
+ """Gather the data by max/min
+
+ Parameters
+ ----------
+ quantizer: BaseQuantizer
+ The quantizer
+ data: np.ndarray
+ The source data.
+ name: str
+ The name of the tensor.
+ consumer: str
+ The name of the consumer.
+ plan: dict
+ The pre-calibrated plan.
+ nbits: int
+ The number bits for quantize.
+
+ Returns
+ -------
+ plan: dict
+ The plan of the tensor.
+ """
+
+ abs_max_list = plan.get("abs_max_list", [])
+ abs_max_list.append(float(np.abs(data).max()))
+ max_list = plan.get("max_list", [])
+ max_list.append(float(data.max()))
+ min_list = plan.get("min_list", [])
+ min_list.append(float(data.min()))
+ return {
+ "abs_max_list": abs_max_list,
+ "max_list": max_list,
+ "min_list": min_list,
+ "calibrated": False,
+ }
+
+ @classmethod
+ def gather_kl_divergence(
+ cls,
+ quantizer: BaseTool,
+ data: np.ndarray,
+ name: str,
+ consumer: str,
+ plan: dict,
+ nbits: int = 8,
+ bins: int = 4096,
+ ) -> dict:
+ """Gather the data by kl_divergence
+
+ Parameters
+ ----------
+ quantizer: BaseQuantizer
+ The quantizer
+ data: np.ndarray
+ The source data.
+ name: str
+ The name of the tensor.
+ consumer: str
+ The name of the consumer.
+ plan: dict
+ The pre-calibrated plan.
+ nbits: int
+ The number bits for quantize.
+ bins: int
+ The number bins.
+
+ Returns
+ -------
+ plan: dict
+ The plan of the tensor.
+ """
+
+ if not plan or "abs_max" not in plan:
+ return cls.gather_maxmin(quantizer, name, data, plan, nbits)
+ hist, edge = np.histogram(data, bins=bins, range=[-plan["abs_max"],
plan["abs_max"]])
+ hist_list = plan.get("hist_list", [])
+ return {"hist_list": hist_list + [hist], "edge": edge, **plan}
+
+ @classmethod
+ def gather_max_per_channel(
+ cls,
+ quantizer: BaseTool,
+ data: np.ndarray,
+ name: str,
+ consumer: str,
+ plan: dict,
+ nbits: int = 8,
+ channel: str = "O",
+ auto_unsign: bool = False,
+ ) -> dict:
+ """Gather the data by max_per_channel
+
+ Parameters
+ ----------
+ quantizer: BaseQuantizer
+ The quantizer
+ data: np.ndarray
+ The source data.
+ name: str
+ The name of the tensor.
+ consumer: str
+ The name of the consumer.
+ plan: dict
+ The pre-calibrated plan.
+ nbits: int
+ The number bits for quantize.
+ channel: str
+ The channel reference.
+ auto_unsign: bool
+ Whether to use auto unsign.
+
+ Returns
+ -------
+ plan: dict
+ The plan of the tensor.
+ """
+
+ weight = quantizer.find_tensor(name)
+ axis = weight.layout_of(channel)
+ channel_datas = np.split(data, data.shape[axis], axis)
+ channel_max = [float(np.abs(d).max()) for d in channel_datas]
+ sign = data.min() < 0 if auto_unsign else True
+ valid_range = 2 ** (nbits - int(sign)) - 1
+ scale = [valid_range / m for m in channel_max]
+ return {"scale": scale, "sign": sign, "axis": axis, "calibrated": True}
+
+ @classmethod
+ def calibrate_maxmin(
+ cls,
+ quantizer: BaseTool,
+ name: str,
+ consumer: str,
+ plan: dict,
+ nbits: int = 8,
+ auto_unsign: bool = False,
+ ) -> dict:
+ """Calibrate the data by kl_divergence
+
+ Parameters
+ ----------
+ quantizer: BaseQuantizer
+ The quantizer
+ name: str
+ The name of the tensor.
+ consumer: str
+ The name of the consumer.
+ plan: dict
+ The pre-calibrated plan.
+ nbits: int
+ The number bits for quantize.
+ auto_unsign: bool
+ Whether to use auto unsign.
+
+ Returns
+ -------
+ plan: dict
+ The plan of the tensor.
+ """
+
+ sign = plan["min"] < 0 if auto_unsign else True
+ valid_range = 2 ** (nbits - int(sign)) - 1
+ abs_max = float(np.array(plan["abs_max_list"]).max())
+ return {"scale": valid_range / abs_max, "sign": sign, "calibrated":
True}
+
+ @classmethod
+ def calibrate_kl_divergence(
+ cls,
+ quantizer: BaseTool,
+ name: str,
+ consumer: str,
+ plan: dict,
+ nbits: int = 8,
+ bins: int = 4096,
+ auto_unsign: bool = False,
+ ) -> dict:
+ """Calibrate the data by kl_divergence
+
+ Parameters
+ ----------
+ quantizer: BaseQuantizer
+ The quantizer
+ name: str
+ The name of the tensor.
+ consumer: str
+ The name of the consumer.
+ plan: dict
+ The pre-calibrated plan.
+ nbits: int
+ The number bits for quantize.
+ bins: int
+ The number bins.
+ auto_unsign: bool
+ Whether to use auto unsign.
+
+ Returns
+ -------
+ plan: dict
+ The plan of the tensor.
+ """
+
+ # pylint: disable=import-outside-toplevel
+ import ctypes
+ from tvm.relay import quantize as _quantize
+
+ if plan and "abs_max_list" in plan:
+ return {
+ "abs_max": float(np.array(plan["abs_max_list"]).max()),
+ "max": float(np.array(plan["max_list"]).max()),
+ "min": float(np.array(plan["min_list"]).min()),
+ "calibrated": False,
+ }
+
+ def get_pointer(arr, ctypes_type):
+ ptr = arr.ctypes.data_as(ctypes.POINTER(ctypes_type))
+ return ctypes.cast(ptr, ctypes.c_void_p)
+
+ sign = plan["min"] < 0 if auto_unsign else True
+ hist = np.array(plan["hist_list"]).sum(axis=0)
+ hist_ptr = get_pointer(hist.astype(np.int64), ctypes.c_int64)
+ edge_ptr = get_pointer(plan["edge"].astype(np.float32), ctypes.c_float)
+ valid_range = 2 ** (nbits - int(sign)) - 1
+ scale = _quantize._quantize.FindScaleByKLMinimization(hist_ptr,
edge_ptr, bins, valid_range)
+ return {"scale": valid_range / scale, "sign": sign, "calibrated": True}
+
+ @classmethod
+ def quantize_normal(
+ cls,
+ quantizer: BaseTool,
+ data: np.ndarray,
+ name: str,
+ consumer: str,
+ scale: float,
+ nbits: int = 8,
+ axis: int = -1,
+ sign: bool = True,
+ rounding: str = "round",
+ epsilon: float = 1.0 / (1 << 24),
+ ) -> np.ndarray:
+ """Calibrate the data by kl_divergence
+
+ Parameters
+ ----------
+ quantizer: BaseQuantizer
+ The quantizer
+ data: np.ndarray
+ The source data.
+ name: str
+ The name of the tensor.
+ consumer: str
+ The name of the consumer.
+ scale: float
+ The scale factor
+ nbits: int
+ The number bits for quantize.
+ axis: int
+ The axis.
+ sign: bool
+ Whether to use sign.
+ rounding str
+ The rounding method.
+ epsilon: float
+ The epsilon for get scale.
+
+ Returns
+ -------
+ data: array like
+ The processed tensor.
+ """
+
+ valid_range = 2 ** (nbits - int(sign)) - 1
+ min_val = -valid_range if sign else 0
+ scale_tensor = quantizer._get_tensor_cache(name, consumer,
"scale_tensor")
+ if scale_tensor is None:
+ scale_tensor = cls.get_scale_tensor(data, scale, axis, epsilon)
+ quantizer._save_tensor_cache(name, consumer, "scale_tensor",
scale_tensor)
+ data = cls.amplify_data(data, scale_tensor, min_val, valid_range,
rounding)
+ return data / scale
+
+ @classmethod
+ def dequantize_normal(
+ cls,
+ quantizer: BaseTool,
+ data: np.ndarray,
+ name: str,
+ consumer: str,
+ scale: float = -1.0,
+ nbits: int = 8,
+ axis: int = -1,
+ sign: bool = True,
+ rounding: str = "round",
+ epsilon: float = 1.0 / (1 << 24),
+ ) -> np.ndarray:
+ """Calibrate the data by kl_divergence
+
+ Parameters
+ ----------
+ quantizer: BaseQuantizer
+ The quantizer
+ data: np.ndarray
+ The source data.
+ name: str
+ The name of the tensor.
+ consumer: str
+ The name of the consumer.
+ scale: float
+ The scale factor
+ nbits: int
+ The number bits for quantize.
+ axis: int
+ The axis.
+ sign: bool
+ Whether to use sign.
+ rounding str
+ The rounding method.
+ epsilon: float
+ The epsilon for get scale.
+
+ Returns
+ -------
+ data: array like
+ The processed tensor.
+ """
+
+ return data
+
+ @classmethod
+ def framework(cls):
+ return MSCFramework.MSC
+
+ @classmethod
+ def tool_type(cls):
+ return ToolType.QUANTIZER
+
+
+msc_utils.register_tool_method(QuantizeMethod)
diff --git a/python/tvm/contrib/msc/core/tools/quantize/quantizer.py
b/python/tvm/contrib/msc/core/tools/quantize/quantizer.py
new file mode 100644
index 0000000000..bee8e6fa42
--- /dev/null
+++ b/python/tvm/contrib/msc/core/tools/quantize/quantizer.py
@@ -0,0 +1,249 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""tvm.contrib.msc.core.tools.quantize.quantizer"""
+
+from typing import List, Dict, Any
+
+from tvm.contrib.msc.core.tools.tool import ToolType, BaseTool, ToolStrategy
+from tvm.contrib.msc.core import utils as msc_utils
+
+
+class QuantizeStage:
+ GATHER = "gather"
+ CALIBRATE = "calibrate"
+
+
+class BaseQuantizer(BaseTool):
+ """Base quantizer for all"""
+
+ def setup(self) -> dict:
+ """Setup the tool
+
+ Returns
+ -------
+ info: dict
+ The setup info.
+ """
+
+ if self._plan:
+ self._calibrated = True
+ self.change_stage(msc_utils.MSCStage.QUANTIZE)
+ else:
+ self._calibrated = False
+ self._calibrate_plan = {}
+ self.change_stage(QuantizeStage.GATHER)
+ return super().setup()
+
+ def calibrate(self) -> dict:
+ """Calibrate the datas
+
+ Returns
+ -------
+ plan: dict
+ The calibrated plan.
+ """
+
+ new_plan = {}
+ self.change_stage(QuantizeStage.CALIBRATE)
+ for tensor_id, plan in self._calibrate_plan.items():
+ if plan.get("calibrated", False):
+ new_plan[tensor_id] = plan
+ continue
+ name, consumer = self.from_tensor_id(tensor_id)
+ strategy = self._get_tensor_strategy(name, consumer)
+ new_plan[tensor_id] = strategy(self, name, consumer, plan)
+ if any(not plan.get("calibrated", False) for plan in
new_plan.values()):
+ self._calibrate_plan = new_plan
+ self.change_stage(QuantizeStage.GATHER)
+ else:
+ self._calibrated = True
+ for name, plan in new_plan.items():
+ self._plan[name] = {k: v for k, v in plan.items() if k not in
("calibrated")}
+ self.change_stage(msc_utils.MSCStage.QUANTIZE)
+ self._forward_cnt = 0
+ return new_plan
+
+ def _parse_strategys(self, strategy_list: dict) -> Dict[str, ToolStrategy]:
+ """Parse the strategy to get valid strategy
+
+ Parameters
+ -------
+ strategy_list: dict
+ The given strategy
+
+ Returns
+ -------
+ strategys: dict<str, ToolStrategy>
+ The parsed strategy.
+ """
+
+ def _update_stages(strategy):
+ if "stages" not in strategy:
+ strategy["stages"] = [msc_utils.MSCStage.QUANTIZE]
+ return strategy
+
+ return super()._parse_strategys([_update_stages(s) for s in
strategy_list])
+
+ def _check_tensor(self, name: str, consumer: str) -> bool:
+ """Check if the tensor should be processed
+
+ Parameters
+ -------
+ name: str
+ The name of the tensor.
+ consumer: str
+ The name of the consumer.
+
+ Returns
+ -------
+ vaild: bool
+ Whether to process the tensor.
+ """
+
+ strategys = self._get_tensor_strategys(name, consumer)
+ if not strategys:
+ return False
+ if any(s.get_config().get("nbits", 8) == -1 for s in strategys):
+ return False
+ return True
+
+ def _process_tensor(
+ self, tensor: Any, name: str, consumer: str, scope: str, strategys:
List[ToolStrategy]
+ ) -> Any:
+ """Process tensor
+
+ Parameters
+ -------
+ tensor: Any
+ Tensor in framework
+ name: str
+ The name of the tensor.
+ consumer: str
+ The name of the consumer.
+ scope: str
+ The scope mark teacher| student| null.
+ strategys: list<ToolStrategy>
+ The strategys for the tensor.
+
+ Returns
+ -------
+ tensor: Any
+ The processed tensor.
+ """
+
+ if not self._calibrated:
+ return self._gather_tensor(tensor, name, consumer, strategys)
+ return self._quantize_tensor(tensor, name, consumer, strategys)
+
+ def _gather_tensor(
+ self, tensor: Any, name: str, consumer: str, strategys:
List[ToolStrategy]
+ ) -> Any:
+ """Gather tensor datas
+
+ Parameters
+ -------
+ tensor: Any
+ Tensor in framework
+ name: str
+ The name of the tensor.
+ consumer: str
+ The name of the consumer.
+ strategys: list<ToolStrategy>
+ The strategys for the tensor.
+
+ Returns
+ -------
+ tensor: Any
+ The processed tensor.
+ """
+
+ assert len(strategys) == 1, "gather should only has 1 strategy, get "
+ str(strategys)
+ tensor_id = self.to_tensor_id(name, consumer)
+ plan = self._calibrate_plan.get(tensor_id, {})
+ if plan.get("calibrated", False):
+ return tensor
+ self._calibrate_plan[tensor_id] = strategys[0](self, tensor, name,
consumer, plan)
+ return tensor
+
+ def _quantize_tensor(
+ self, tensor: Any, name: str, consumer: str, strategys:
List[ToolStrategy]
+ ) -> Any:
+ """Quantize tensor
+
+ Parameters
+ -------
+ tensor: Any
+ Tensor in framework
+ name: str
+ The name of the tensor.
+ consumer: str
+ The name of the consumer.
+ strategys: list<ToolStrategy>
+ The strategys for the tensor.
+
+ Returns
+ -------
+ tensor: Any
+ The processed tensor.
+ """
+
+ tensor_id = self.to_tensor_id(name, consumer)
+ for strategy in strategys:
+ tensor = strategy(self, tensor, name, consumer,
**self._plan[tensor_id])
+ return tensor
+
+ def create_tasks(self, **kwargs) -> List[dict]:
+ """Create tasks for gym
+
+ Parameters
+ ----------
+ kwargs: dict
+ The kwargs for create tasks.
+
+ Returns
+ -------
+ tasks: list<dict>
+ The tasks.
+ """
+
+ tasks, recorded = [], set()
+ for tensor_id, plan in self._plan.items():
+ name, _ = self.from_tensor_id(tensor_id)
+ if self.is_weight(name) and not kwargs.get("quantize_weights",
False):
+ continue
+ if name not in recorded:
+ tasks.append({"name": tensor_id, **plan})
+ if self._cache_processed:
+ recorded.add(name)
+ return tasks
+
+ @property
+ def calibrated(self):
+ return self._calibrated
+
+ @classmethod
+ def tool_type(cls):
+ return ToolType.QUANTIZER
+
+
+class DefaultQuantizer(BaseQuantizer):
+ @classmethod
+ def tool_style(cls):
+ return "default"
+
+
+msc_utils.register_tool_cls(DefaultQuantizer)
diff --git a/python/tvm/contrib/msc/core/utils/file.py
b/python/tvm/contrib/msc/core/utils/file.py
index 146cfaf504..446efd4724 100644
--- a/python/tvm/contrib/msc/core/utils/file.py
+++ b/python/tvm/contrib/msc/core/utils/file.py
@@ -72,6 +72,8 @@ class MSCDirectory(object):
return "{}(Cleanup: {}): {} Files".format(self._path, self._cleanup,
len(self.listdir()))
def __enter__(self):
+ if not os.path.isdir(self._path):
+ os.mkdir(self._path)
os.chdir(self._path)
return self
@@ -105,6 +107,9 @@ class MSCDirectory(object):
"""
file_path = self.relpath(name)
+ base_dir = os.path.dirname(name)
+ if base_dir and not os.path.isdir(base_dir):
+ os.makedirs(base_dir)
with open(file_path, "w") as f:
f.write(contains)
return file_path
diff --git a/python/tvm/contrib/msc/core/utils/info.py
b/python/tvm/contrib/msc/core/utils/info.py
index 440789f856..d1b5cd1a26 100644
--- a/python/tvm/contrib/msc/core/utils/info.py
+++ b/python/tvm/contrib/msc/core/utils/info.py
@@ -234,6 +234,10 @@ def load_dict(str_dict: str, flavor: str = "json") -> dict:
dict_obj = json.load(f)
elif isinstance(str_dict, str):
dict_obj = json.loads(str_dict)
+ elif isinstance(str_dict, dict):
+ dict_obj = copy_dict(str_dict)
+ else:
+ raise Exception("Unexpected str_dict {}({})".format(str_dict,
type(str_dict)))
assert flavor == "json", "Unexpected flavor for load_dict: " + str(flavor)
return dict_obj
diff --git a/python/tvm/contrib/msc/framework/tensorflow/tools/__init__.py
b/python/tvm/contrib/msc/framework/tensorflow/tools/__init__.py
index d25cfd4e67..e5ebe50956 100644
--- a/python/tvm/contrib/msc/framework/tensorflow/tools/__init__.py
+++ b/python/tvm/contrib/msc/framework/tensorflow/tools/__init__.py
@@ -17,4 +17,5 @@
"""tvm.contrib.msc.framework.tensorflow.tools"""
from .prune import *
+from .quantize import *
from .track import *
diff --git a/python/tvm/contrib/msc/framework/tvm/tools/__init__.py
b/python/tvm/contrib/msc/framework/tensorflow/tools/quantize/__init__.py
similarity index 90%
copy from python/tvm/contrib/msc/framework/tvm/tools/__init__.py
copy to python/tvm/contrib/msc/framework/tensorflow/tools/quantize/__init__.py
index 226ae3102d..ed458ef838 100644
--- a/python/tvm/contrib/msc/framework/tvm/tools/__init__.py
+++ b/python/tvm/contrib/msc/framework/tensorflow/tools/quantize/__init__.py
@@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""tvm.contrib.msc.framework.tvm.tools"""
+"""tvm.contrib.msc.framework.tensorflow.tools.quantize"""
-from .prune import *
-from .track import *
+from .quantizer import *
diff --git
a/python/tvm/contrib/msc/framework/tensorflow/tools/quantize/quantizer.py
b/python/tvm/contrib/msc/framework/tensorflow/tools/quantize/quantizer.py
new file mode 100644
index 0000000000..dd6f2aac38
--- /dev/null
+++ b/python/tvm/contrib/msc/framework/tensorflow/tools/quantize/quantizer.py
@@ -0,0 +1,55 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""tvm.contrib.msc.framework.tensorflow.tools.quantize.quantizer"""
+
+from tvm.contrib.msc.core.tools.tool import ToolType
+from tvm.contrib.msc.core.tools.quantize import BaseQuantizer
+from tvm.contrib.msc.core.utils.namespace import MSCFramework
+from tvm.contrib.msc.core import utils as msc_utils
+
+
+class TensorflowQuantizerFactory(object):
+ """Quantizer factory for tensorflow"""
+
+ def create(self, base_cls: BaseQuantizer) -> BaseQuantizer:
+ """Create adaptive quantizer
+
+ Parameters
+ ----------
+ base_cls: BaseQuantizer
+ The base quantizer class
+
+ Returns
+ -------
+ quantizer_cls: BaseQuantizer
+ The quantizer class.
+ """
+
+ class Quantizer(base_cls):
+ """Adaptive quantizer for tensorflow"""
+
+ @classmethod
+ def framework(cls):
+ return MSCFramework.TENSORFLOW
+
+ return Quantizer
+
+
+factory = TensorflowQuantizerFactory()
+tools = msc_utils.get_registered_tool_cls(MSCFramework.MSC,
ToolType.QUANTIZER, tool_style="all")
+for tool in tools.values():
+ msc_utils.register_tool_cls(factory.create(tool))
diff --git a/python/tvm/contrib/msc/framework/tensorrt/codegen/sources.py
b/python/tvm/contrib/msc/framework/tensorrt/codegen/sources.py
index b6497e9258..a5df42f78b 100644
--- a/python/tvm/contrib/msc/framework/tensorrt/codegen/sources.py
+++ b/python/tvm/contrib/msc/framework/tensorrt/codegen/sources.py
@@ -302,6 +302,171 @@ bool TRTUtils::DeserializeEngineFromFile(const
std::string& file,
"""
+def get_trt_quantize_h_code():
+ """Create trt_quantize header file codes
+
+ Returns
+ -------
+ source: str
+ The trt_quantize header source.
+ """
+
+ return """#ifndef TVM_CONTRIB_MSC_UTILS_TRT_QUANTIZE_H_
+#define TVM_CONTRIB_MSC_UTILS_TRT_QUANTIZE_H_
+
+#include <cassert>
+#include <fstream>
+#include <iterator>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "NvInfer.h"
+#include "base.h"
+#include "trt_common.h"
+
+namespace tvm {
+namespace contrib {
+namespace msc {
+
+using namespace nvinfer1;
+
+class CalibrateHelper {
+ public:
+ CalibrateHelper(const std::string& range_file, const std::string& folder,
int max_size = -1);
+
+ ~CalibrateHelper() {
+ for (const auto& buffer : cpu_buffers_) {
+ free(buffer);
+ }
+ for (const auto& buffer : gpu_buffers_) {
+ CHECK(cudaFree(buffer));
+ }
+ }
+
+ bool GetBatch(void* bindings[], const char* names[], int nbBindings);
+
+ const void* ReadCache(size_t& length);
+
+ void WriteCache(const void* cache, size_t length);
+
+ private:
+ std::unique_ptr<DatasetReader> reader_;
+ std::string range_file_;
+ std::vector<char> cache_;
+ std::vector<void*> cpu_buffers_;
+ std::vector<void*> gpu_buffers_;
+};
+
+#define CALIBRATE_MEMBERS(Calibrator)
\\
+ public:
\\
+ Calibrator(const std::string& range_file, const std::string& folder, int
max_size = -1) { \\
+ helper_.reset(new CalibrateHelper(range_file, folder, max_size));
\\
+ }
\\
+
\\
+ virtual ~Calibrator() {}
\\
+
\\
+ int getBatchSize() const noexcept override { return 1; }
\\
+
\\
+ bool getBatch(void* bindings[], const char* names[], int nbBindings)
noexcept override { \\
+ return helper_->GetBatch(bindings, names, nbBindings);
\\
+ }
\\
+
\\
+ const void* readCalibrationCache(size_t& length) noexcept override {
\\
+ return helper_->ReadCache(length);
\\
+ }
\\
+
\\
+ void writeCalibrationCache(const void* cache, size_t length) noexcept
override { \\
+ return helper_->WriteCache(cache, length);
\\
+ }
\\
+
\\
+ private:
\\
+ std::unique_ptr<CalibrateHelper> helper_;
+
+class MSCInt8EntropyCalibrator : public IInt8EntropyCalibrator {
+ CALIBRATE_MEMBERS(MSCInt8EntropyCalibrator)
+};
+
+class MSCInt8EntropyCalibrator2 : public IInt8EntropyCalibrator2 {
+ CALIBRATE_MEMBERS(MSCInt8EntropyCalibrator2)
+};
+
+} // namespace msc
+} // namespace contrib
+} // namespace tvm
+
+#endif // TVM_CONTRIB_MSC_UTILS_TRT_QUANTIZE_H_
+"""
+
+
+def get_trt_quantize_cc_code():
+ """Create trt_quantize cc file codes
+
+ Returns
+ -------
+ source: str
+ The trt_quantize cc source.
+ """
+
+ return """#include "trt_quantize.h"
+
+namespace tvm {
+namespace contrib {
+namespace msc {
+
+using namespace nvinfer1;
+
+CalibrateHelper::CalibrateHelper(const std::string& range_file, const
std::string& folder,
+ int max_size) {
+ range_file_ = range_file;
+ reader_.reset(new DatasetReader(folder, max_size));
+ const auto& tensor_names = reader_->GetTensorNames();
+ cpu_buffers_.resize(tensor_names.size());
+ gpu_buffers_.resize(tensor_names.size());
+ for (size_t i = 0; i < tensor_names.size(); i++) {
+ size_t tensor_size = reader_->GetTensorSize(tensor_names[i]);
+ cpu_buffers_[i] = malloc(tensor_size);
+ CHECK(cudaMalloc(&gpu_buffers_[i], tensor_size));
+ }
+}
+
+bool CalibrateHelper::GetBatch(void* bindings[], const char* names[], int
nbBindings) {
+ if (!reader_->ReadNext(cpu_buffers_.data())) {
+ return false;
+ }
+ for (size_t i = 0; i < nbBindings; i++) {
+ CHECK(cudaMemcpy(gpu_buffers_[i], cpu_buffers_[i],
reader_->GetTensorSize(names[i]),
+ cudaMemcpyHostToDevice));
+ bindings[i] = gpu_buffers_[i];
+ }
+ return true;
+}
+
+const void* CalibrateHelper::ReadCache(size_t& length) {
+ cache_.clear();
+ std::ifstream in_file(range_file_, std::ifstream::binary);
+ if (!in_file.is_open()) {
+ return nullptr;
+ }
+ in_file >> std::noskipws;
+ std::copy(std::istream_iterator<char>(in_file),
std::istream_iterator<char>(),
+ std::back_inserter(cache_));
+ length = cache_.size();
+ return length > 0 ? &cache_[0] : nullptr;
+}
+
+void CalibrateHelper::WriteCache(const void* cache, size_t length) {
+ std::ofstream output(range_file_, std::ios::binary);
+ output.write(reinterpret_cast<const char*>(cache), length);
+}
+
+} // namespace msc
+} // namespace contrib
+} // namespace tvm
+"""
+
+
def get_trt_sources() -> Dict[str, str]:
"""Create trt sources for cpp codegen
@@ -313,6 +478,11 @@ def get_trt_sources() -> Dict[str, str]:
sources = get_base_sources()
sources.update(
- {"trt_common.h": get_trt_common_h_code(), "trt_common.cc":
get_trt_common_cc_code()}
+ {
+ "trt_common.h": get_trt_common_h_code(),
+ "trt_common.cc": get_trt_common_cc_code(),
+ "trt_quantize.h": get_trt_quantize_h_code(),
+ "trt_quantize.cc": get_trt_quantize_cc_code(),
+ }
)
return sources
diff --git a/python/tvm/contrib/msc/framework/tensorrt/runtime/runner.py
b/python/tvm/contrib/msc/framework/tensorrt/runtime/runner.py
index 15a42b2cf9..c66f8d1450 100644
--- a/python/tvm/contrib/msc/framework/tensorrt/runtime/runner.py
+++ b/python/tvm/contrib/msc/framework/tensorrt/runtime/runner.py
@@ -17,9 +17,14 @@
# pylint: disable=unused-import
"""tvm.contrib.msc.framework.tensorrt.runtime.runner"""
+from typing import Any, List, Dict
+
import tvm
+from tvm.contrib.msc.core.ir import MSCGraph
from tvm.contrib.msc.core.runtime import BYOCRunner
+from tvm.contrib.msc.core.tools import ToolType
from tvm.contrib.msc.core.utils.namespace import MSCFramework
+from tvm.contrib.msc.core import utils as msc_utils
from tvm.contrib.msc.framework.tensorrt.frontend import (
partition_for_tensorrt,
transform_for_tensorrt,
@@ -44,6 +49,54 @@ class TensorRTRunner(BYOCRunner):
self._device = "cuda"
return super().setup()
+ def apply_tool(self, tool_type: str, data_loader: Any = None) -> dict:
+ """Execute tool and get plan
+
+ Parameters
+ -------
+ tool_type: str
+ The tool type, should be in ToolType
+ data_loader:
+ The data loader
+ """
+
+ assert tool_type in self._tools, "Can not find tool " + str(tool_type)
+ if tool_type == ToolType.QUANTIZER:
+ quantizer = self.get_tool(ToolType.QUANTIZER)
+ assert data_loader, "data_loader should be given to plan prune"
+ for inputs in data_loader():
+ self.run(inputs)
+ self._generate_model()
+ quantizer.calibrate()
+ assert quantizer.calibrated, "Failed to calibrate the tenosrrt
quantizer"
+ return super().apply_tool(tool_type, data_loader)
+
+ def _generate_model(
+ self, graphs: List[MSCGraph] = None, weights: List[Dict[str,
tvm.nd.array]] = None
+ ) -> Any:
+ """Codegen the model according to framework
+
+ Parameters
+ -------
+ graphs: list<MSCgraph>
+ The msc graphs.
+ weights: list<dict<str, tvm.nd.array>>
+ The weights
+
+ Returns
+ -------
+ model: Any
+ The meta model
+ """
+
+ codegen = self._generate_config.get("codegen")
+ if not isinstance(codegen, (list, tuple)):
+ self._generate_config["codegen"] = [msc_utils.copy_dict(codegen)]
* len(self._graphs)
+ for tool in self.get_tools():
+ self._generate_config = tool.config_generate(self._generate_config)
+
+ return super()._generate_model(graphs, weights)
+
@classmethod
def target_transform(cls, mod: tvm.IRModule):
"""Transform the mod by target.
diff --git a/python/tvm/contrib/msc/framework/tensorrt/tools/__init__.py
b/python/tvm/contrib/msc/framework/tensorrt/tools/__init__.py
index ecc82bc40f..c010a42004 100644
--- a/python/tvm/contrib/msc/framework/tensorrt/tools/__init__.py
+++ b/python/tvm/contrib/msc/framework/tensorrt/tools/__init__.py
@@ -17,4 +17,5 @@
"""tvm.contrib.msc.framework.tensorrt.tools"""
from .prune import *
+from .quantize import *
from .track import *
diff --git a/python/tvm/contrib/msc/framework/tensorflow/tools/__init__.py
b/python/tvm/contrib/msc/framework/tensorrt/tools/quantize/__init__.py
similarity index 88%
copy from python/tvm/contrib/msc/framework/tensorflow/tools/__init__.py
copy to python/tvm/contrib/msc/framework/tensorrt/tools/quantize/__init__.py
index d25cfd4e67..e47b3324c9 100644
--- a/python/tvm/contrib/msc/framework/tensorflow/tools/__init__.py
+++ b/python/tvm/contrib/msc/framework/tensorrt/tools/quantize/__init__.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""tvm.contrib.msc.framework.tensorflow.tools"""
+"""tvm.contrib.msc.framework.tensorrt.tools.quantize"""
-from .prune import *
-from .track import *
+from .quantizer import *
+from .method import *
diff --git a/python/tvm/contrib/msc/framework/tensorrt/tools/quantize/method.py
b/python/tvm/contrib/msc/framework/tensorrt/tools/quantize/method.py
new file mode 100644
index 0000000000..0feb836d13
--- /dev/null
+++ b/python/tvm/contrib/msc/framework/tensorrt/tools/quantize/method.py
@@ -0,0 +1,149 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-argument
+"""tvm.contrib.msc.framework.tensorrt.tools.quantize.method"""
+
+from typing import Dict
+
+from tvm.contrib.msc.core.tools.quantize import QuantizeMethod, BaseQuantizer
+from tvm.contrib.msc.core.utils.namespace import MSCFramework
+from tvm.contrib.msc.core import utils as msc_utils
+
+
+class TensorRTQuantizeMethod(QuantizeMethod):
+ """Default quantize method for tensorrt"""
+
+ @classmethod
+ def quantize_normal(
+ cls,
+ quantizer: BaseQuantizer,
+ tensor_ctx: Dict[str, str],
+ name: str,
+ consumer: str,
+ scale: float,
+ nbits: int = 8,
+ axis: int = -1,
+ sign: bool = True,
+ rounding: str = "round",
+ epsilon: float = 1.0 / (1 << 24),
+ ) -> Dict[str, str]:
+ """Calibrate the data by kl_divergence
+
+ Parameters
+ ----------
+ quantizer: BaseQuantizer
+ The quantizer
+ tensor_ctx: dict<str, str>
+ Tensor describe items.
+ name: str
+ The name of the tensor.
+ consumer: str
+ The name of the consumer.
+ scale: float
+ The scale factor
+ nbits: int
+ The number bits for quantize.
+ axis: int
+ The axis.
+ sign: bool
+ Whether to use sign.
+ rounding str
+ The rounding method.
+ epsilon: float
+ The epsilon for get scale.
+
+ Returns
+ -------
+ tensor_ctx: dict<str, str>
+ Tensor describe items.
+ """
+
+ if quantizer.is_weight(name):
+ return tensor_ctx
+ dtype = quantizer.find_tensor(name).dtype_name
+ precision = "DataType::k"
+ if nbits == 8:
+ precision += "INT8"
+ elif dtype == "float16":
+ precision += "HALF"
+ elif dtype == "float32":
+ precision += "FLOAT"
+ else:
+ raise TypeError("nbits {} is not supported".format(nbits))
+ tensor_ctx["processed"].extend(
+ [
+ "{}->setPrecision({})".format(tensor_ctx["producer"],
precision),
+ "{0}->setDynamicRange(-{1}, {1})".format(tensor_ctx["tensor"],
scale),
+ ]
+ )
+ return tensor_ctx
+
+ @classmethod
+ def dequantize_normal(
+ cls,
+ quantizer: BaseQuantizer,
+ tensor_ctx: Dict[str, str],
+ name: str,
+ consumer: str,
+ scale: float,
+ nbits: int = 8,
+ axis: int = -1,
+ sign: bool = True,
+ rounding: str = "round",
+ epsilon: float = 1.0 / (1 << 24),
+ ) -> Dict[str, str]:
+ """Calibrate the data by kl_divergence
+
+ Parameters
+ ----------
+ quantizer: BaseQuantizer
+ The quantizer
+ tensor_ctx: dict<str, str>
+ Tensor describe items.
+ name: str
+ The name of the tensor.
+ consumer: str
+ The name of the consumer.
+ scale: float
+ The scale factor
+ nbits: int
+ The number bits for quantize.
+ axis: int
+ The axis.
+ sign: bool
+ Whether to use sign.
+ rounding str
+ The rounding method.
+ epsilon: float
+ The epsilon for get scale.
+
+ Returns
+ -------
+ tensor_ctx: dict<str, str>
+ Tensor describe items.
+ """
+
+ return cls.quantize_normal(
+ quantizer, tensor_ctx, name, consumer, scale, nbits, axis, sign,
rounding, epsilon
+ )
+
+ @classmethod
+ def framework(cls):
+ return MSCFramework.TENSORRT
+
+
+msc_utils.register_tool_method(TensorRTQuantizeMethod)
diff --git
a/python/tvm/contrib/msc/framework/tensorrt/tools/quantize/quantizer.py
b/python/tvm/contrib/msc/framework/tensorrt/tools/quantize/quantizer.py
new file mode 100644
index 0000000000..f971186196
--- /dev/null
+++ b/python/tvm/contrib/msc/framework/tensorrt/tools/quantize/quantizer.py
@@ -0,0 +1,366 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""tvm.contrib.msc.framework.tensorrt.tools.quantize.quantizer"""
+
+import os
+import struct
+from typing import List, Dict, Any, Tuple
+
+import tvm
+from tvm.contrib.msc.core.ir import MSCGraph
+from tvm.contrib.msc.core.tools.tool import ToolType, ToolStrategy
+from tvm.contrib.msc.core.tools.quantize import BaseQuantizer, QuantizeStage
+from tvm.contrib.msc.core.utils.namespace import MSCFramework
+from tvm.contrib.msc.core import utils as msc_utils
+
+
+class TensorRTQuantizerFactory(object):
+ """Quantizer factory for tensorrt"""
+
+ def create(self, base_cls: BaseQuantizer) -> BaseQuantizer:
+ """Create adaptive quantizer
+
+ Parameters
+ ----------
+ base_cls: BaseQuantizer
+ The base quantizer class
+
+ Returns
+ -------
+ quantizer_cls: BaseQuantizer
+ The quantizer class.
+ """
+
+ class Quantizer(base_cls):
+ """Adaptive quantizer for tensorrt"""
+
+ def setup(self) -> dict:
+ """Setup the tool
+
+ Returns
+ -------
+ info: dict
+ The setup info.
+ """
+
+ if self._plan:
+ self._use_range = all(
+ info.get("use_range", False) for info in
self._plan.values()
+ )
+ else:
+ self._use_range = True
+ return super().setup()
+
+ def _reset(
+ self, graphs: List[MSCGraph], weights: List[Dict[str,
tvm.nd.array]]
+ ) -> Tuple[List[MSCGraph], List[Dict[str, tvm.nd.array]]]:
+ """Reset the tool
+
+ Parameters
+ ----------
+ graphs: list<MSCgraph>
+ The msc graphs.
+ weights: list<dict<str, tvm.nd.array>>
+ The weights
+
+ Returns
+ -------
+ graphs: list<MSCgraph>
+ The msc graphs.
+ weights: list<dict<str, tvm.nd.array>>
+ The weights
+ """
+
+ config_folder = msc_utils.get_config_dir()
+ self._range_files = [config_folder.relpath(g.name + ".range")
for g in graphs]
+ calibrate_root =
msc_utils.get_dataset_dir().create_dir("Calibrate")
+ self._calibrate_folders = [calibrate_root.relpath(g.name) for
g in graphs]
+ if self._calibrated:
+ if self._use_range:
+ for r_file, graph in zip(self._range_files, graphs):
+ if not os.path.isfile(r_file):
+ self._plan_to_range(graph, r_file)
+ self._logger.debug(
+ "G[%s](%s) use range file: %s",
+ graph.name,
+ self._stage,
+ r_file,
+ )
+ else:
+ self._quantized_tensors = set()
+ elif self._stage == QuantizeStage.GATHER:
+ self._calibrate_savers = []
+ for folder, graph in zip(self._calibrate_folders, graphs):
+ saver_options = {"input_names": [i.name for i in
graph.get_inputs()]}
+ saver = msc_utils.IODataSaver(folder, saver_options)
+ self._calibrate_savers.append(saver)
+ self._logger.debug(
+ "G[%s](%s) create calibrate saver: %s",
+ graph.name,
+ self._stage,
+ saver,
+ )
+ else:
+ assert all(
+ msc_utils.is_io_dataset(f) for f in
self._calibrate_folders
+ ), "Some IODataset missing: " +
str(self._calibrate_folders)
+ return super()._reset(graphs, weights)
+
+ def _execute_after_build(self, codegen_context: dict) -> dict:
+ """Execute after model build
+
+ Parameters
+ ----------
+ codegen_context: dict
+ The context.
+
+ Returns
+ ----------
+ codegen_context: dict
+ The processed context.
+ """
+
+ if self._stage == QuantizeStage.GATHER and self._forward_cnt
== 0:
+ return codegen_context
+ if not self._use_range:
+ return codegen_context
+ processed = ["// Set int8 calibrator"]
+ range_file = self.get_graph().name + ".range"
+ version = [int(v) for v in
codegen_context["version"].split(".")]
+ if msc_utils.compare_version(version, [6, 0, 0]) >= 0:
+ configer = codegen_context["config"]
+ else:
+ configer = codegen_context["builder"]
+ # check the range file if calibrated
+ if self._calibrated:
+ processed.extend(
+ [
+ 'if (!FileUtils::FileExist("{}"))
{{'.format(range_file),
+ ' logger.log(ILogger::Severity::kERROR, "{} not
exist!");'.format(
+ range_file
+ ),
+ " return -1;",
+ "}",
+ ]
+ )
+ processed.extend(
+ [
+ 'MSCInt8EntropyCalibrator2 calibrator("{}",
"{}");'.format(
+ range_file, self._calibrate_folders[self._graph_id]
+ ),
+ "{}->setInt8Calibrator(&calibrator);".format(configer),
+ ]
+ )
+ codegen_context["processed"].extend(processed)
+ return codegen_context
+
+ def _execute_before_forward(self, step_context: dict) -> dict:
+ """Execute before model forward
+
+ Parameters
+ ----------
+ step_context: dict
+ The context.
+
+ Returns
+ ----------
+ step_context: dict
+ The processed context.
+ """
+
+ if self._stage == QuantizeStage.GATHER:
+ saver = self._calibrate_savers[self._graph_id]
+ saver.save_batch(
+ {name: data.asnumpy() for name, data in
step_context["datas"].items()}
+ )
+ for name, data in step_context["datas"].items():
+ self.debug_tensor(data, name, "any", "ctx_gathered")
+ super()._execute_before_forward(step_context)
+
+ def _quantize_tensor(
+ self,
+ tensor_ctx: Dict[str, str],
+ name: str,
+ consumer: str,
+ strategys: List[ToolStrategy],
+ ) -> Dict[str, str]:
+ """Quantize tensor
+
+ Parameters
+ -------
+ tensor_ctx: dict<str, str>
+ Tensor describe items.
+ name: str
+ The name of the tensor.
+ consumer: str
+ The name of the consumer.
+ strategys: list<ToolStrategy>
+ The strategys for the tensor.
+
+ Returns
+ -------
+ tensor_ctx: dict<str, str>
+ Tensor items with processed.
+ """
+
+ if not self._use_range and name not in self._quantized_tensors:
+ self._quantized_tensors.add(name)
+ return super()._quantize_tensor(tensor_ctx, name,
consumer, strategys)
+ return tensor_ctx
+
+ def calibrate(self) -> dict:
+ """Calibrate the datas
+
+ Returns
+ -------
+ plan: dict
+ The calibrated plan.
+ """
+
+ for r_file, graph in zip(self._range_files, self._graphs):
+ self._range_to_plan(graph, r_file)
+ self._calibrated, self._forward_cnt = True, 0
+ self.change_stage("quantize")
+ return self._plan
+
+ def config_generate(self, generate_config: Dict[str, Any]) ->
Dict[str, Any]:
+ """Update the generate configs
+
+ Parameters
+ ----------
+ generate_config: dict<str, Any>
+ The generate_config.
+
+ Returns
+ -------
+ generate_config: dict<str, Any>
+ The updated generate_config.
+ """
+
+ if self._calibrated:
+ if self._use_range:
+ for config, r_file in zip(generate_config["codegen"],
self._range_files):
+ if os.path.isfile(r_file):
+ config.update({"range_file": r_file,
"precision": "int8"})
+ elif self._stage == QuantizeStage.GATHER and self._forward_cnt
> 0:
+ for config, saver, r_file in zip(
+ generate_config["codegen"], self._calibrate_savers,
self._range_files
+ ):
+ saver.finalize()
+ self._logger.debug(
+ "%ssave %d datas to %s",
+ self.msg_mark(in_forward=False),
+ self._forward_cnt,
+ saver.folder,
+ )
+ config.update(
+ {"dataset": saver.folder, "range_file": r_file,
"precision": "int8"}
+ )
+ self.change_stage(QuantizeStage.CALIBRATE)
+ return generate_config
+
+ def _plan_to_range(self, graph: MSCGraph, range_file: str,
title="MSCCalibrate"):
+ """Extract plan config to range_file
+
+ Parameters
+ ----------
+ plan: dict
+ The plan.
+ graph: MSCGraph
+ The graph.
+ range_file: str
+ The output range_file path.
+ title: str
+ The title of the range file.
+ """
+
+ def _scale_to_hex(scale):
+ return hex(struct.unpack("<I", struct.pack("<f", scale /
127))[0])[2:]
+
+ recorded = set()
+ with open(range_file, "w") as f:
+ f.write(title + "\n")
+ for name, info in self._plan.items():
+ t_name, _ = self.from_tensor_id(name)
+ if not graph.find_tensor(t_name):
+ continue
+ if t_name not in recorded:
+ f.write("{}: {}\n").format(t_name,
_scale_to_hex(info["scale"]))
+ recorded.add(t_name)
+ self._logger.debug(
+ "Graph[%s](%s) extract %d plan to range %s",
+ graph.name,
+ self._stage,
+ len(recorded),
+ range_file,
+ )
+
+ def _range_to_plan(self, graph: MSCGraph, range_file: str):
+ """Extract scale in range_file to plan
+
+ Parameters
+ ----------
+ graph: MSCGraph
+ The graph.
+ range_file: str
+ The input range_file path.
+ """
+
+ range_num = 0
+ with open(range_file, "r") as f:
+ f.readline()
+ line = f.readline()
+ while line:
+ name, scale = line.split(": ")
+ scale = scale.strip()
+ if scale == "0":
+ value = 0.0
+ else:
+ value = struct.unpack("!f",
bytes.fromhex(scale))[0] * 127
+ range_num += 1
+ consumers = graph.find_consumers(name)
+ if consumers:
+ for c in consumers:
+ self._plan[self.to_tensor_id(name, c.name)] = {
+ "scale": value,
+ "use_range": True,
+ }
+ else:
+ self._plan[self.to_tensor_id(name, "exit")] = {
+ "scale": value,
+ "use_range": True,
+ }
+ line = f.readline()
+ self._logger.debug(
+ "Graph[%s](%s) extract %d range to plan from %s",
+ graph.name,
+ self._stage,
+ range_num,
+ range_file,
+ )
+
+ @classmethod
+ def framework(cls):
+ return MSCFramework.TENSORRT
+
+ return Quantizer
+
+
+factory = TensorRTQuantizerFactory()
+tools = msc_utils.get_registered_tool_cls(MSCFramework.MSC,
ToolType.QUANTIZER, tool_style="all")
+for tool in tools.values():
+ msc_utils.register_tool_cls(factory.create(tool))
diff --git a/python/tvm/contrib/msc/framework/torch/tools/__init__.py
b/python/tvm/contrib/msc/framework/torch/tools/__init__.py
index dda1e13822..f8fee73d69 100644
--- a/python/tvm/contrib/msc/framework/torch/tools/__init__.py
+++ b/python/tvm/contrib/msc/framework/torch/tools/__init__.py
@@ -17,4 +17,5 @@
"""tvm.contrib.msc.framework.torch.tools"""
from .prune import *
+from .quantize import *
from .track import *
diff --git a/python/tvm/contrib/msc/framework/tensorflow/tools/__init__.py
b/python/tvm/contrib/msc/framework/torch/tools/quantize/__init__.py
similarity index 88%
copy from python/tvm/contrib/msc/framework/tensorflow/tools/__init__.py
copy to python/tvm/contrib/msc/framework/torch/tools/quantize/__init__.py
index d25cfd4e67..9391900edf 100644
--- a/python/tvm/contrib/msc/framework/tensorflow/tools/__init__.py
+++ b/python/tvm/contrib/msc/framework/torch/tools/quantize/__init__.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""tvm.contrib.msc.framework.tensorflow.tools"""
+"""tvm.contrib.msc.framework.torch.tools.quantize"""
-from .prune import *
-from .track import *
+from .quantizer import *
+from .method import *
diff --git a/python/tvm/contrib/msc/framework/torch/tools/quantize/method.py
b/python/tvm/contrib/msc/framework/torch/tools/quantize/method.py
new file mode 100644
index 0000000000..6f82a796e1
--- /dev/null
+++ b/python/tvm/contrib/msc/framework/torch/tools/quantize/method.py
@@ -0,0 +1,237 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-argument
+"""tvm.contrib.msc.framework.torch.tools.quantize.method"""
+
+import numpy as np
+import torch
+from tvm.contrib.msc.core.tools.quantize import QuantizeMethod, BaseQuantizer
+from tvm.contrib.msc.core.utils.namespace import MSCFramework
+from tvm.contrib.msc.core import utils as msc_utils
+
+
+class TorchQuantizeMethod(QuantizeMethod):
+ """Default quantize method for torch"""
+
+ @classmethod
+ def amplify_data(
+ cls,
+ data: torch.Tensor,
+ scale: float,
+ min_val: float,
+ max_val: float,
+ rounding: str = "round",
+ ) -> torch.Tensor:
+ """Amplify the data
+
+ Parameters
+ ----------
+ data: torch.Tensor
+ The source data.
+ scale: float
+ The scale factor
+ min_val: float
+ The min.
+ max_val: float
+ The max.
+ rounding: str
+ The round method
+
+ Returns
+ -------
+ data: torch.Tensor
+ The processed data.
+ """
+
+ if rounding == "null":
+ return torch.clamp(data * scale, min_val, max_val)
+ if rounding == "floor":
+ return torch.clamp(torch.floor(data * scale), min_val, max_val)
+ if rounding == "ceil":
+ return torch.clamp(torch.ceil(data * scale), min_val, max_val)
+ if rounding == "round":
+ return torch.clamp(torch.round(data * scale), min_val, max_val)
+ if rounding == "trunc":
+ return torch.clamp(torch.trunc(data * scale), min_val, max_val)
+ if rounding == "logic_round":
+ data = torch.clamp(data * scale, min_val, max_val)
+ negative_ceil = torch.where(
+ torch.logical_and(data < 0, (data - torch.floor(data)) ==
0.5), torch.ceil(data), 0
+ )
+ data = torch.where(
+ torch.logical_and(data < 0, (data - torch.floor(data)) ==
0.5), 0, data
+ )
+ data = torch.where((data - torch.floor(data)) >= 0.5,
torch.ceil(data), data)
+ data = torch.where((data - torch.floor(data)) < 0.5,
torch.floor(data), data)
+ return data + negative_ceil
+ raise TypeError("Unexpected rounding " + str(rounding))
+
+ @classmethod
+ def gather_maxmin(
+ cls,
+ quantizer: BaseQuantizer,
+ data: torch.Tensor,
+ name: str,
+ consumer: str,
+ plan: dict,
+ nbits: int = 8,
+ ) -> dict:
+ """Gather the data by max/min
+
+ Parameters
+ ----------
+ quantizer: BaseQuantizer
+ The quantizer
+ data: np.ndarray
+ The source data.
+ name: str
+ The name of the tensor.
+ consumer: str
+ The name of the consumer.
+ plan: dict
+ The pre-calibrated plan.
+ nbits: int
+ The number bits for quantize.
+
+ Returns
+ -------
+ plan: dict
+ The plan of the tensor.
+ """
+
+ abs_max_list = plan.get("abs_max_list", [])
+ abs_max_list.append(float(torch.abs(data).max()))
+ max_list = plan.get("max_list", [])
+ max_list.append(float(data.max()))
+ min_list = plan.get("min_list", [])
+ min_list.append(float(data.min()))
+ return {
+ "abs_max_list": abs_max_list,
+ "max_list": max_list,
+ "min_list": min_list,
+ "calibrated": False,
+ }
+
+ @classmethod
+ def gather_max_per_channel(
+ cls,
+ quantizer: BaseQuantizer,
+ data: torch.Tensor,
+ name: str,
+ consumer: str,
+ plan: dict,
+ nbits: int = 8,
+ channel: str = "O",
+ auto_unsign: bool = False,
+ ) -> dict:
+ """Gather the data by max_per_channel
+
+ Parameters
+ ----------
+ quantizer: BaseQuantizer
+ The quantizer
+ data: np.ndarray
+ The source data.
+ name: str
+ The name of the tensor.
+ consumer: str
+ The name of the consumer.
+ plan: dict
+ The pre-calibrated plan.
+ nbits: int
+ The number bits for quantize.
+ channel: str
+ The channel reference.
+ auto_unsign: bool
+ Whether to use auto unsign.
+
+ Returns
+ -------
+ plan: dict
+ The plan of the tensor.
+ """
+
+ weight = quantizer.find_tensor(name)
+ axis = weight.layout_of(channel)
+ channel_max = [torch.abs(d).max() for d in torch.chunk(data,
data.shape[axis], dim=axis)]
+ sign = data.min() < 0 if auto_unsign else True
+ valid_range = 2 ** (nbits - int(sign)) - 1
+ scale = [valid_range / float(m) for m in channel_max]
+ return {"scale": scale, "sign": sign, "axis": axis, "calibrated": True}
+
+ @classmethod
+ def quantize_normal(
+ cls,
+ quantizer: BaseQuantizer,
+ data: torch.Tensor,
+ name: str,
+ consumer: str,
+ scale: float,
+ nbits: int = 8,
+ axis: int = -1,
+ sign: bool = True,
+ rounding: str = "round",
+ epsilon: float = 1.0 / (1 << 24),
+ ) -> torch.Tensor:
+ """Calibrate the data by kl_divergence
+
+ Parameters
+ ----------
+ quantizer: BaseQuantizer
+ The quantizer
+ data: torch.Tensor
+ The source data.
+ name: str
+ The name of the tensor.
+ consumer: str
+ The name of the consumer.
+ scale: float
+ The scale factor
+ nbits: int
+ The number bits for quantize.
+ axis: int
+ The axis.
+ sign: bool
+ Whether to use sign.
+ rounding str
+ The rounding method.
+ epsilon: float
+ The epsilon for get scale.
+
+ Returns
+ -------
+ data: torch.Tensor
+ The processed tensor.
+ """
+
+ valid_range = 2 ** (nbits - int(sign)) - 1
+ min_val = -valid_range if sign else 0
+ scale_tensor = quantizer._get_tensor_cache(name, consumer,
"scale_tensor")
+ if scale_tensor is None:
+ scale_tensor = cls.get_scale_tensor(data, scale, axis, epsilon)
+ if isinstance(scale_tensor, np.ndarray):
+ scale_tensor = torch.from_numpy(scale_tensor).to(data.device)
+ quantizer._save_tensor_cache(name, consumer, "scale_tensor",
scale_tensor)
+ data = cls.amplify_data(data, scale_tensor, min_val, valid_range,
rounding)
+ return data / scale_tensor
+
+ @classmethod
+ def framework(cls):
+ return MSCFramework.TORCH
+
+
+msc_utils.register_tool_method(TorchQuantizeMethod)
diff --git a/python/tvm/contrib/msc/framework/torch/tools/quantize/quantizer.py
b/python/tvm/contrib/msc/framework/torch/tools/quantize/quantizer.py
new file mode 100644
index 0000000000..0e5c599b87
--- /dev/null
+++ b/python/tvm/contrib/msc/framework/torch/tools/quantize/quantizer.py
@@ -0,0 +1,55 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""tvm.contrib.msc.framework.torch.tools.quantize.quantizer"""
+
+from tvm.contrib.msc.core.tools.tool import ToolType
+from tvm.contrib.msc.core.tools.quantize import BaseQuantizer
+from tvm.contrib.msc.core.utils.namespace import MSCFramework
+from tvm.contrib.msc.core import utils as msc_utils
+
+
+class TorchQuantizerFactory(object):
+ """Quantizer factory for torch"""
+
+ def create(self, base_cls: BaseQuantizer) -> BaseQuantizer:
+ """Create adaptive quantizer
+
+ Parameters
+ ----------
+ base_cls: BaseQuantizer
+ The base quantizer class
+
+ Returns
+ -------
+ quantizer_cls: BaseQuantizer
+ The quantizer class.
+ """
+
+ class Quantizer(base_cls):
+ """Adaptive quantizer for torch"""
+
+ @classmethod
+ def framework(cls):
+ return MSCFramework.TORCH
+
+ return Quantizer
+
+
+factory = TorchQuantizerFactory()
+tools = msc_utils.get_registered_tool_cls(MSCFramework.MSC,
ToolType.QUANTIZER, tool_style="all")
+for tool in tools.values():
+ msc_utils.register_tool_cls(factory.create(tool))
diff --git a/python/tvm/contrib/msc/framework/tvm/tools/__init__.py
b/python/tvm/contrib/msc/framework/tvm/tools/__init__.py
index 226ae3102d..ddfd41f3c8 100644
--- a/python/tvm/contrib/msc/framework/tvm/tools/__init__.py
+++ b/python/tvm/contrib/msc/framework/tvm/tools/__init__.py
@@ -17,4 +17,5 @@
"""tvm.contrib.msc.framework.tvm.tools"""
from .prune import *
+from .quantize import *
from .track import *
diff --git a/python/tvm/contrib/msc/framework/tensorflow/tools/__init__.py
b/python/tvm/contrib/msc/framework/tvm/tools/quantize/__init__.py
similarity index 88%
copy from python/tvm/contrib/msc/framework/tensorflow/tools/__init__.py
copy to python/tvm/contrib/msc/framework/tvm/tools/quantize/__init__.py
index d25cfd4e67..0026724989 100644
--- a/python/tvm/contrib/msc/framework/tensorflow/tools/__init__.py
+++ b/python/tvm/contrib/msc/framework/tvm/tools/quantize/__init__.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""tvm.contrib.msc.framework.tensorflow.tools"""
+"""tvm.contrib.msc.framework.tvm.tools.quantize"""
-from .prune import *
-from .track import *
+from .quantizer import *
+from .method import *
diff --git a/python/tvm/contrib/msc/framework/tvm/tools/quantize/method.py
b/python/tvm/contrib/msc/framework/tvm/tools/quantize/method.py
new file mode 100644
index 0000000000..9966e9c1af
--- /dev/null
+++ b/python/tvm/contrib/msc/framework/tvm/tools/quantize/method.py
@@ -0,0 +1,204 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-argument
+"""tvm.contrib.msc.framework.tvm.tools.quantize.method"""
+
+from typing import Tuple
+import numpy as np
+
+import tvm
+from tvm.relax import op as relax_op
+from tvm.contrib.msc.core.tools.quantize import QuantizeMethod, BaseQuantizer
+from tvm.contrib.msc.core.utils.namespace import MSCFramework
+from tvm.contrib.msc.core import utils as msc_utils
+from tvm.contrib.msc.core import _ffi_api
+
+
+class TVMQuantizeMethod(QuantizeMethod):
+ """Default quantize method for tvm"""
+
+ @classmethod
+ def get_quantize_cache(
+ cls,
+ quantizer: BaseQuantizer,
+ data: tvm.relax.Var,
+ name: str,
+ consumer: str,
+ scale: float,
+ axis: int = -1,
+ epsilon: float = 1.0 / (1 << 24),
+ ) -> Tuple[tvm.relax.Constant, tvm.relax.Constant]:
+ """Calibrate the data by kl_divergence
+
+ Parameters
+ ----------
+ quantizer: BaseQuantizer
+ The quantizer
+ data: tvm.relax.Var
+ The source data.
+ name: str
+ The name of the tensor.
+ consumer: str
+ The name of the consumer.
+ scale: float
+ The scale factor
+ axis: int
+ The axis.
+ epsilon: float
+ The epsilon for get scale.
+
+ Returns
+ -------
+ scale_tensor: tvm.relax.Constant
+ The scale_tensor.
+ zero_point: tvm.relax.Constant
+ The zero_point.
+ """
+
+ name_prefix = name if quantizer._cache_processed else
quantizer.to_tensor_id(name, consumer)
+ scale_tensor = quantizer._get_tensor_cache(name, consumer,
"scale_tensor")
+ zero_point = quantizer._get_tensor_cache(name, consumer, "zero_point")
+ if scale_tensor is None:
+ scale_tensor = cls.get_scale_tensor(data, scale, axis, epsilon,
expand_dims=False)
+ if isinstance(scale_tensor, float):
+ scale_tensor = np.array(scale_tensor)
+ scale_tensor =
scale_tensor.astype(quantizer.find_tensor(name).dtype_name)
+ zero_point = np.zeros_like(scale_tensor).astype("int8")
+ scale_span = _ffi_api.SpanCreateWithAttr("name", name_prefix +
"_scale")
+ scale_tensor = tvm.relax.Constant(tvm.nd.array(scale_tensor),
span=scale_span)
+ zp_span = _ffi_api.SpanCreateWithAttr("name", name_prefix +
"_zero_point")
+ zero_point = tvm.relax.Constant(tvm.nd.array(zero_point),
span=zp_span)
+ quantizer._save_tensor_cache(name, consumer, "scale_tensor",
scale_tensor)
+ quantizer._save_tensor_cache(name, consumer, "zero_point",
zero_point)
+ return scale_tensor, zero_point
+
+ @classmethod
+ def quantize_normal(
+ cls,
+ quantizer: BaseQuantizer,
+ data: tvm.relax.Var,
+ name: str,
+ consumer: str,
+ scale: float,
+ nbits: int = 8,
+ axis: int = -1,
+ sign: bool = True,
+ rounding: str = "round",
+ epsilon: float = 1.0 / (1 << 24),
+ ) -> tvm.relax.Var:
+ """Calibrate the data by kl_divergence
+
+ Parameters
+ ----------
+ quantizer: BaseQuantizer
+ The quantizer
+ data: tvm.relax.Var
+ The source data.
+ name: str
+ The name of the tensor.
+ consumer: str
+ The name of the consumer.
+ scale: float
+ The scale factor
+ nbits: int
+ The number bits for quantize.
+ axis: int
+ The axis.
+ sign: bool
+ Whether to use sign.
+ rounding str
+ The rounding method.
+ epsilon: float
+ The epsilon for get scale.
+
+ Returns
+ -------
+ data: tvm.relax.Var
+ The processed tensor.
+ """
+
+ if nbits == 8:
+ dtype = "int8"
+ else:
+ raise TypeError("Unexpected nbits " + str(nbits))
+ name_prefix = name if quantizer._cache_processed else
quantizer.to_tensor_id(name, consumer)
+ scale_tensor, zero_point = cls.get_quantize_cache(
+ quantizer, data, name, consumer, scale, axis, epsilon
+ )
+ expr = relax_op.quantize(data, scale_tensor, zero_point, axis, dtype)
+ return quantizer._block_builder.emit(expr, name_hint=name_prefix +
"_quantize")
+
+ @classmethod
+ def dequantize_normal(
+ cls,
+ quantizer: BaseQuantizer,
+ data: tvm.relax.Var,
+ name: str,
+ consumer: str,
+ scale: float = -1.0,
+ nbits: int = 8,
+ axis: int = -1,
+ sign: bool = True,
+ rounding: str = "round",
+ epsilon: float = 1.0 / (1 << 24),
+ ) -> tvm.relax.Var:
+ """Calibrate the data by kl_divergence
+
+ Parameters
+ ----------
+ quantizer: BaseQuantizer
+ The quantizer
+ data: np.ndarray
+ The source data.
+ name: str
+ The name of the tensor.
+ consumer: str
+ The name of the consumer.
+ scale: float
+ The scale factor
+ nbits: int
+ The number bits for quantize.
+ axis: int
+ The axis.
+ sign: bool
+ Whether to use sign.
+ rounding str
+ The rounding method.
+ epsilon: float
+ The epsilon for get scale.
+
+ Returns
+ -------
+ data: array like
+ The processed tensor.
+ """
+
+ name_prefix = name if quantizer._cache_processed else
quantizer.to_tensor_id(name, consumer)
+ scale_tensor, zero_point = cls.get_quantize_cache(
+ quantizer, data, name, consumer, scale, axis, epsilon
+ )
+ expr = relax_op.dequantize(
+ data, scale_tensor, zero_point, axis,
quantizer.find_tensor(name).dtype
+ )
+ return quantizer._block_builder.emit(expr, name_hint=name_prefix +
"_dequantize")
+
+ @classmethod
+ def framework(cls):
+ return MSCFramework.TVM
+
+
+msc_utils.register_tool_method(TVMQuantizeMethod)
diff --git a/python/tvm/contrib/msc/framework/tvm/tools/quantize/quantizer.py
b/python/tvm/contrib/msc/framework/tvm/tools/quantize/quantizer.py
new file mode 100644
index 0000000000..d4680b9088
--- /dev/null
+++ b/python/tvm/contrib/msc/framework/tvm/tools/quantize/quantizer.py
@@ -0,0 +1,167 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-argument
+"""tvm.contrib.msc.framework.tvm.tools.quantize.quantizer"""
+
+from typing import List, Union
+
+import tvm
+from tvm.contrib.msc.core.tools.tool import ToolType, ToolStrategy
+from tvm.contrib.msc.core.tools.quantize import BaseQuantizer
+from tvm.contrib.msc.core.utils.namespace import MSCFramework
+from tvm.contrib.msc.core import utils as msc_utils
+
+
+class TVMQuantizerFactory(object):
+ """Quantizer factory for tvm"""
+
+ def create(self, base_cls: BaseQuantizer) -> BaseQuantizer:
+ """Create adaptive quantizer
+
+ Parameters
+ ----------
+ base_cls: BaseQuantizer
+ The base quantizer class
+
+ Returns
+ -------
+ quantizer_cls: BaseQuantizer
+ The quantizer class.
+ """
+
+ class Quantizer(base_cls):
+ """Adaptive quantizer for tvm"""
+
+ def _execute_before_build(self, block_builder:
tvm.relax.BlockBuilder):
+ """Execute before model build
+
+ Parameters
+ ----------
+ block_builder: tvm.relax.BlockBuilder
+ The block builder.
+ """
+
+ self._block_builder = block_builder
+ self._gather_tensors, self._gather_names = {}, []
+ super()._execute_before_build(block_builder)
+
+ def _execute_after_build(
+ self, output: Union[tvm.relax.Var, List[tvm.relax.DataflowVar]]
+ ) -> List[tvm.relax.Var]:
+ """Execute after model build
+
+ Parameters
+ ----------
+ output: var or list<var>
+ The output var of the model.
+
+ Returns
+ -------
+ outputs: list<var>
+ The modified outputs var.
+ """
+
+ if self._calibrated:
+ return super()._execute_after_build(output)
+ self._gather_names = list(sorted(self._gather_tensors.keys()))
+ gather_tensors = [self._gather_tensors[o]["tensor"] for o in
self._gather_names]
+ if isinstance(output, tvm.relax.Var):
+ return super()._execute_after_build([output] +
gather_tensors)
+ return super()._execute_after_build(output + gather_tensors)
+
+ def _execute_after_forward(
+ self, outputs: List[tvm.runtime.NDArray]
+ ) -> Union[tvm.runtime.NDArray, List[tvm.runtime.NDArray]]:
+ """Execute after model forward
+
+ Parameters
+ ----------
+ outputs: list<np.ndarray>
+ The output datas.
+
+ Returns
+ -------
+ output: np.ndarray or list<np.ndarray>
+ The modified output ndarray.
+ """
+
+ if self._calibrated:
+ return super()._execute_after_forward(outputs)
+ output_num = len(outputs) - len(self._gather_names)
+ for data, name in zip(outputs[output_num:],
self._gather_names):
+ info = self._gather_tensors[name]
+ for consumer in info["consumers"]:
+ strategys = self._get_tensor_strategys(name, consumer)
+ self._gather_tensor(data, name, consumer, strategys)
+ if output_num == 1:
+ return super()._execute_after_forward(outputs[0])
+ return super()._execute_after_forward(outputs[:output_num])
+
+ def _process_tensor(
+ self,
+ tensor: tvm.relax.DataflowVar,
+ name: str,
+ consumer: str,
+ scope: str,
+ strategys: List[ToolStrategy],
+ ) -> tvm.relax.DataflowVar:
+ """Process tensor
+
+ Parameters
+ -------
+ tensor: Any
+ Tensor in framework
+ name: str
+ The name of the tensor.
+ consumer: str
+ The name of the consumer.
+ scope: str
+ The scope mark teacher| student| null.
+ strategys: list<ToolStrategy>
+ The strategys for the tensor.
+
+ Returns
+ -------
+ tensor: Any
+ The processed tensor.
+ """
+
+ if not self._calibrated:
+ if self.is_weight(name):
+ return self._gather_tensor(self.get_data(name), name,
consumer, strategys)
+ if name not in self._gather_tensors:
+ self._gather_tensors[name] = {
+ "consumers": [consumer],
+ "tensor": tensor,
+ }
+ self._gather_names.append(name)
+ else:
+
self._gather_tensors[name]["consumers"].append(consumer)
+ return tensor
+ return self._quantize_tensor(tensor, name, consumer, strategys)
+
+ @classmethod
+ def framework(cls):
+ return MSCFramework.TVM
+
+ return Quantizer
+
+
+factory = TVMQuantizerFactory()
+tools = msc_utils.get_registered_tool_cls(MSCFramework.MSC,
ToolType.QUANTIZER, tool_style="all")
+for tool in tools.values():
+ msc_utils.register_tool_cls(factory.create(tool))
diff --git a/python/tvm/contrib/msc/pipeline/manager.py
b/python/tvm/contrib/msc/pipeline/manager.py
index bbd6d452ad..8a37ef951f 100644
--- a/python/tvm/contrib/msc/pipeline/manager.py
+++ b/python/tvm/contrib/msc/pipeline/manager.py
@@ -378,6 +378,10 @@ class BaseManager(object):
if _tool_enabled(ToolType.PRUNER):
self._apply_tool(ToolType.PRUNER, stage_config)
+ # run quantize
+ if _tool_enabled(ToolType.QUANTIZER):
+ self._apply_tool(ToolType.QUANTIZER, stage_config)
+
# optimize and get the runner
msc_utils.time_stamp(MSCStage.OPTIMIZE)
return self._create_runner(
diff --git a/tests/python/contrib/test_msc/test_tools.py
b/tests/python/contrib/test_msc/test_tools.py
index f396b81ea4..7e981d348b 100644
--- a/tests/python/contrib/test_msc/test_tools.py
+++ b/tests/python/contrib/test_msc/test_tools.py
@@ -77,7 +77,42 @@ def get_tool_config(tool_type):
"strategys": [{"method": "per_channel", "density": 0.8}],
}
elif tool_type == ToolType.QUANTIZER:
- raise NotImplementedError("Quantizer is not supported")
+ # pylint: disable=import-outside-toplevel
+ from tvm.contrib.msc.core.tools.quantize import QuantizeStage
+
+ config = {
+ "plan_file": "msc_quantizer.json",
+ "strategys": [
+ {
+ "method": "gather_maxmin",
+ "op_types": ["nn.conv2d", "msc.linear"],
+ "tensor_types": ["input", "output"],
+ "stages": [QuantizeStage.GATHER],
+ },
+ {
+ "method": "gather_max_per_channel",
+ "op_types": ["nn.conv2d", "msc.linear"],
+ "tensor_types": ["weight"],
+ "stages": [QuantizeStage.GATHER],
+ },
+ {
+ "method": "calibrate_maxmin",
+ "op_types": ["nn.conv2d", "msc.linear"],
+ "tensor_types": ["input", "output"],
+ "stages": [QuantizeStage.CALIBRATE],
+ },
+ {
+ "method": "quantize_normal",
+ "op_types": ["nn.conv2d", "msc.linear"],
+ "tensor_types": ["input", "weight"],
+ },
+ {
+ "method": "dequantize_normal",
+ "op_types": ["nn.conv2d", "msc.linear"],
+ "tensor_types": ["output"],
+ },
+ ],
+ }
elif tool_type == ToolType.TRACKER:
config = {
"plan_file": "msc_tracker.json",
@@ -183,7 +218,7 @@ def get_model_info(compile_type):
raise TypeError("Unexpected compile_type " + str(compile_type))
[email protected]("tool_type", [ToolType.PRUNER, ToolType.TRACKER])
[email protected]("tool_type", [ToolType.PRUNER, ToolType.QUANTIZER,
ToolType.TRACKER])
def test_tvm_tool(tool_type):
"""Test tools for tvm"""
@@ -194,7 +229,10 @@ def test_tvm_tool(tool_type):
@requires_tensorrt
[email protected]("tool_type", [ToolType.PRUNER, ToolType.TRACKER])
[email protected](
+ "tool_type",
+ [ToolType.PRUNER, ToolType.QUANTIZER, ToolType.TRACKER],
+)
def test_tensorrt_tool(tool_type):
"""Test tools for tensorrt"""