This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 8edfee8574 [Unity][MSC][M2.4] Add quantizer for quantize model (#16228)
8edfee8574 is described below

commit 8edfee85742da697957e9e33e95bce28727362ac
Author: Archermmt <[email protected]>
AuthorDate: Fri Dec 15 15:59:56 2023 +0800

    [Unity][MSC][M2.4] Add quantizer for quantize model (#16228)
    
    * add quantizer
    
    * add quantizer test
---
 python/tvm/contrib/msc/core/runtime/runner.py      |   8 +
 python/tvm/contrib/msc/core/tools/__init__.py      |   1 +
 .../tvm/tools => core/tools/quantize}/__init__.py  |   6 +-
 .../tvm/contrib/msc/core/tools/quantize/method.py  | 472 +++++++++++++++++++++
 .../contrib/msc/core/tools/quantize/quantizer.py   | 249 +++++++++++
 python/tvm/contrib/msc/core/utils/file.py          |   5 +
 python/tvm/contrib/msc/core/utils/info.py          |   4 +
 .../msc/framework/tensorflow/tools/__init__.py     |   1 +
 .../tools/quantize}/__init__.py                    |   5 +-
 .../tensorflow/tools/quantize/quantizer.py         |  55 +++
 .../msc/framework/tensorrt/codegen/sources.py      | 172 +++++++-
 .../msc/framework/tensorrt/runtime/runner.py       |  53 +++
 .../msc/framework/tensorrt/tools/__init__.py       |   1 +
 .../tools => tensorrt/tools/quantize}/__init__.py  |   6 +-
 .../framework/tensorrt/tools/quantize/method.py    | 149 +++++++
 .../framework/tensorrt/tools/quantize/quantizer.py | 366 ++++++++++++++++
 .../contrib/msc/framework/torch/tools/__init__.py  |   1 +
 .../tools => torch/tools/quantize}/__init__.py     |   6 +-
 .../msc/framework/torch/tools/quantize/method.py   | 237 +++++++++++
 .../framework/torch/tools/quantize/quantizer.py    |  55 +++
 .../contrib/msc/framework/tvm/tools/__init__.py    |   1 +
 .../tools => tvm/tools/quantize}/__init__.py       |   6 +-
 .../msc/framework/tvm/tools/quantize/method.py     | 204 +++++++++
 .../msc/framework/tvm/tools/quantize/quantizer.py  | 167 ++++++++
 python/tvm/contrib/msc/pipeline/manager.py         |   4 +
 tests/python/contrib/test_msc/test_tools.py        |  44 +-
 26 files changed, 2259 insertions(+), 19 deletions(-)

diff --git a/python/tvm/contrib/msc/core/runtime/runner.py 
b/python/tvm/contrib/msc/core/runtime/runner.py
index dcf24225fe..5228b06b10 100644
--- a/python/tvm/contrib/msc/core/runtime/runner.py
+++ b/python/tvm/contrib/msc/core/runtime/runner.py
@@ -414,6 +414,14 @@ class BaseRunner(object):
                     self.run(inputs, ret_type="native")
                     break
             plan = pruner.finalize()
+        elif tool_type == ToolType.QUANTIZER:
+            quantizer = self.get_tool(ToolType.QUANTIZER)
+            while not quantizer.calibrated:
+                assert data_loader, "data_loader should be given to plan prune"
+                for inputs in data_loader():
+                    self.run(inputs, ret_type="native")
+                quantizer.calibrate()
+            plan = quantizer.finalize()
         else:
             plan = self.get_tool(tool_type).finalize()
         assert plan, "Failed to create plan for {}".format(tool_type)
diff --git a/python/tvm/contrib/msc/core/tools/__init__.py 
b/python/tvm/contrib/msc/core/tools/__init__.py
index 0524e4c823..e97771cf6c 100644
--- a/python/tvm/contrib/msc/core/tools/__init__.py
+++ b/python/tvm/contrib/msc/core/tools/__init__.py
@@ -19,4 +19,5 @@
 from .tool import *
 from .execute import *
 from .prune import *
+from .quantize import *
 from .track import *
