masahi commented on a change in pull request #7474:
URL: https://github.com/apache/tvm/pull/7474#discussion_r579008341



##########
File path: python/tvm/relay/transform/quantize/_calibrator.py
##########
@@ -0,0 +1,382 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""API for calibrating a quantized function."""
+import numpy as np
+
+import tvm
+from tvm import relay
+from tvm.contrib import graph_runtime
+import tvm.relay.build_module as build_module
+
+
+class QuantizationCalibrator:
+    """The QuantizationCalibrator picks scales and zero points for all qnn ops 
in the quantized
+    module.
+
+    Parameters
+    ----------
+    quantizer : Quantizer
+        Quantizer created with the mod we are calibrating.
+
+    target : String, optional
+        The target to run the quantized function on during calibration.
+
+    ctx : String, optional
+        The ctx used for running the quantized function on during calibration.
+
+    dataset_manager : DatasetManager, optional
+        The dataset manager containing data used to run the graph during
+        data-aware calibration.
+    """
+
+    def __init__(self, quantizer, target="llvm", ctx=tvm.cpu(), 
dataset_manager=None,
+                 show_scale_zps=False):
+        self.quantizer = quantizer
+
+        self.calibration_info = CalibrationInfo(
+            quantizer.tuple_subgraph_func,
+            quantizer.q_tuple_subgraph_func,
+            quantizer.partition_infos,
+            dataset_manager,
+            target,
+            ctx,
+        )
+
+        self.show_scale_zps = show_scale_zps
+
+    def calibrate(self):
+        """Picks the scales and zero points for all qnn ops in the quantized 
graph, using the
+        calibrate_pattern function from the quantizer.
+
+        Returns
+        -------
+        calibrated_func : relay.Function
+            The quantized function with the values for scales and zero points 
substituted into the
+            function.
+        """
+        # Create a map of DFPatternCallback to QuantizerPattern
+        pattern_map = {pattern.pattern: pattern for pattern in 
self.quantizer.patterns}
+
+        for partition_info in self.calibration_info.partition_infos:
+            # Set the partition info so we can access it from the callback
+            self.calibration_info.set_current_partition_info(partition_info)
+            quantizer_pattern = pattern_map[partition_info.pattern]
+
+            # Get the values for scales and ZPs in this layer, store
+            scale_zps = 
quantizer_pattern.calibrate_pattern(self.calibration_info)
+            if self.show_scale_zps:
+                self.report_scale_zps(scale_zps)
+            self.calibration_info.update_scale_zp_map(scale_zps)
+
+        calibrated_func = build_module.bind_params_by_name(
+            self.quantizer.q_tuple_subgraph_func, 
self.calibration_info.scale_zp_value_map
+        )
+
+        # If num_orig_outputs is -1, original output wasn't a tuple
+        params = calibrated_func.params
+        if self.quantizer.num_orig_outputs == -1:
+            calibrated_func = relay.Function(params, 
calibrated_func.body.fields[0])
+        else:
+            new_body = relay.Tuple(calibrated_func.body.fields[0 : 
self.quantizer.num_orig_outputs])
+            calibrated_func = relay.Function(params, new_body)
+
+        return calibrated_func
+
+    def report_scale_zps(self, scale_zp_map):
+        """Prints the scales and zero points out.
+
+        Parameters
+        ----------
+        scale_zp_map : dict of str to value
+            The map from names of scale and zero point variables to their 
assigned values.
+        """
+        for key, value in scale_zp_map.items():
+            print("Set ", key, " variable to ", value)
+
+
+class CalibrationInfo:
+    """Helper class that contains information necessary for picking scales and 
zero points into
+    calibrate_pattern. The state of CalibrationInfo is updated by 
QuantizationCalibrator.
+
+    Parameters
+    ----------
+    tuple_subgraph_func : relay.Function
+        A function whose output is a tuple that contains values we will need 
to access during
+        calibration.
+
+    q_tuple_subgraph_func : relay.Function
+        A quantized version of the tuple_subgraph_func. Note that to run this 
function, you
+        must pass in values for scales and zero points.
+
+    partition_infos : List[PatternCalibrationInfo]
+        A list of objects that correspond to every pattern matched during 
quantization. Each
+        contains scale and zero point variables, and indices into the the 
tuple functions.
+
+    dataset_manager : DatasetManager
+        The dataset manager containing data used to run the graph during 
data-aware calibration.
+
+    target : String
+        The target to run the quantized function on during calibration.
+
+    ctx : String
+        The ctx used for running the quantized function on during calibration.
+    """
+
+    def __init__(
+        self,
+        tuple_subgraph_func,
+        q_tuple_subgraph_func,
+        partition_infos,
+        dataset_manager,
+        target,
+        ctx,
+    ):
+        self.tuple_subgraph_func = tuple_subgraph_func
+        self.q_tuple_subgraph_func = q_tuple_subgraph_func
+        self.dataset_manager = dataset_manager
+        self.partition_infos = partition_infos
+        self.target = target
+        self.ctx = ctx
+
+        self.partition_info = None
+        self.input_scale_zps = None
+
+        tuple_subgraph_mod = 
tvm.ir.IRModule.from_expr(self.tuple_subgraph_func)
+        q_tuple_subgraph_mod = 
tvm.ir.IRModule.from_expr(self.q_tuple_subgraph_func)
+
+        self.tuple_subgraph_graphmodule = None
+        self.q_tuple_subgraph_graphmodule = None
+        self.init_subgraph_graphmodules(tuple_subgraph_mod, 
q_tuple_subgraph_mod)
+
+        self.scale_zp_value_map = {}
+        self.initialize_scale_zp_map()
+
+    def init_subgraph_graphmodules(self, tuple_subgraph_mod, 
q_tuple_subgraph_mod):
+        """Builds the tuple subgraphs so they can be run during calibration.
+
+        Parameters
+        ----------
+        tuple_subgraph_mod : tvm.ir.IRModule
+            Module wrapping tuple_subgraph_func.
+
+        q_tuple_subgraph_mod : tvm.ir.IRModule
+            Module wrapping q_tuple_subgraph_func.
+        """
+        # AlterOpLayout is disabled because it inserts some pads and other ops
+        with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
+            tuple_subgraph_lib = relay.build(tuple_subgraph_mod, 
target=self.target)
+            q_tuple_subgraph_lib = relay.build(q_tuple_subgraph_mod, 
target=self.target)
+
+        ts_graph_mod = 
graph_runtime.GraphModule(tuple_subgraph_lib["default"](self.ctx))
+        q_ts_graph_mod = 
graph_runtime.GraphModule(q_tuple_subgraph_lib["default"](self.ctx))
+        self.tuple_subgraph_graphmodule = ts_graph_mod
+        self.q_tuple_subgraph_graphmodule = q_ts_graph_mod
+
+    def initialize_scale_zp_map(self):
+        """Initializes scales to 1 and zero points to zero. These values will 
only be used
+        to calculate values in the tuple subgraph that are not returned to the 
user."""

Review comment:
       We should be careful with how we initialize the params here. A scale of 
1 doesn't make sense, since it would essentially clamp the entire floating 
point range to [-128, 127]. So the outputs from the first run will likely be 
garbage.
   
   Does the choice of initialization affect the accuracy of the final model? If 
so, we should use more sensible values by default.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to