diff --git a/python/tvm/contrib/msc/framework/tvm/tools/__init__.py 
b/python/tvm/contrib/msc/core/tools/quantize/__init__.py
similarity index 89%
copy from python/tvm/contrib/msc/framework/tvm/tools/__init__.py
copy to python/tvm/contrib/msc/core/tools/quantize/__init__.py
index 226ae3102d..1aad17c055 100644
--- a/python/tvm/contrib/msc/framework/tvm/tools/__init__.py
+++ b/python/tvm/contrib/msc/core/tools/quantize/__init__.py
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""tvm.contrib.msc.framework.tvm.tools"""
+"""tvm.contrib.msc.core.tools.quantize"""
 
-from .prune import *
-from .track import *
+from .quantizer import *
+from .method import *
diff --git a/python/tvm/contrib/msc/core/tools/quantize/method.py 
b/python/tvm/contrib/msc/core/tools/quantize/method.py
new file mode 100644
index 0000000000..9701858267
--- /dev/null
+++ b/python/tvm/contrib/msc/core/tools/quantize/method.py
@@ -0,0 +1,472 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-argument
+"""tvm.contrib.msc.core.tools.quantize.method"""
+
+from typing import Union, Any
+import numpy as np
+
+from tvm.contrib.msc.core.tools.tool import ToolType, BaseTool
+from tvm.contrib.msc.core.utils.namespace import MSCFramework
+from tvm.contrib.msc.core import utils as msc_utils
+
+
+class QuantizeMethod(object):
+    """Default quantize method"""
+
+    @classmethod
+    def amplify_data(
+        cls, data: np.array, scale: float, min_val: float, max_val: float, 
rounding: str = "round"
+    ) -> np.ndarray:
+        """Amplify the data
+
+        Parameters
+        ----------
+        data: np.ndarray
+            The source data.
+        scale: float
+            The scale factor
+        min_val: float
+            The min.
+        max_val: float
+            The max.
+        rounding: str
+            The round method
+
+        Returns
+        -------
+        data: np.ndarray
+            The processed data.
+        """
+
+        if rounding == "null":
+            return np.clip(data * scale, min_val, max_val)
+        if rounding == "floor":
+            return np.clip(np.floor(data * scale), min_val, max_val)
+        if rounding == "ceil":
+            return np.clip(np.ceil(data * scale), min_val, max_val)
+        if rounding == "round":
+            return np.clip(np.round(data * scale), min_val, max_val)
+        if rounding == "trunc":
+            return np.clip(np.trunc(data * scale), min_val, max_val)
+        if rounding == "logic_round":
+            data = np.clip(data * scale, min_val, max_val)
+            negative_ceil = np.where(
+                np.logical_and(data < 0, (data - np.floor(data)) == 0.5), 
np.ceil(data), 0
+            )
+            data = np.where(np.logical_and(data < 0, (data - np.floor(data)) 
== 0.5), 0, data)
+            data = np.where((data - np.floor(data)) >= 0.5, np.ceil(data), 
data)
+            data = np.where((data - np.floor(data)) < 0.5, np.floor(data), 
data)
+            return data + negative_ceil
+        raise TypeError("Unexpected rounding " + str(rounding))
+
+    @classmethod
+    def get_scale_tensor(
+        cls,
+        data: Any,
+        scale: float,
+        axis: int = -1,
+        epsilon: float = 1.0 / (1 << 24),
+        expand_dims: bool = True,
+    ) -> Union[float, np.ndarray]:
+        """Get the scale tensor
+
+        Parameters
+        ----------
+        quantizer: BaseQuantizer
+            The quantizer
+        data: array_like
+            The source data.
+        name: str
+            The name of the tensor.
+        consumer: str
+            The name of the consumer.
+        scale: float
+            The scale factor
+        axis: int
+            The axis.
+        epsilon: float
+            The epsilon for get scale.
+        expand_dims: bool
+            Whether to expand dims
+
+        Returns
+        -------
+        scale_tensor: np.ndarray
+            The processed tensor.
+        """
+
+        data = msc_utils.cast_array(data)
+        if isinstance(scale, list):
+            scale_tensor = np.array(scale).astype(data.dtype)
+            if expand_dims:
+                scale_shape = [s if idx == axis else 1 for idx, s in 
enumerate(data.shape)]
+                scale_tensor = scale_tensor.reshape(scale_shape)
+            if scale_tensor.min() <= epsilon:
+                scale_mask = scale_tensor <= epsilon
+                scale_tensor[scale_mask] = 0
+        elif scale <= epsilon:
+            scale_tensor = 0
+        else:
+            scale_tensor = scale
+        return scale_tensor
+
+    @classmethod
+    def gather_maxmin(
+        cls,
+        quantizer: BaseTool,
+        data: np.ndarray,
+        name: str,
+        consumer: str,
+        plan: dict,
+        nbits: int = 8,
+    ) -> dict:
+        """Gather the data by max/min
+
+        Parameters
+        ----------
+        quantizer: BaseQuantizer
+            The quantizer
+        data: np.ndarray
+            The source data.
+        name: str
+            The name of the tensor.
+        consumer: str
+            The name of the consumer.
+        plan: dict
+            The pre-calibrated plan.
+        nbits: int
+            The number bits for quantize.
+
+        Returns
+        -------
+        plan: dict
+            The plan of the tensor.
+        """
+
+        abs_max_list = plan.get("abs_max_list", [])
+        abs_max_list.append(float(np.abs(data).max()))
+        max_list = plan.get("max_list", [])
+        max_list.append(float(data.max()))
+        min_list = plan.get("min_list", [])
+        min_list.append(float(data.min()))
+        return {
+            "abs_max_list": abs_max_list,
+            "max_list": max_list,
+            "min_list": min_list,
+            "calibrated": False,
+        }
+
+    @classmethod
+    def gather_kl_divergence(
+        cls,
+        quantizer: BaseTool,
+        data: np.ndarray,
+        name: str,
+        consumer: str,
+        plan: dict,
+        nbits: int = 8,
+        bins: int = 4096,
+    ) -> dict:
+        """Gather the data by kl_divergence
+
+        Parameters
+        ----------
+        quantizer: BaseQuantizer
+            The quantizer
+        data: np.ndarray
+            The source data.
+        name: str
+            The name of the tensor.
+        consumer: str
+            The name of the consumer.
+        plan: dict
+            The pre-calibrated plan.
+        nbits: int
+            The number bits for quantize.
+        bins: int
+            The number bins.
+
+        Returns
+        -------
+        plan: dict
+            The plan of the tensor.
+        """
+
+        if not plan or "abs_max" not in plan:
+            return cls.gather_maxmin(quantizer, name, data, plan, nbits)
+        hist, edge = np.histogram(data, bins=bins, range=[-plan["abs_max"], 
plan["abs_max"]])
+        hist_list = plan.get("hist_list", [])
+        return {"hist_list": hist_list + [hist], "edge": edge, **plan}
+
+    @classmethod
+    def gather_max_per_channel(
+        cls,
+        quantizer: BaseTool,
+        data: np.ndarray,
+        name: str,
+        consumer: str,
+        plan: dict,
+        nbits: int = 8,
+        channel: str = "O",
+        auto_unsign: bool = False,
+    ) -> dict:
+        """Gather the data by max_per_channel
+
+        Parameters
+        ----------
+        quantizer: BaseQuantizer
+            The quantizer
+        data: np.ndarray
+            The source data.
+        name: str
+            The name of the tensor.
+        consumer: str
+            The name of the consumer.
+        plan: dict
+            The pre-calibrated plan.
+        nbits: int
+            The number bits for quantize.
+        channel: str
+            The channel reference.
+        auto_unsign: bool
+            Whether to use auto unsign.
+
+        Returns
+        -------
+        plan: dict
+            The plan of the tensor.
+        """
+
+        weight = quantizer.find_tensor(name)
+        axis = weight.layout_of(channel)
+        channel_datas = np.split(data, data.shape[axis], axis)
+        channel_max = [float(np.abs(d).max()) for d in channel_datas]
+        sign = data.min() < 0 if auto_unsign else True
+        valid_range = 2 ** (nbits - int(sign)) - 1
+        scale = [valid_range / m for m in channel_max]
+        return {"scale": scale, "sign": sign, "axis": axis, "calibrated": True}
+
+    @classmethod
+    def calibrate_maxmin(
+        cls,
+        quantizer: BaseTool,
+        name: str,
+        consumer: str,
+        plan: dict,
+        nbits: int = 8,
+        auto_unsign: bool = False,
+    ) -> dict:
+        """Calibrate the data by kl_divergence
+
+        Parameters
+        ----------
+        quantizer: BaseQuantizer
+            The quantizer
+        name: str
+            The name of the tensor.
+        consumer: str
+            The name of the consumer.
+        plan: dict
+            The pre-calibrated plan.
+        nbits: int
+            The number bits for quantize.
+        auto_unsign: bool
+            Whether to use auto unsign.
+
+        Returns
+        -------
+        plan: dict
+            The plan of the tensor.
+        """
+
+        sign = plan["min"] < 0 if auto_unsign else True
+        valid_range = 2 ** (nbits - int(sign)) - 1
+        abs_max = float(np.array(plan["abs_max_list"]).max())
+        return {"scale": valid_range / abs_max, "sign": sign, "calibrated": 
True}
+
+    @classmethod
+    def calibrate_kl_divergence(
+        cls,
+        quantizer: BaseTool,
+        name: str,
+        consumer: str,
+        plan: dict,
+        nbits: int = 8,
+        bins: int = 4096,
+        auto_unsign: bool = False,
+    ) -> dict:
+        """Calibrate the data by kl_divergence
+
+        Parameters
+        ----------
+        quantizer: BaseQuantizer
+            The quantizer
+        name: str
+            The name of the tensor.
+        consumer: str
+            The name of the consumer.
+        plan: dict
+            The pre-calibrated plan.
+        nbits: int
+            The number bits for quantize.
+        bins: int
+            The number bins.
+        auto_unsign: bool
+            Whether to use auto unsign.
+
+        Returns
+        -------
+        plan: dict
+            The plan of the tensor.
+        """
+
+        # pylint: disable=import-outside-toplevel
+        import ctypes
+        from tvm.relay import quantize as _quantize
+
+        if plan and "abs_max_list" in plan:
+            return {
+                "abs_max": float(np.array(plan["abs_max_list"]).max()),
+                "max": float(np.array(plan["max_list"]).max()),
+                "min": float(np.array(plan["min_list"]).min()),
+                "calibrated": False,
+            }
+
+        def get_pointer(arr, ctypes_type):
+            ptr = arr.ctypes.data_as(ctypes.POINTER(ctypes_type))
+            return ctypes.cast(ptr, ctypes.c_void_p)
+
+        sign = plan["min"] < 0 if auto_unsign else True
+        hist = np.array(plan["hist_list"]).sum(axis=0)
+        hist_ptr = get_pointer(hist.astype(np.int64), ctypes.c_int64)
+        edge_ptr = get_pointer(plan["edge"].astype(np.float32), ctypes.c_float)
+        valid_range = 2 ** (nbits - int(sign)) - 1
+        scale = _quantize._quantize.FindScaleByKLMinimization(hist_ptr, 
edge_ptr, bins, valid_range)
+        return {"scale": valid_range / scale, "sign": sign, "calibrated": True}
+
+    @classmethod
+    def quantize_normal(
+        cls,
+        quantizer: BaseTool,
+        data: np.ndarray,
+        name: str,
+        consumer: str,
+        scale: float,
+        nbits: int = 8,
+        axis: int = -1,
+        sign: bool = True,
+        rounding: str = "round",
+        epsilon: float = 1.0 / (1 << 24),
+    ) -> np.ndarray:
+        """Calibrate the data by kl_divergence
+
+        Parameters
+        ----------
+        quantizer: BaseQuantizer
+            The quantizer
+        data: np.ndarray
+            The source data.
+        name: str
+            The name of the tensor.
+        consumer: str
+            The name of the consumer.
+        scale: float
+            The scale factor
+        nbits: int
+            The number bits for quantize.
+        axis: int
+            The axis.
+        sign: bool
+            Whether to use sign.
+        rounding str
+            The rounding method.
+        epsilon: float
+            The epsilon for get scale.
+
+        Returns
+        -------
+        data: array like
+            The processed tensor.
+        """
+
+        valid_range = 2 ** (nbits - int(sign)) - 1
+        min_val = -valid_range if sign else 0
+        scale_tensor = quantizer._get_tensor_cache(name, consumer, 
"scale_tensor")
+        if scale_tensor is None:
+            scale_tensor = cls.get_scale_tensor(data, scale, axis, epsilon)
+            quantizer._save_tensor_cache(name, consumer, "scale_tensor", 
scale_tensor)
+        data = cls.amplify_data(data, scale_tensor, min_val, valid_range, 
rounding)
+        return data / scale
+
+    @classmethod
+    def dequantize_normal(
+        cls,
+        quantizer: BaseTool,
+        data: np.ndarray,
+        name: str,
+        consumer: str,
+        scale: float = -1.0,
+        nbits: int = 8,
+        axis: int = -1,
+        sign: bool = True,
+        rounding: str = "round",
+        epsilon: float = 1.0 / (1 << 24),
+    ) -> np.ndarray:
+        """Calibrate the data by kl_divergence
+
+        Parameters
+        ----------
+        quantizer: BaseQuantizer
+            The quantizer
+        data: np.ndarray
+            The source data.
+        name: str
+            The name of the tensor.
+        consumer: str
+            The name of the consumer.
+        scale: float
+            The scale factor
+        nbits: int
+            The number bits for quantize.
+        axis: int
+            The axis.
+        sign: bool
+            Whether to use sign.
+        rounding str
+            The rounding method.
+        epsilon: float
+            The epsilon for get scale.
+
+        Returns
+        -------
+        data: array like
+            The processed tensor.
+        """
+
+        return data
+
+    @classmethod
+    def framework(cls):
+        return MSCFramework.MSC
+
+    @classmethod
+    def tool_type(cls):
+        return ToolType.QUANTIZER
+
+
+msc_utils.register_tool_method(QuantizeMethod)
diff --git a/python/tvm/contrib/msc/core/tools/quantize/quantizer.py 
b/python/tvm/contrib/msc/core/tools/quantize/quantizer.py
new file mode 100644
index 0000000000..bee8e6fa42
--- /dev/null
+++ b/python/tvm/contrib/msc/core/tools/quantize/quantizer.py
@@ -0,0 +1,249 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""tvm.contrib.msc.core.tools.quantize.quantizer"""
+
+from typing import List, Dict, Any
+
+from tvm.contrib.msc.core.tools.tool import ToolType, BaseTool, ToolStrategy
+from tvm.contrib.msc.core import utils as msc_utils
+
+
+class QuantizeStage:
+    GATHER = "gather"
+    CALIBRATE = "calibrate"
+
+
+class BaseQuantizer(BaseTool):
+    """Base quantizer for all"""
+
+    def setup(self) -> dict:
+        """Setup the tool
+
+        Returns
+        -------
+        info: dict
+            The setup info.
+        """
+
+        if self._plan:
+            self._calibrated = True
+            self.change_stage(msc_utils.MSCStage.QUANTIZE)
+        else:
+            self._calibrated = False
+            self._calibrate_plan = {}
+            self.change_stage(QuantizeStage.GATHER)
+        return super().setup()
+
+    def calibrate(self) -> dict:
+        """Calibrate the datas
+
+        Returns
+        -------
+        plan: dict
+            The calibrated plan.
+        """
+
+        new_plan = {}
+        self.change_stage(QuantizeStage.CALIBRATE)
+        for tensor_id, plan in self._calibrate_plan.items():
+            if plan.get("calibrated", False):
+                new_plan[tensor_id] = plan
+                continue
+            name, consumer = self.from_tensor_id(tensor_id)
+            strategy = self._get_tensor_strategy(name, consumer)
+            new_plan[tensor_id] = strategy(self, name, consumer, plan)
+        if any(not plan.get("calibrated", False) for plan in 
new_plan.values()):
+            self._calibrate_plan = new_plan
+            self.change_stage(QuantizeStage.GATHER)
+        else:
+            self._calibrated = True
+            for name, plan in new_plan.items():
+                self._plan[name] = {k: v for k, v in plan.items() if k not in 
("calibrated")}
+            self.change_stage(msc_utils.MSCStage.QUANTIZE)
+        self._forward_cnt = 0
+        return new_plan
+
+    def _parse_strategys(self, strategy_list: dict) -> Dict[str, ToolStrategy]:
+        """Parse the strategy to get valid strategy
+
+        Parameters
+        -------
+        strategy_list: dict
+            The given strategy
+
+        Returns
+        -------
+        strategys: dict<str, ToolStrategy>
+            The parsed strategy.
+        """
+
+        def _update_stages(strategy):
+            if "stages" not in strategy:
+                strategy["stages"] = [msc_utils.MSCStage.QUANTIZE]
+            return strategy
+
+        return super()._parse_strategys([_update_stages(s) for s in 
strategy_list])
+
+    def _check_tensor(self, name: str, consumer: str) -> bool:
+        """Check if the tensor should be processed
+
+        Parameters
+        -------
+        name: str
+            The name of the tensor.
+        consumer: str
+            The name of the consumer.
+
+        Returns
+        -------
+        vaild: bool
+            Whether to process the tensor.
+        """
+
+        strategys = self._get_tensor_strategys(name, consumer)
+        if not strategys:
+            return False
+        if any(s.get_config().get("nbits", 8) == -1 for s in strategys):
+            return False
+        return True
+
+    def _process_tensor(
+        self, tensor: Any, name: str, consumer: str, scope: str, strategys: 
List[ToolStrategy]
+    ) -> Any:
+        """Process tensor
+
+        Parameters
+        -------
+        tensor: Any
+            Tensor in framework
+        name: str
+            The name of the tensor.
+        consumer: str
+            The name of the consumer.
+        scope: str
+            The scope mark teacher| student| null.
+        strategys: list<ToolStrategy>
+            The strategys for the tensor.
+
+        Returns
+        -------
+        tensor: Any
+            The processed tensor.
+        """
+
+        if not self._calibrated:
+            return self._gather_tensor(tensor, name, consumer, strategys)
+        return self._quantize_tensor(tensor, name, consumer, strategys)
+
+    def _gather_tensor(
+        self, tensor: Any, name: str, consumer: str, strategys: 
List[ToolStrategy]
+    ) -> Any:
+        """Gather tensor datas
+
+        Parameters
+        -------
+        tensor: Any
+            Tensor in framework
+        name: str
+            The name of the tensor.
+        consumer: str
+            The name of the consumer.
+        strategys: list<ToolStrategy>
+            The strategys for the tensor.
+
+        Returns
+        -------
+        tensor: Any
+            The processed tensor.
+        """
+
+        assert len(strategys) == 1, "gather should only has 1 strategy, get " 
+ str(strategys)
+        tensor_id = self.to_tensor_id(name, consumer)
+        plan = self._calibrate_plan.get(tensor_id, {})
+        if plan.get("calibrated", False):
+            return tensor
+        self._calibrate_plan[tensor_id] = strategys[0](self, tensor, name, 
consumer, plan)
+        return tensor
+
+    def _quantize_tensor(
+        self, tensor: Any, name: str, consumer: str, strategys: 
List[ToolStrategy]
+    ) -> Any:
+        """Quantize tensor
+
+        Parameters
+        -------
+        tensor: Any
+            Tensor in framework
+        name: str
+            The name of the tensor.
+        consumer: str
+            The name of the consumer.
+        strategys: list<ToolStrategy>
+            The strategys for the tensor.
+
+        Returns
+        -------
+        tensor: Any
+            The processed tensor.
+        """
+
+        tensor_id = self.to_tensor_id(name, consumer)
+        for strategy in strategys:
+            tensor = strategy(self, tensor, name, consumer, 
**self._plan[tensor_id])
+        return tensor
+
+    def create_tasks(self, **kwargs) -> List[dict]:
+        """Create tasks for gym
+
+        Parameters
+        ----------
+        kwargs: dict
+           The kwargs for create tasks.
+
+        Returns
+        -------
+        tasks: list<dict>
+            The tasks.
+        """
+
+        tasks, recorded = [], set()
+        for tensor_id, plan in self._plan.items():
+            name, _ = self.from_tensor_id(tensor_id)
+            if self.is_weight(name) and not kwargs.get("quantize_weights", 
False):
+                continue
+            if name not in recorded:
+                tasks.append({"name": tensor_id, **plan})
+                if self._cache_processed:
+                    recorded.add(name)
+        return tasks
+
+    @property
+    def calibrated(self):
+        return self._calibrated
+
+    @classmethod
+    def tool_type(cls):
+        return ToolType.QUANTIZER
+
+
+class DefaultQuantizer(BaseQuantizer):
+    @classmethod
+    def tool_style(cls):
+        return "default"
+
+
+msc_utils.register_tool_cls(DefaultQuantizer)
diff --git a/python/tvm/contrib/msc/core/utils/file.py 
b/python/tvm/contrib/msc/core/utils/file.py
index 146cfaf504..446efd4724 100644
--- a/python/tvm/contrib/msc/core/utils/file.py
+++ b/python/tvm/contrib/msc/core/utils/file.py
@@ -72,6 +72,8 @@ class MSCDirectory(object):
         return "{}(Cleanup: {}): {} Files".format(self._path, self._cleanup, 
len(self.listdir()))
 
     def __enter__(self):
+        if not os.path.isdir(self._path):
+            os.mkdir(self._path)
         os.chdir(self._path)
         return self
 
@@ -105,6 +107,9 @@ class MSCDirectory(object):
         """
 
         file_path = self.relpath(name)
+        base_dir = os.path.dirname(name)
+        if base_dir and not os.path.isdir(base_dir):
+            os.makedirs(base_dir)
         with open(file_path, "w") as f:
             f.write(contains)
         return file_path
diff --git a/python/tvm/contrib/msc/core/utils/info.py 
b/python/tvm/contrib/msc/core/utils/info.py
index 440789f856..d1b5cd1a26 100644
--- a/python/tvm/contrib/msc/core/utils/info.py
+++ b/python/tvm/contrib/msc/core/utils/info.py
@@ -234,6 +234,10 @@ def load_dict(str_dict: str, flavor: str = "json") -> dict:
             dict_obj = json.load(f)
     elif isinstance(str_dict, str):
         dict_obj = json.loads(str_dict)
+    elif isinstance(str_dict, dict):
+        dict_obj = copy_dict(str_dict)
+    else:
+        raise Exception("Unexpected str_dict {}({})".format(str_dict, 
type(str_dict)))
     assert flavor == "json", "Unexpected flavor for load_dict: " + str(flavor)
     return dict_obj
 
diff --git a/python/tvm/contrib/msc/framework/tensorflow/tools/__init__.py 
b/python/tvm/contrib/msc/framework/tensorflow/tools/__init__.py
index d25cfd4e67..e5ebe50956 100644
--- a/python/tvm/contrib/msc/framework/tensorflow/tools/__init__.py
+++ b/python/tvm/contrib/msc/framework/tensorflow/tools/__init__.py
@@ -17,4 +17,5 @@
 """tvm.contrib.msc.framework.tensorflow.tools"""
 
 from .prune import *
+from .quantize import *
 from .track import *
diff --git a/python/tvm/contrib/msc/framework/tvm/tools/__init__.py 
b/python/tvm/contrib/msc/framework/tensorflow/tools/quantize/__init__.py
similarity index 90%
copy from python/tvm/contrib/msc/framework/tvm/tools/__init__.py
copy to python/tvm/contrib/msc/framework/tensorflow/tools/quantize/__init__.py
index 226ae3102d..ed458ef838 100644
--- a/python/tvm/contrib/msc/framework/tvm/tools/__init__.py
+++ b/python/tvm/contrib/msc/framework/tensorflow/tools/quantize/__init__.py
@@ -14,7 +14,6 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""tvm.contrib.msc.framework.tvm.tools"""
+"""tvm.contrib.msc.framework.tensorflow.tools.quantize"""
 
-from .prune import *
-from .track import *
+from .quantizer import *
diff --git 
a/python/tvm/contrib/msc/framework/tensorflow/tools/quantize/quantizer.py 
b/python/tvm/contrib/msc/framework/tensorflow/tools/quantize/quantizer.py
new file mode 100644
index 0000000000..dd6f2aac38
--- /dev/null
+++ b/python/tvm/contrib/msc/framework/tensorflow/tools/quantize/quantizer.py
@@ -0,0 +1,55 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""tvm.contrib.msc.framework.tensorflow.tools.quantize.quantizer"""
+
+from tvm.contrib.msc.core.tools.tool import ToolType
+from tvm.contrib.msc.core.tools.quantize import BaseQuantizer
+from tvm.contrib.msc.core.utils.namespace import MSCFramework
+from tvm.contrib.msc.core import utils as msc_utils
+
+
+class TensorflowQuantizerFactory(object):
+    """Quantizer factory for tensorflow"""
+
+    def create(self, base_cls: BaseQuantizer) -> BaseQuantizer:
+        """Create adaptive quantizer
+
+        Parameters
+        ----------
+        base_cls: BaseQuantizer
+            The base quantizer class
+
+        Returns
+        -------
+        quantizer_cls: BaseQuantizer
+            The quantizer class.
+        """
+
+        class Quantizer(base_cls):
+            """Adaptive quantizer for tensorflow"""
+
+            @classmethod
+            def framework(cls):
+                return MSCFramework.TENSORFLOW
+
+        return Quantizer
+
+
+factory = TensorflowQuantizerFactory()
+tools = msc_utils.get_registered_tool_cls(MSCFramework.MSC, 
ToolType.QUANTIZER, tool_style="all")
+for tool in tools.values():
+    msc_utils.register_tool_cls(factory.create(tool))
diff --git a/python/tvm/contrib/msc/framework/tensorrt/codegen/sources.py 
b/python/tvm/contrib/msc/framework/tensorrt/codegen/sources.py
index b6497e9258..a5df42f78b 100644
--- a/python/tvm/contrib/msc/framework/tensorrt/codegen/sources.py
+++ b/python/tvm/contrib/msc/framework/tensorrt/codegen/sources.py
@@ -302,6 +302,171 @@ bool TRTUtils::DeserializeEngineFromFile(const 
std::string& file,
 """
 
 
+def get_trt_quantize_h_code():
+    """Create trt_quantize header file codes
+
+    Returns
+    -------
+    source: str
+        The trt_quantize header source.
+    """
+
+    return """#ifndef TVM_CONTRIB_MSC_UTILS_TRT_QUANTIZE_H_
+#define TVM_CONTRIB_MSC_UTILS_TRT_QUANTIZE_H_
+
+#include <cassert>
+#include <fstream>
+#include <iterator>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "NvInfer.h"
+#include "base.h"
+#include "trt_common.h"
+
+namespace tvm {
+namespace contrib {
+namespace msc {
+
+using namespace nvinfer1;
+
+class CalibrateHelper {
+ public:
+  CalibrateHelper(const std::string& range_file, const std::string& folder, 
int max_size = -1);
+
+  ~CalibrateHelper() {
+    for (const auto& buffer : cpu_buffers_) {
+      free(buffer);
+    }
+    for (const auto& buffer : gpu_buffers_) {
+      CHECK(cudaFree(buffer));
+    }
+  }
+
+  bool GetBatch(void* bindings[], const char* names[], int nbBindings);
+
+  const void* ReadCache(size_t& length);
+
+  void WriteCache(const void* cache, size_t length);
+
+ private:
+  std::unique_ptr<DatasetReader> reader_;
+  std::string range_file_;
+  std::vector<char> cache_;
+  std::vector<void*> cpu_buffers_;
+  std::vector<void*> gpu_buffers_;
+};
+
+#define CALIBRATE_MEMBERS(Calibrator)                                          
             \\
+ public:                                                                       
             \\
+  Calibrator(const std::string& range_file, const std::string& folder, int 
max_size = -1) { \\
+    helper_.reset(new CalibrateHelper(range_file, folder, max_size));          
             \\
+  }                                                                            
             \\
+                                                                               
             \\
+  virtual ~Calibrator() {}                                                     
             \\
+                                                                               
             \\
+  int getBatchSize() const noexcept override { return 1; }                     
             \\
+                                                                               
             \\
+  bool getBatch(void* bindings[], const char* names[], int nbBindings) 
noexcept override {  \\
+    return helper_->GetBatch(bindings, names, nbBindings);                     
             \\
+  }                                                                            
             \\
+                                                                               
             \\
+  const void* readCalibrationCache(size_t& length) noexcept override {         
             \\
+    return helper_->ReadCache(length);                                         
             \\
+  }                                                                            
             \\
+                                                                               
             \\
+  void writeCalibrationCache(const void* cache, size_t length) noexcept 
override {          \\
+    return helper_->WriteCache(cache, length);                                 
             \\
+  }                                                                            
             \\
+                                                                               
             \\
+ private:                                                                      
             \\
+  std::unique_ptr<CalibrateHelper> helper_;
+
+class MSCInt8EntropyCalibrator : public IInt8EntropyCalibrator {
+  CALIBRATE_MEMBERS(MSCInt8EntropyCalibrator)
+};
+
+class MSCInt8EntropyCalibrator2 : public IInt8EntropyCalibrator2 {
+  CALIBRATE_MEMBERS(MSCInt8EntropyCalibrator2)
+};
+
+}  // namespace msc
+}  // namespace contrib
+}  // namespace tvm
+
+#endif  // TVM_CONTRIB_MSC_UTILS_TRT_QUANTIZE_H_
+"""
+
+
+def get_trt_quantize_cc_code():
+    """Create trt_quantize cc file codes
+
+    Returns
+    -------
+    source: str
+        The trt_quantize cc source.
+    """
+
+    return """#include "trt_quantize.h"
+
+namespace tvm {
+namespace contrib {
+namespace msc {
+
+using namespace nvinfer1;
+
+CalibrateHelper::CalibrateHelper(const std::string& range_file, const 
std::string& folder,
+                                 int max_size) {
+  range_file_ = range_file;
+  reader_.reset(new DatasetReader(folder, max_size));
+  const auto& tensor_names = reader_->GetTensorNames();
+  cpu_buffers_.resize(tensor_names.size());
+  gpu_buffers_.resize(tensor_names.size());
+  for (size_t i = 0; i < tensor_names.size(); i++) {
+    size_t tensor_size = reader_->GetTensorSize(tensor_names[i]);
+    cpu_buffers_[i] = malloc(tensor_size);
+    CHECK(cudaMalloc(&gpu_buffers_[i], tensor_size));
+  }
+}
+
+bool CalibrateHelper::GetBatch(void* bindings[], const char* names[], int 
nbBindings) {
+  if (!reader_->ReadNext(cpu_buffers_.data())) {
+    return false;
+  }
+  for (size_t i = 0; i < nbBindings; i++) {
+    CHECK(cudaMemcpy(gpu_buffers_[i], cpu_buffers_[i], 
reader_->GetTensorSize(names[i]),
+                     cudaMemcpyHostToDevice));
+    bindings[i] = gpu_buffers_[i];
+  }
+  return true;
+}
+
+const void* CalibrateHelper::ReadCache(size_t& length) {
+  cache_.clear();
+  std::ifstream in_file(range_file_, std::ifstream::binary);
+  if (!in_file.is_open()) {
+    return nullptr;
+  }
+  in_file >> std::noskipws;
+  std::copy(std::istream_iterator<char>(in_file), 
std::istream_iterator<char>(),
+            std::back_inserter(cache_));
+  length = cache_.size();
+  return length > 0 ? &cache_[0] : nullptr;
+}
+
+void CalibrateHelper::WriteCache(const void* cache, size_t length) {
+  std::ofstream output(range_file_, std::ios::binary);
+  output.write(reinterpret_cast<const char*>(cache), length);
+}
+
+}  // namespace msc
+}  // namespace contrib
+}  // namespace tvm
+"""
+
+
 def get_trt_sources() -> Dict[str, str]:
     """Create trt sources for cpp codegen
 
@@ -313,6 +478,11 @@ def get_trt_sources() -> Dict[str, str]:
 
     sources = get_base_sources()
     sources.update(
-        {"trt_common.h": get_trt_common_h_code(), "trt_common.cc": 
get_trt_common_cc_code()}
+        {
+            "trt_common.h": get_trt_common_h_code(),
+            "trt_common.cc": get_trt_common_cc_code(),
+            "trt_quantize.h": get_trt_quantize_h_code(),
+            "trt_quantize.cc": get_trt_quantize_cc_code(),
+        }
     )
     return sources
diff --git a/python/tvm/contrib/msc/framework/tensorrt/runtime/runner.py 
b/python/tvm/contrib/msc/framework/tensorrt/runtime/runner.py
index 15a42b2cf9..c66f8d1450 100644
--- a/python/tvm/contrib/msc/framework/tensorrt/runtime/runner.py
+++ b/python/tvm/contrib/msc/framework/tensorrt/runtime/runner.py
@@ -17,9 +17,14 @@
 # pylint: disable=unused-import
 """tvm.contrib.msc.framework.tensorrt.runtime.runner"""
 
+from typing import Any, List, Dict
+
 import tvm
+from tvm.contrib.msc.core.ir import MSCGraph
 from tvm.contrib.msc.core.runtime import BYOCRunner
+from tvm.contrib.msc.core.tools import ToolType
 from tvm.contrib.msc.core.utils.namespace import MSCFramework
+from tvm.contrib.msc.core import utils as msc_utils
 from tvm.contrib.msc.framework.tensorrt.frontend import (
     partition_for_tensorrt,
     transform_for_tensorrt,
@@ -44,6 +49,54 @@ class TensorRTRunner(BYOCRunner):
             self._device = "cuda"
         return super().setup()
 
+    def apply_tool(self, tool_type: str, data_loader: Any = None) -> dict:
+        """Execute tool and get plan
+
+        Parameters
+        -------
+        tool_type: str
+            The tool type, should be in ToolType
+        data_loader:
+            The data loader
+        """
+
+        assert tool_type in self._tools, "Can not find tool " + str(tool_type)
+        if tool_type == ToolType.QUANTIZER:
+            quantizer = self.get_tool(ToolType.QUANTIZER)
+            assert data_loader, "data_loader should be given to plan prune"
+            for inputs in data_loader():
+                self.run(inputs)
+            self._generate_model()
+            quantizer.calibrate()
+            assert quantizer.calibrated, "Failed to calibrate the tenosrrt 
quantizer"
+        return super().apply_tool(tool_type, data_loader)
+
+    def _generate_model(
+        self, graphs: List[MSCGraph] = None, weights: List[Dict[str, 
tvm.nd.array]] = None
+    ) -> Any:
+        """Codegen the model according to framework
+
+        Parameters
+        -------
+        graphs: list<MSCgraph>
+            The msc graphs.
+        weights: list<dict<str, tvm.nd.array>>
+            The weights
+
+        Returns
+        -------
+        model: Any
+            The meta model
+        """
+
+        codegen = self._generate_config.get("codegen")
+        if not isinstance(codegen, (list, tuple)):
+            self._generate_config["codegen"] = [msc_utils.copy_dict(codegen)] 
* len(self._graphs)
+        for tool in self.get_tools():
+            self._generate_config = tool.config_generate(self._generate_config)
+
+        return super()._generate_model(graphs, weights)
+
     @classmethod
     def target_transform(cls, mod: tvm.IRModule):
         """Transform the mod by target.
diff --git a/python/tvm/contrib/msc/framework/tensorrt/tools/__init__.py 
b/python/tvm/contrib/msc/framework/tensorrt/tools/__init__.py
index ecc82bc40f..c010a42004 100644
--- a/python/tvm/contrib/msc/framework/tensorrt/tools/__init__.py
+++ b/python/tvm/contrib/msc/framework/tensorrt/tools/__init__.py
@@ -17,4 +17,5 @@
 """tvm.contrib.msc.framework.tensorrt.tools"""
 
 from .prune import *
+from .quantize import *
 from .track import *
diff --git a/python/tvm/contrib/msc/framework/tensorflow/tools/__init__.py 
b/python/tvm/contrib/msc/framework/tensorrt/tools/quantize/__init__.py
similarity index 88%
copy from python/tvm/contrib/msc/framework/tensorflow/tools/__init__.py
copy to python/tvm/contrib/msc/framework/tensorrt/tools/quantize/__init__.py
index d25cfd4e67..e47b3324c9 100644
--- a/python/tvm/contrib/msc/framework/tensorflow/tools/__init__.py
+++ b/python/tvm/contrib/msc/framework/tensorrt/tools/quantize/__init__.py
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""tvm.contrib.msc.framework.tensorflow.tools"""
+"""tvm.contrib.msc.framework.tensorrt.tools.quantize"""
 
-from .prune import *
-from .track import *
+from .quantizer import *
+from .method import *
diff --git a/python/tvm/contrib/msc/framework/tensorrt/tools/quantize/method.py 
b/python/tvm/contrib/msc/framework/tensorrt/tools/quantize/method.py
new file mode 100644
index 0000000000..0feb836d13
--- /dev/null
+++ b/python/tvm/contrib/msc/framework/tensorrt/tools/quantize/method.py
@@ -0,0 +1,149 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-argument
+"""tvm.contrib.msc.framework.tensorrt.tools.quantize.method"""
+
+from typing import Dict
+
+from tvm.contrib.msc.core.tools.quantize import QuantizeMethod, BaseQuantizer
+from tvm.contrib.msc.core.utils.namespace import MSCFramework
+from tvm.contrib.msc.core import utils as msc_utils
+
+
+class TensorRTQuantizeMethod(QuantizeMethod):
+    """Default quantize method for tensorrt"""
+
+    @classmethod
+    def quantize_normal(
+        cls,
+        quantizer: BaseQuantizer,
+        tensor_ctx: Dict[str, str],
+        name: str,
+        consumer: str,
+        scale: float,
+        nbits: int = 8,
+        axis: int = -1,
+        sign: bool = True,
+        rounding: str = "round",
+        epsilon: float = 1.0 / (1 << 24),
+    ) -> Dict[str, str]:
+        """Calibrate the data by kl_divergence
+
+        Parameters
+        ----------
+        quantizer: BaseQuantizer
+            The quantizer
+        tensor_ctx: dict<str, str>
+            Tensor describe items.
+        name: str
+            The name of the tensor.
+        consumer: str
+            The name of the consumer.
+        scale: float
+            The scale factor
+        nbits: int
+            The number bits for quantize.
+        axis: int
+            The axis.
+        sign: bool
+            Whether to use sign.
+        rounding str
+            The rounding method.
+        epsilon: float
+            The epsilon for get scale.
+
+        Returns
+        -------
+        tensor_ctx: dict<str, str>
+            Tensor describe items.
+        """
+
+        if quantizer.is_weight(name):
+            return tensor_ctx
+        dtype = quantizer.find_tensor(name).dtype_name
+        precision = "DataType::k"
+        if nbits == 8:
+            precision += "INT8"
+        elif dtype == "float16":
+            precision += "HALF"
+        elif dtype == "float32":
+            precision += "FLOAT"
+        else:
+            raise TypeError("nbits {} is not supported".format(nbits))
+        tensor_ctx["processed"].extend(
+            [
+                "{}->setPrecision({})".format(tensor_ctx["producer"], 
precision),
+                "{0}->setDynamicRange(-{1}, {1})".format(tensor_ctx["tensor"], 
scale),
+            ]
+        )
+        return tensor_ctx
+
+    @classmethod
+    def dequantize_normal(
+        cls,
+        quantizer: BaseQuantizer,
+        tensor_ctx: Dict[str, str],
+        name: str,
+        consumer: str,
+        scale: float,
+        nbits: int = 8,
+        axis: int = -1,
+        sign: bool = True,
+        rounding: str = "round",
+        epsilon: float = 1.0 / (1 << 24),
+    ) -> Dict[str, str]:
+        """Calibrate the data by kl_divergence
+
+        Parameters
+        ----------
+        quantizer: BaseQuantizer
+            The quantizer
+        tensor_ctx: dict<str, str>
+            Tensor describe items.
+        name: str
+            The name of the tensor.
+        consumer: str
+            The name of the consumer.
+        scale: float
+            The scale factor
+        nbits: int
+            The number bits for quantize.
+        axis: int
+            The axis.
+        sign: bool
+            Whether to use sign.
+        rounding str
+            The rounding method.
+        epsilon: float
+            The epsilon for get scale.
+
+        Returns
+        -------
+        tensor_ctx: dict<str, str>
+            Tensor describe items.
+        """
+
+        return cls.quantize_normal(
+            quantizer, tensor_ctx, name, consumer, scale, nbits, axis, sign, 
rounding, epsilon
+        )
+
+    @classmethod
+    def framework(cls):
+        return MSCFramework.TENSORRT
+
+
+msc_utils.register_tool_method(TensorRTQuantizeMethod)
diff --git 
a/python/tvm/contrib/msc/framework/tensorrt/tools/quantize/quantizer.py 
b/python/tvm/contrib/msc/framework/tensorrt/tools/quantize/quantizer.py
new file mode 100644
index 0000000000..f971186196
--- /dev/null
+++ b/python/tvm/contrib/msc/framework/tensorrt/tools/quantize/quantizer.py
@@ -0,0 +1,366 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""tvm.contrib.msc.framework.tensorrt.tools.quantize.quantizer"""
+
+import os
+import struct
+from typing import List, Dict, Any, Tuple
+
+import tvm
+from tvm.contrib.msc.core.ir import MSCGraph
+from tvm.contrib.msc.core.tools.tool import ToolType, ToolStrategy
+from tvm.contrib.msc.core.tools.quantize import BaseQuantizer, QuantizeStage
+from tvm.contrib.msc.core.utils.namespace import MSCFramework
+from tvm.contrib.msc.core import utils as msc_utils
+
+
+class TensorRTQuantizerFactory(object):
+    """Quantizer factory for tensorrt"""
+
+    def create(self, base_cls: BaseQuantizer) -> BaseQuantizer:
+        """Create adaptive quantizer
+
+        Parameters
+        ----------
+        base_cls: BaseQuantizer
+            The base quantizer class
+
+        Returns
+        -------
+        quantizer_cls: BaseQuantizer
+            The quantizer class.
+        """
+
+        class Quantizer(base_cls):
+            """Adaptive quantizer for tensorrt"""
+
+            def setup(self) -> dict:
+                """Setup the tool
+
+                Returns
+                -------
+                info: dict
+                    The setup info.
+                """
+
+                if self._plan:
+                    self._use_range = all(
+                        info.get("use_range", False) for info in 
self._plan.values()
+                    )
+                else:
+                    self._use_range = True
+                return super().setup()
+
+            def _reset(
+                self, graphs: List[MSCGraph], weights: List[Dict[str, 
tvm.nd.array]]
+            ) -> Tuple[List[MSCGraph], List[Dict[str, tvm.nd.array]]]:
+                """Reset the tool
+
+                Parameters
+                ----------
+                graphs: list<MSCgraph>
+                    The msc graphs.
+                weights: list<dict<str, tvm.nd.array>>
+                    The weights
+
+                Returns
+                -------
+                graphs: list<MSCgraph>
+                    The msc graphs.
+                weights: list<dict<str, tvm.nd.array>>
+                    The weights
+                """
+
+                config_folder = msc_utils.get_config_dir()
+                self._range_files = [config_folder.relpath(g.name + ".range") 
for g in graphs]
+                calibrate_root = 
msc_utils.get_dataset_dir().create_dir("Calibrate")
+                self._calibrate_folders = [calibrate_root.relpath(g.name) for 
g in graphs]
+                if self._calibrated:
+                    if self._use_range:
+                        for r_file, graph in zip(self._range_files, graphs):
+                            if not os.path.isfile(r_file):
+                                self._plan_to_range(graph, r_file)
+                            self._logger.debug(
+                                "G[%s](%s) use range file: %s",
+                                graph.name,
+                                self._stage,
+                                r_file,
+                            )
+                    else:
+                        self._quantized_tensors = set()
+                elif self._stage == QuantizeStage.GATHER:
+                    self._calibrate_savers = []
+                    for folder, graph in zip(self._calibrate_folders, graphs):
+                        saver_options = {"input_names": [i.name for i in 
graph.get_inputs()]}
+                        saver = msc_utils.IODataSaver(folder, saver_options)
+                        self._calibrate_savers.append(saver)
+                        self._logger.debug(
+                            "G[%s](%s) create calibrate saver: %s",
+                            graph.name,
+                            self._stage,
+                            saver,
+                        )
+                else:
+                    assert all(
+                        msc_utils.is_io_dataset(f) for f in 
self._calibrate_folders
+                    ), "Some IODataset missing: " + 
str(self._calibrate_folders)
+                return super()._reset(graphs, weights)
+
+            def _execute_after_build(self, codegen_context: dict) -> dict:
+                """Execute after model build
+
+                Parameters
+                ----------
+                codegen_context: dict
+                    The context.
+
+                Returns
+                ----------
+                codegen_context: dict
+                    The processed context.
+                """
+
+                if self._stage == QuantizeStage.GATHER and self._forward_cnt 
== 0:
+                    return codegen_context
+                if not self._use_range:
+                    return codegen_context
+                processed = ["// Set int8 calibrator"]
+                range_file = self.get_graph().name + ".range"
+                version = [int(v) for v in 
codegen_context["version"].split(".")]
+                if msc_utils.compare_version(version, [6, 0, 0]) >= 0:
+                    configer = codegen_context["config"]
+                else:
+                    configer = codegen_context["builder"]
+                # check the range file if calibrated
+                if self._calibrated:
+                    processed.extend(
+                        [
+                            'if (!FileUtils::FileExist("{}")) 
{{'.format(range_file),
+                            '  logger.log(ILogger::Severity::kERROR, "{} not 
exist!");'.format(
+                                range_file
+                            ),
+                            "  return -1;",
+                            "}",
+                        ]
+                    )
+                processed.extend(
+                    [
+                        'MSCInt8EntropyCalibrator2 calibrator("{}", 
"{}");'.format(
+                            range_file, self._calibrate_folders[self._graph_id]
+                        ),
+                        "{}->setInt8Calibrator(&calibrator);".format(configer),
+                    ]
+                )
+                codegen_context["processed"].extend(processed)
+                return codegen_context
+
+            def _execute_before_forward(self, step_context: dict) -> dict:
+                """Execute before model forward
+
+                Parameters
+                ----------
+                step_context: dict
+                    The context.
+
+                Returns
+                ----------
+                step_context: dict
+                    The processed context.
+                """
+
+                if self._stage == QuantizeStage.GATHER:
+                    saver = self._calibrate_savers[self._graph_id]
+                    saver.save_batch(
+                        {name: data.asnumpy() for name, data in 
step_context["datas"].items()}
+                    )
+                    for name, data in step_context["datas"].items():
+                        self.debug_tensor(data, name, "any", "ctx_gathered")
+                super()._execute_before_forward(step_context)
+
+            def _quantize_tensor(
+                self,
+                tensor_ctx: Dict[str, str],
+                name: str,
+                consumer: str,
+                strategys: List[ToolStrategy],
+            ) -> Dict[str, str]:
+                """Quantize tensor
+
+                Parameters
+                -------
+                tensor_ctx: dict<str, str>
+                    Tensor describe items.
+                name: str
+                    The name of the tensor.
+                consumer: str
+                    The name of the consumer.
+                strategys: list<ToolStrategy>
+                    The strategys for the tensor.
+
+                Returns
+                -------
+                tensor_ctx: dict<str, str>
+                    Tensor items with processed.
+                """
+
+                if not self._use_range and name not in self._quantized_tensors:
+                    self._quantized_tensors.add(name)
+                    return super()._quantize_tensor(tensor_ctx, name, 
consumer, strategys)
+                return tensor_ctx
+
+            def calibrate(self) -> dict:
+                """Calibrate the datas
+
+                Returns
+                -------
+                plan: dict
+                    The calibrated plan.
+                """
+
+                for r_file, graph in zip(self._range_files, self._graphs):
+                    self._range_to_plan(graph, r_file)
+                self._calibrated, self._forward_cnt = True, 0
+                self.change_stage("quantize")
+                return self._plan
+
+            def config_generate(self, generate_config: Dict[str, Any]) -> 
Dict[str, Any]:
+                """Update the generate configs
+
+                Parameters
+                ----------
+                generate_config: dict<str, Any>
+                    The generate_config.
+
+                Returns
+                -------
+                generate_config: dict<str, Any>
+                    The updated generate_config.
+                """
+
+                if self._calibrated:
+                    if self._use_range:
+                        for config, r_file in zip(generate_config["codegen"], 
self._range_files):
+                            if os.path.isfile(r_file):
+                                config.update({"range_file": r_file, 
"precision": "int8"})
+                elif self._stage == QuantizeStage.GATHER and self._forward_cnt 
> 0:
+                    for config, saver, r_file in zip(
+                        generate_config["codegen"], self._calibrate_savers, 
self._range_files
+                    ):
+                        saver.finalize()
+                        self._logger.debug(
+                            "%ssave %d datas to %s",
+                            self.msg_mark(in_forward=False),
+                            self._forward_cnt,
+                            saver.folder,
+                        )
+                        config.update(
+                            {"dataset": saver.folder, "range_file": r_file, 
"precision": "int8"}
+                        )
+                    self.change_stage(QuantizeStage.CALIBRATE)
+                return generate_config
+
+            def _plan_to_range(self, graph: MSCGraph, range_file: str, 
title="MSCCalibrate"):
+                """Extract plan config to range_file
+
+                Parameters
+                ----------
+                plan: dict
+                    The plan.
+                graph: MSCGraph
+                    The graph.
+                range_file: str
+                    The output range_file path.
+                title: str
+                    The title of the range file.
+                """
+
+                def _scale_to_hex(scale):
+                    return hex(struct.unpack("<I", struct.pack("<f", scale / 
127))[0])[2:]
+
+                recorded = set()
+                with open(range_file, "w") as f:
+                    f.write(title + "\n")
+                    for name, info in self._plan.items():
+                        t_name, _ = self.from_tensor_id(name)
+                        if not graph.find_tensor(t_name):
+                            continue
+                        if t_name not in recorded:
+                            f.write("{}: {}\n").format(t_name, 
_scale_to_hex(info["scale"]))
+                            recorded.add(t_name)
+                self._logger.debug(
+                    "Graph[%s](%s) extract %d plan to range %s",
+                    graph.name,
+                    self._stage,
+                    len(recorded),
+                    range_file,
+                )
+
+            def _range_to_plan(self, graph: MSCGraph, range_file: str):
+                """Extract scale in range_file to plan
+
+                Parameters
+                ----------
+                graph: MSCGraph
+                    The graph.
+                range_file: str
+                    The input range_file path.
+                """
+
+                range_num = 0
+                with open(range_file, "r") as f:
+                    f.readline()
+                    line = f.readline()
+                    while line:
+                        name, scale = line.split(": ")
+                        scale = scale.strip()
+                        if scale == "0":
+                            value = 0.0
+                        else:
+                            value = struct.unpack("!f", 
bytes.fromhex(scale))[0] * 127
+                        range_num += 1
+                        consumers = graph.find_consumers(name)
+                        if consumers:
+                            for c in consumers:
+                                self._plan[self.to_tensor_id(name, c.name)] = {
+                                    "scale": value,
+                                    "use_range": True,
+                                }
+                        else:
+                            self._plan[self.to_tensor_id(name, "exit")] = {
+                                "scale": value,
+                                "use_range": True,
+                            }
+                        line = f.readline()
+                self._logger.debug(
+                    "Graph[%s](%s) extract %d range to plan from %s",
+                    graph.name,
+                    self._stage,
+                    range_num,
+                    range_file,
+                )
+
+            @classmethod
+            def framework(cls):
+                return MSCFramework.TENSORRT
+
+        return Quantizer
+
+
+factory = TensorRTQuantizerFactory()
+tools = msc_utils.get_registered_tool_cls(MSCFramework.MSC, 
ToolType.QUANTIZER, tool_style="all")
+for tool in tools.values():
+    msc_utils.register_tool_cls(factory.create(tool))
diff --git a/python/tvm/contrib/msc/framework/torch/tools/__init__.py 
b/python/tvm/contrib/msc/framework/torch/tools/__init__.py
index dda1e13822..f8fee73d69 100644
--- a/python/tvm/contrib/msc/framework/torch/tools/__init__.py
+++ b/python/tvm/contrib/msc/framework/torch/tools/__init__.py
@@ -17,4 +17,5 @@
 """tvm.contrib.msc.framework.torch.tools"""
 
 from .prune import *
+from .quantize import *
 from .track import *
diff --git a/python/tvm/contrib/msc/framework/tensorflow/tools/__init__.py 
b/python/tvm/contrib/msc/framework/torch/tools/quantize/__init__.py
similarity index 88%
copy from python/tvm/contrib/msc/framework/tensorflow/tools/__init__.py
copy to python/tvm/contrib/msc/framework/torch/tools/quantize/__init__.py
index d25cfd4e67..9391900edf 100644
--- a/python/tvm/contrib/msc/framework/tensorflow/tools/__init__.py
+++ b/python/tvm/contrib/msc/framework/torch/tools/quantize/__init__.py
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""tvm.contrib.msc.framework.tensorflow.tools"""
+"""tvm.contrib.msc.framework.torch.tools.quantize"""
 
-from .prune import *
-from .track import *
+from .quantizer import *
+from .method import *
diff --git a/python/tvm/contrib/msc/framework/torch/tools/quantize/method.py 
b/python/tvm/contrib/msc/framework/torch/tools/quantize/method.py
new file mode 100644
index 0000000000..6f82a796e1
--- /dev/null
+++ b/python/tvm/contrib/msc/framework/torch/tools/quantize/method.py
@@ -0,0 +1,237 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-argument
+"""tvm.contrib.msc.framework.torch.tools.quantize.method"""
+
+import numpy as np
+import torch
+from tvm.contrib.msc.core.tools.quantize import QuantizeMethod, BaseQuantizer
+from tvm.contrib.msc.core.utils.namespace import MSCFramework
+from tvm.contrib.msc.core import utils as msc_utils
+
+
+class TorchQuantizeMethod(QuantizeMethod):
+    """Default quantize method for torch"""
+
+    @classmethod
+    def amplify_data(
+        cls,
+        data: torch.Tensor,
+        scale: float,
+        min_val: float,
+        max_val: float,
+        rounding: str = "round",
+    ) -> torch.Tensor:
+        """Amplify the data
+
+        Parameters
+        ----------
+        data: torch.Tensor
+            The source data.
+        scale: float
+            The scale factor
+        min_val: float
+            The min.
+        max_val: float
+            The max.
+        rounding: str
+            The round method
+
+        Returns
+        -------
+        data: torch.Tensor
+            The processed data.
+        """
+
+        if rounding == "null":
+            return torch.clamp(data * scale, min_val, max_val)
+        if rounding == "floor":
+            return torch.clamp(torch.floor(data * scale), min_val, max_val)
+        if rounding == "ceil":
+            return torch.clamp(torch.ceil(data * scale), min_val, max_val)
+        if rounding == "round":
+            return torch.clamp(torch.round(data * scale), min_val, max_val)
+        if rounding == "trunc":
+            return torch.clamp(torch.trunc(data * scale), min_val, max_val)
+        if rounding == "logic_round":
+            data = torch.clamp(data * scale, min_val, max_val)
+            negative_ceil = torch.where(
+                torch.logical_and(data < 0, (data - torch.floor(data)) == 
0.5), torch.ceil(data), 0
+            )
+            data = torch.where(
+                torch.logical_and(data < 0, (data - torch.floor(data)) == 
0.5), 0, data
+            )
+            data = torch.where((data - torch.floor(data)) >= 0.5, 
torch.ceil(data), data)
+            data = torch.where((data - torch.floor(data)) < 0.5, 
torch.floor(data), data)
+            return data + negative_ceil
+        raise TypeError("Unexpected rounding " + str(rounding))
+
+    @classmethod
+    def gather_maxmin(
+        cls,
+        quantizer: BaseQuantizer,
+        data: torch.Tensor,
+        name: str,
+        consumer: str,
+        plan: dict,
+        nbits: int = 8,
+    ) -> dict:
+        """Gather the data by max/min
+
+        Parameters
+        ----------
+        quantizer: BaseQuantizer
+            The quantizer
+        data: np.ndarray
+            The source data.
+        name: str
+            The name of the tensor.
+        consumer: str
+            The name of the consumer.
+        plan: dict
+            The pre-calibrated plan.
+        nbits: int
+            The number bits for quantize.
+
+        Returns
+        -------
+        plan: dict
+            The plan of the tensor.
+        """
+
+        abs_max_list = plan.get("abs_max_list", [])
+        abs_max_list.append(float(torch.abs(data).max()))
+        max_list = plan.get("max_list", [])
+        max_list.append(float(data.max()))
+        min_list = plan.get("min_list", [])
+        min_list.append(float(data.min()))
+        return {
+            "abs_max_list": abs_max_list,
+            "max_list": max_list,
+            "min_list": min_list,
+            "calibrated": False,
+        }
+
+    @classmethod
+    def gather_max_per_channel(
+        cls,
+        quantizer: BaseQuantizer,
+        data: torch.Tensor,
+        name: str,
+        consumer: str,
+        plan: dict,
+        nbits: int = 8,
+        channel: str = "O",
+        auto_unsign: bool = False,
+    ) -> dict:
+        """Gather the data by max_per_channel
+
+        Parameters
+        ----------
+        quantizer: BaseQuantizer
+            The quantizer
+        data: np.ndarray
+            The source data.
+        name: str
+            The name of the tensor.
+        consumer: str
+            The name of the consumer.
+        plan: dict
+            The pre-calibrated plan.
+        nbits: int
+            The number bits for quantize.
+        channel: str
+            The channel reference.
+        auto_unsign: bool
+            Whether to use auto unsign.
+
+        Returns
+        -------
+        plan: dict
+            The plan of the tensor.
+        """
+
+        weight = quantizer.find_tensor(name)
+        axis = weight.layout_of(channel)
+        channel_max = [torch.abs(d).max() for d in torch.chunk(data, 
data.shape[axis], dim=axis)]
+        sign = data.min() < 0 if auto_unsign else True
+        valid_range = 2 ** (nbits - int(sign)) - 1
+        scale = [valid_range / float(m) for m in channel_max]
+        return {"scale": scale, "sign": sign, "axis": axis, "calibrated": True}
+
+    @classmethod
+    def quantize_normal(
+        cls,
+        quantizer: BaseQuantizer,
+        data: torch.Tensor,
+        name: str,
+        consumer: str,
+        scale: float,
+        nbits: int = 8,
+        axis: int = -1,
+        sign: bool = True,
+        rounding: str = "round",
+        epsilon: float = 1.0 / (1 << 24),
+    ) -> torch.Tensor:
+        """Calibrate the data by kl_divergence
+
+        Parameters
+        ----------
+        quantizer: BaseQuantizer
+            The quantizer
+        data: torch.Tensor
+            The source data.
+        name: str
+            The name of the tensor.
+        consumer: str
+            The name of the consumer.
+        scale: float
+            The scale factor
+        nbits: int
+            The number bits for quantize.
+        axis: int
+            The axis.
+        sign: bool
+            Whether to use sign.
+        rounding str
+            The rounding method.
+        epsilon: float
+            The epsilon for get scale.
+
+        Returns
+        -------
+        data: torch.Tensor
+            The processed tensor.
+        """
+
+        valid_range = 2 ** (nbits - int(sign)) - 1
+        min_val = -valid_range if sign else 0
+        scale_tensor = quantizer._get_tensor_cache(name, consumer, 
"scale_tensor")
+        if scale_tensor is None:
+            scale_tensor = cls.get_scale_tensor(data, scale, axis, epsilon)
+            if isinstance(scale_tensor, np.ndarray):
+                scale_tensor = torch.from_numpy(scale_tensor).to(data.device)
+            quantizer._save_tensor_cache(name, consumer, "scale_tensor", 
scale_tensor)
+        data = cls.amplify_data(data, scale_tensor, min_val, valid_range, 
rounding)
+        return data / scale_tensor
+
+    @classmethod
+    def framework(cls):
+        return MSCFramework.TORCH
+
+
+msc_utils.register_tool_method(TorchQuantizeMethod)
diff --git a/python/tvm/contrib/msc/framework/torch/tools/quantize/quantizer.py 
b/python/tvm/contrib/msc/framework/torch/tools/quantize/quantizer.py
new file mode 100644
index 0000000000..0e5c599b87
--- /dev/null
+++ b/python/tvm/contrib/msc/framework/torch/tools/quantize/quantizer.py
@@ -0,0 +1,55 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""tvm.contrib.msc.framework.torch.tools.quantize.quantizer"""
+
+from tvm.contrib.msc.core.tools.tool import ToolType
+from tvm.contrib.msc.core.tools.quantize import BaseQuantizer
+from tvm.contrib.msc.core.utils.namespace import MSCFramework
+from tvm.contrib.msc.core import utils as msc_utils
+
+
+class TorchQuantizerFactory(object):
+    """Quantizer factory for torch"""
+
+    def create(self, base_cls: BaseQuantizer) -> BaseQuantizer:
+        """Create adaptive quantizer
+
+        Parameters
+        ----------
+        base_cls: BaseQuantizer
+            The base quantizer class
+
+        Returns
+        -------
+        quantizer_cls: BaseQuantizer
+            The quantizer class.
+        """
+
+        class Quantizer(base_cls):
+            """Adaptive quantizer for torch"""
+
+            @classmethod
+            def framework(cls):
+                return MSCFramework.TORCH
+
+        return Quantizer
+
+
+factory = TorchQuantizerFactory()
+tools = msc_utils.get_registered_tool_cls(MSCFramework.MSC, 
ToolType.QUANTIZER, tool_style="all")
+for tool in tools.values():
+    msc_utils.register_tool_cls(factory.create(tool))
diff --git a/python/tvm/contrib/msc/framework/tvm/tools/__init__.py 
b/python/tvm/contrib/msc/framework/tvm/tools/__init__.py
index 226ae3102d..ddfd41f3c8 100644
--- a/python/tvm/contrib/msc/framework/tvm/tools/__init__.py
+++ b/python/tvm/contrib/msc/framework/tvm/tools/__init__.py
@@ -17,4 +17,5 @@
 """tvm.contrib.msc.framework.tvm.tools"""
 
 from .prune import *
+from .quantize import *
 from .track import *
diff --git a/python/tvm/contrib/msc/framework/tensorflow/tools/__init__.py 
b/python/tvm/contrib/msc/framework/tvm/tools/quantize/__init__.py
similarity index 88%
copy from python/tvm/contrib/msc/framework/tensorflow/tools/__init__.py
copy to python/tvm/contrib/msc/framework/tvm/tools/quantize/__init__.py
index d25cfd4e67..0026724989 100644
--- a/python/tvm/contrib/msc/framework/tensorflow/tools/__init__.py
+++ b/python/tvm/contrib/msc/framework/tvm/tools/quantize/__init__.py
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""tvm.contrib.msc.framework.tensorflow.tools"""
+"""tvm.contrib.msc.framework.tvm.tools.quantize"""
 
-from .prune import *
-from .track import *
+from .quantizer import *
+from .method import *
diff --git a/python/tvm/contrib/msc/framework/tvm/tools/quantize/method.py 
b/python/tvm/contrib/msc/framework/tvm/tools/quantize/method.py
new file mode 100644
index 0000000000..9966e9c1af
--- /dev/null
+++ b/python/tvm/contrib/msc/framework/tvm/tools/quantize/method.py
@@ -0,0 +1,204 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-argument
+"""tvm.contrib.msc.framework.tvm.tools.quantize.method"""
+
+from typing import Tuple
+import numpy as np
+
+import tvm
+from tvm.relax import op as relax_op
+from tvm.contrib.msc.core.tools.quantize import QuantizeMethod, BaseQuantizer
+from tvm.contrib.msc.core.utils.namespace import MSCFramework
+from tvm.contrib.msc.core import utils as msc_utils
+from tvm.contrib.msc.core import _ffi_api
+
+
+class TVMQuantizeMethod(QuantizeMethod):
+    """Default quantize method for tvm"""
+
+    @classmethod
+    def get_quantize_cache(
+        cls,
+        quantizer: BaseQuantizer,
+        data: tvm.relax.Var,
+        name: str,
+        consumer: str,
+        scale: float,
+        axis: int = -1,
+        epsilon: float = 1.0 / (1 << 24),
+    ) -> Tuple[tvm.relax.Constant, tvm.relax.Constant]:
+        """Calibrate the data by kl_divergence
+
+        Parameters
+        ----------
+        quantizer: BaseQuantizer
+            The quantizer
+        data: tvm.relax.Var
+            The source data.
+        name: str
+            The name of the tensor.
+        consumer: str
+            The name of the consumer.
+        scale: float
+            The scale factor
+        axis: int
+            The axis.
+        epsilon: float
+            The epsilon for get scale.
+
+        Returns
+        -------
+        scale_tensor: tvm.relax.Constant
+            The scale_tensor.
+        zero_point: tvm.relax.Constant
+            The zero_point.
+        """
+
+        name_prefix = name if quantizer._cache_processed else 
quantizer.to_tensor_id(name, consumer)
+        scale_tensor = quantizer._get_tensor_cache(name, consumer, 
"scale_tensor")
+        zero_point = quantizer._get_tensor_cache(name, consumer, "zero_point")
+        if scale_tensor is None:
+            scale_tensor = cls.get_scale_tensor(data, scale, axis, epsilon, 
expand_dims=False)
+            if isinstance(scale_tensor, float):
+                scale_tensor = np.array(scale_tensor)
+            scale_tensor = 
scale_tensor.astype(quantizer.find_tensor(name).dtype_name)
+            zero_point = np.zeros_like(scale_tensor).astype("int8")
+            scale_span = _ffi_api.SpanCreateWithAttr("name", name_prefix + 
"_scale")
+            scale_tensor = tvm.relax.Constant(tvm.nd.array(scale_tensor), 
span=scale_span)
+            zp_span = _ffi_api.SpanCreateWithAttr("name", name_prefix + 
"_zero_point")
+            zero_point = tvm.relax.Constant(tvm.nd.array(zero_point), 
span=zp_span)
+            quantizer._save_tensor_cache(name, consumer, "scale_tensor", 
scale_tensor)
+            quantizer._save_tensor_cache(name, consumer, "zero_point", 
zero_point)
+        return scale_tensor, zero_point
+
+    @classmethod
+    def quantize_normal(
+        cls,
+        quantizer: BaseQuantizer,
+        data: tvm.relax.Var,
+        name: str,
+        consumer: str,
+        scale: float,
+        nbits: int = 8,
+        axis: int = -1,
+        sign: bool = True,
+        rounding: str = "round",
+        epsilon: float = 1.0 / (1 << 24),
+    ) -> tvm.relax.Var:
+        """Calibrate the data by kl_divergence
+
+        Parameters
+        ----------
+        quantizer: BaseQuantizer
+            The quantizer
+        data: tvm.relax.Var
+            The source data.
+        name: str
+            The name of the tensor.
+        consumer: str
+            The name of the consumer.
+        scale: float
+            The scale factor
+        nbits: int
+            The number bits for quantize.
+        axis: int
+            The axis.
+        sign: bool
+            Whether to use sign.
+        rounding str
+            The rounding method.
+        epsilon: float
+            The epsilon for get scale.
+
+        Returns
+        -------
+        data: tvm.relax.Var
+            The processed tensor.
+        """
+
+        if nbits == 8:
+            dtype = "int8"
+        else:
+            raise TypeError("Unexpected nbits " + str(nbits))
+        name_prefix = name if quantizer._cache_processed else 
quantizer.to_tensor_id(name, consumer)
+        scale_tensor, zero_point = cls.get_quantize_cache(
+            quantizer, data, name, consumer, scale, axis, epsilon
+        )
+        expr = relax_op.quantize(data, scale_tensor, zero_point, axis, dtype)
+        return quantizer._block_builder.emit(expr, name_hint=name_prefix + 
"_quantize")
+
+    @classmethod
+    def dequantize_normal(
+        cls,
+        quantizer: BaseQuantizer,
+        data: tvm.relax.Var,
+        name: str,
+        consumer: str,
+        scale: float = -1.0,
+        nbits: int = 8,
+        axis: int = -1,
+        sign: bool = True,
+        rounding: str = "round",
+        epsilon: float = 1.0 / (1 << 24),
+    ) -> tvm.relax.Var:
+        """Calibrate the data by kl_divergence
+
+        Parameters
+        ----------
+        quantizer: BaseQuantizer
+            The quantizer
+        data: np.ndarray
+            The source data.
+        name: str
+            The name of the tensor.
+        consumer: str
+            The name of the consumer.
+        scale: float
+            The scale factor
+        nbits: int
+            The number bits for quantize.
+        axis: int
+            The axis.
+        sign: bool
+            Whether to use sign.
+        rounding str
+            The rounding method.
+        epsilon: float
+            The epsilon for get scale.
+
+        Returns
+        -------
+        data: array like
+            The processed tensor.
+        """
+
+        name_prefix = name if quantizer._cache_processed else 
quantizer.to_tensor_id(name, consumer)
+        scale_tensor, zero_point = cls.get_quantize_cache(
+            quantizer, data, name, consumer, scale, axis, epsilon
+        )
+        expr = relax_op.dequantize(
+            data, scale_tensor, zero_point, axis, 
quantizer.find_tensor(name).dtype
+        )
+        return quantizer._block_builder.emit(expr, name_hint=name_prefix + 
"_dequantize")
+
+    @classmethod
+    def framework(cls):
+        return MSCFramework.TVM
+
+
+msc_utils.register_tool_method(TVMQuantizeMethod)
diff --git a/python/tvm/contrib/msc/framework/tvm/tools/quantize/quantizer.py 
b/python/tvm/contrib/msc/framework/tvm/tools/quantize/quantizer.py
new file mode 100644
index 0000000000..d4680b9088
--- /dev/null
+++ b/python/tvm/contrib/msc/framework/tvm/tools/quantize/quantizer.py
@@ -0,0 +1,167 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-argument
+"""tvm.contrib.msc.framework.tvm.tools.quantize.quantizer"""
+
+from typing import List, Union
+
+import tvm
+from tvm.contrib.msc.core.tools.tool import ToolType, ToolStrategy
+from tvm.contrib.msc.core.tools.quantize import BaseQuantizer
+from tvm.contrib.msc.core.utils.namespace import MSCFramework
+from tvm.contrib.msc.core import utils as msc_utils
+
+
+class TVMQuantizerFactory(object):
+    """Quantizer factory for tvm"""
+
+    def create(self, base_cls: BaseQuantizer) -> BaseQuantizer:
+        """Create adaptive quantizer
+
+        Parameters
+        ----------
+        base_cls: BaseQuantizer
+            The base quantizer class
+
+        Returns
+        -------
+        quantizer_cls: BaseQuantizer
+            The quantizer class.
+        """
+
+        class Quantizer(base_cls):
+            """Adaptive quantizer for tvm"""
+
+            def _execute_before_build(self, block_builder: 
tvm.relax.BlockBuilder):
+                """Execute before model build
+
+                Parameters
+                ----------
+                block_builder: tvm.relax.BlockBuilder
+                    The block builder.
+                """
+
+                self._block_builder = block_builder
+                self._gather_tensors, self._gather_names = {}, []
+                super()._execute_before_build(block_builder)
+
+            def _execute_after_build(
+                self, output: Union[tvm.relax.Var, List[tvm.relax.DataflowVar]]
+            ) -> List[tvm.relax.Var]:
+                """Execute after model build
+
+                Parameters
+                ----------
+                output: var or list<var>
+                    The output var of the model.
+
+                Returns
+                -------
+                outputs: list<var>
+                    The modified outputs var.
+                """
+
+                if self._calibrated:
+                    return super()._execute_after_build(output)
+                self._gather_names = list(sorted(self._gather_tensors.keys()))
+                gather_tensors = [self._gather_tensors[o]["tensor"] for o in 
self._gather_names]
+                if isinstance(output, tvm.relax.Var):
+                    return super()._execute_after_build([output] + 
gather_tensors)
+                return super()._execute_after_build(output + gather_tensors)
+
+            def _execute_after_forward(
+                self, outputs: List[tvm.runtime.NDArray]
+            ) -> Union[tvm.runtime.NDArray, List[tvm.runtime.NDArray]]:
+                """Execute after model forward
+
+                Parameters
+                ----------
+                outputs: list<np.ndarray>
+                    The output datas.
+
+                Returns
+                -------
+                output: np.ndarray or list<np.ndarray>
+                    The modified output ndarray.
+                """
+
+                if self._calibrated:
+                    return super()._execute_after_forward(outputs)
+                output_num = len(outputs) - len(self._gather_names)
+                for data, name in zip(outputs[output_num:], 
self._gather_names):
+                    info = self._gather_tensors[name]
+                    for consumer in info["consumers"]:
+                        strategys = self._get_tensor_strategys(name, consumer)
+                        self._gather_tensor(data, name, consumer, strategys)
+                if output_num == 1:
+                    return super()._execute_after_forward(outputs[0])
+                return super()._execute_after_forward(outputs[:output_num])
+
+            def _process_tensor(
+                self,
+                tensor: tvm.relax.DataflowVar,
+                name: str,
+                consumer: str,
+                scope: str,
+                strategys: List[ToolStrategy],
+            ) -> tvm.relax.DataflowVar:
+                """Process tensor
+
+                Parameters
+                -------
+                tensor: Any
+                    Tensor in framework
+                name: str
+                    The name of the tensor.
+                consumer: str
+                    The name of the consumer.
+                scope: str
+                    The scope mark teacher| student| null.
+                strategys: list<ToolStrategy>
+                    The strategys for the tensor.
+
+                Returns
+                -------
+                tensor: Any
+                    The processed tensor.
+                """
+
+                if not self._calibrated:
+                    if self.is_weight(name):
+                        return self._gather_tensor(self.get_data(name), name, 
consumer, strategys)
+                    if name not in self._gather_tensors:
+                        self._gather_tensors[name] = {
+                            "consumers": [consumer],
+                            "tensor": tensor,
+                        }
+                        self._gather_names.append(name)
+                    else:
+                        
self._gather_tensors[name]["consumers"].append(consumer)
+                    return tensor
+                return self._quantize_tensor(tensor, name, consumer, strategys)
+
+            @classmethod
+            def framework(cls):
+                return MSCFramework.TVM
+
+        return Quantizer
+
+
+factory = TVMQuantizerFactory()
+tools = msc_utils.get_registered_tool_cls(MSCFramework.MSC, 
ToolType.QUANTIZER, tool_style="all")
+for tool in tools.values():
+    msc_utils.register_tool_cls(factory.create(tool))
diff --git a/python/tvm/contrib/msc/pipeline/manager.py 
b/python/tvm/contrib/msc/pipeline/manager.py
index bbd6d452ad..8a37ef951f 100644
--- a/python/tvm/contrib/msc/pipeline/manager.py
+++ b/python/tvm/contrib/msc/pipeline/manager.py
@@ -378,6 +378,10 @@ class BaseManager(object):
         if _tool_enabled(ToolType.PRUNER):
             self._apply_tool(ToolType.PRUNER, stage_config)
 
+        # run quantize
+        if _tool_enabled(ToolType.QUANTIZER):
+            self._apply_tool(ToolType.QUANTIZER, stage_config)
+
         # optimize and get the runner
         msc_utils.time_stamp(MSCStage.OPTIMIZE)
         return self._create_runner(
diff --git a/tests/python/contrib/test_msc/test_tools.py 
b/tests/python/contrib/test_msc/test_tools.py
index f396b81ea4..7e981d348b 100644
--- a/tests/python/contrib/test_msc/test_tools.py
+++ b/tests/python/contrib/test_msc/test_tools.py
@@ -77,7 +77,42 @@ def get_tool_config(tool_type):
             "strategys": [{"method": "per_channel", "density": 0.8}],
         }
     elif tool_type == ToolType.QUANTIZER:
-        raise NotImplementedError("Quantizer is not supported")
+        # pylint: disable=import-outside-toplevel
+        from tvm.contrib.msc.core.tools.quantize import QuantizeStage
+
+        config = {
+            "plan_file": "msc_quantizer.json",
+            "strategys": [
+                {
+                    "method": "gather_maxmin",
+                    "op_types": ["nn.conv2d", "msc.linear"],
+                    "tensor_types": ["input", "output"],
+                    "stages": [QuantizeStage.GATHER],
+                },
+                {
+                    "method": "gather_max_per_channel",
+                    "op_types": ["nn.conv2d", "msc.linear"],
+                    "tensor_types": ["weight"],
+                    "stages": [QuantizeStage.GATHER],
+                },
+                {
+                    "method": "calibrate_maxmin",
+                    "op_types": ["nn.conv2d", "msc.linear"],
+                    "tensor_types": ["input", "output"],
+                    "stages": [QuantizeStage.CALIBRATE],
+                },
+                {
+                    "method": "quantize_normal",
+                    "op_types": ["nn.conv2d", "msc.linear"],
+                    "tensor_types": ["input", "weight"],
+                },
+                {
+                    "method": "dequantize_normal",
+                    "op_types": ["nn.conv2d", "msc.linear"],
+                    "tensor_types": ["output"],
+                },
+            ],
+        }
     elif tool_type == ToolType.TRACKER:
         config = {
             "plan_file": "msc_tracker.json",
@@ -183,7 +218,7 @@ def get_model_info(compile_type):
     raise TypeError("Unexpected compile_type " + str(compile_type))
 
 
[email protected]("tool_type", [ToolType.PRUNER, ToolType.TRACKER])
[email protected]("tool_type", [ToolType.PRUNER, ToolType.QUANTIZER, 
ToolType.TRACKER])
 def test_tvm_tool(tool_type):
     """Test tools for tvm"""
 
@@ -194,7 +229,10 @@ def test_tvm_tool(tool_type):
 
 
 @requires_tensorrt
[email protected]("tool_type", [ToolType.PRUNER, ToolType.TRACKER])
[email protected](
+    "tool_type",
+    [ToolType.PRUNER, ToolType.QUANTIZER, ToolType.TRACKER],
+)
 def test_tensorrt_tool(tool_type):
     """Test tools for tensorrt"""
 


Reply via email to