This is an automated email from the ASF dual-hosted git repository.

jiangjiajun pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e005f8574c [Frontend][PaddlePaddle] PaddlePaddle model with NCHW data 
format that supports quantization (#16651)
e005f8574c is described below

commit e005f8574ca0208d75e9fd0790caa1a06d95af94
Author: Zheng-Bicheng <[email protected]>
AuthorDate: Thu Mar 7 17:39:00 2024 +0800

    [Frontend][PaddlePaddle] PaddlePaddle model with NCHW data format that 
supports quantization (#16651)
    
    * support conv2d when data_format is NHWC
    
    * modify the annotation
    
    * Do not convert input data when processing quantization conv_2d nodes
    
    * Fix code formatting issues
    
    * fixed error code format
    
    * update dequantize and quantize
    
    * fixed bug when model is fp32 model
    
    * update dequantize and quantize
    
    * update for paddle quantize model when format is NCHW
---
 python/tvm/relay/frontend/paddlepaddle.py | 83 +++++++++++++++++++++++++++----
 1 file changed, 74 insertions(+), 9 deletions(-)

diff --git a/python/tvm/relay/frontend/paddlepaddle.py 
b/python/tvm/relay/frontend/paddlepaddle.py
index bb72d30352..b00bb43d46 100755
--- a/python/tvm/relay/frontend/paddlepaddle.py
+++ b/python/tvm/relay/frontend/paddlepaddle.py
@@ -31,6 +31,7 @@ from .. import expr as _expr
 from .. import function as _function
 from .. import ty as _ty
 from .. import op as _op
+from .. import qnn as _qnn
 from .common import (
     autopad,
     fold_constant,
@@ -314,9 +315,9 @@ def convert_conv2d(g, op, block):
     strides = op.attr("strides")
 
     kernel = g.get_node(op.input("Filter")[0])
-    kernel_layout = "OIHW"
     input_x = g.get_node(op.input("Input")[0])
     data_layout = op.attr("data_format")
+    kernel_layout = "OIHW" if data_layout == "NCHW" else "HWIO"
     out_channels, _, k_h, k_w = infer_shape(kernel)
     if padding_algorithm == "VALID":
         paddings = [0, 0]
@@ -336,9 +337,15 @@ def convert_conv2d(g, op, block):
         msg = f'Value {padding_algorithm} in attribute "padding" of operator 
Conv is not "valid."'
         raise tvm.error.OpAttributeInvalid(msg)
 
-    if data_layout == "NHWC":
-        kernel_layout = "HWIO"
-        # PaddlePaddle wieght layout is "OIHW", tvm need "HWIO" when op 
data_format is "NHWC".
+    is_quantized = op.has_attr("quantization_type")
+    # PaddlePaddle wieght layout is "OIHW", tvm need "HWIO" when op 
data_format is "NHWC".
+    # There are two situations when converting the data format of weights:
+    # 1 Conv_2d is not a quantified OP, its weight information is the weights 
themselves.
+    #   We directly convert the weight information when processing conv_2d.
+    # 2 Conv_2d is a quantified OP, and its weight information is the output of
+    #   the quantize_linear operator. Therefore, the weight information needs 
to be
+    #   transformed when processing the quantize_linear operator.
+    if (not is_quantized) and (data_layout == "NHWC"):
         kernel_data = g.get_params(op.input("Filter")[0])
         kernel_data = kernel_data.asnumpy()
         kernel_data = kernel_data.transpose((2, 3, 1, 0))
@@ -1626,7 +1633,7 @@ def convert_pool3d(g, op, block):
         raise tvm.error.OpAttributeInvalid(msg.format(padding_algorithm))
 
     # handle with special case
-    # while kernel size less than input size
+    # while kernel size more than input size
     # shrink kernel size to input size
     if (
         not isinstance(in_h, _op.Expr)
@@ -1812,6 +1819,59 @@ def convert_roi_align(g, op, block):
     g.add_node(op.output("Out")[0], out)
 
 
+def convert_dequantize_linear(g, op, block):
+    """Operator converter for dequantize_linear."""
+
+    data_node_name = op.input("X")[0]
+    data_node = g.get_node(data_node_name)
+
+    # paddle_scale = tvm_scale * 127
+    paddle_quantize_scale = g.get_params(op.input("Scale")[0]).asnumpy()
+    tvm_quantize_scale = paddle_quantize_scale / 127.0
+
+    tvm_quantize_zp = g.get_params(op.input("ZeroPoint")[0]).asnumpy()
+
+    tvm_quantize_axis = op.attr("quant_axis")
+    if tvm_quantize_axis == -1:
+        tvm_quantize_axis = 0
+
+    if len(infer_shape(data_node)) < 2:
+        tvm_quantize_axis = 0
+
+    out = _qnn.op.dequantize(
+        data=data_node,
+        input_scale=_op.const(tvm_quantize_scale, "float32"),
+        input_zero_point=_op.const(tvm_quantize_zp, "int32"),
+        axis=tvm_quantize_axis,
+    )
+    g.add_node(op.output("Y")[0], out)
+
+
+def convert_quantize_linear(g, op, block):
+    """Operator converter for dequantize_linear."""
+
+    data_node_name = op.input("X")[0]
+    data_node = g.get_node(data_node_name)
+
+    # paddle_scale = tvm_scale * 127
+    paddle_quantize_scale = g.get_params(op.input("Scale")[0]).asnumpy()
+    tvm_quantize_scale = paddle_quantize_scale / 127.0
+
+    tvm_quantize_zp = g.get_params(op.input("ZeroPoint")[0]).asnumpy()
+    tvm_quantize_axis = op.attr("quant_axis")
+
+    if tvm_quantize_axis == -1:
+        tvm_quantize_axis = 0
+
+    out = _qnn.op.quantize(
+        data=data_node,
+        output_scale=_op.const(tvm_quantize_scale, "float32"),
+        output_zero_point=_op.const(tvm_quantize_zp, "int32"),
+        axis=tvm_quantize_axis,
+    )
+    g.add_node(op.output("Y")[0], out)
+
+
 def convert_rnn(g, op, block):
     """Operator converter for rnn."""
 
@@ -2386,11 +2446,11 @@ def convert_slice(g, op, block):
 def convert_softmax(g, op, block):
     """Operator converter for softmax."""
 
+    x = g.get_node(op.input("X")[0])
     axis = op.attr("axis")
     input_shape = block.var(op.input("X")[0]).shape
     if axis < 0:
         axis = len(input_shape) + axis
-    x = g.get_node(op.input("X")[0])
     m = _op.max(x, axis, keepdims=True)
     e = _op.exp(x - m)
     out = e / _op.sum(e, axis, keepdims=True)
@@ -2905,6 +2965,9 @@ _convert_map = {
     "unstack": convert_unstack,
     "where": convert_where,
     "where_index": convert_where_index,
+    # Quantized
+    "dequantize_linear": convert_dequantize_linear,
+    "quantize_linear": convert_quantize_linear,
 }
 
 
@@ -2938,7 +3001,7 @@ class GraphProto:
 
         if name is None:
             return self.params
-        assert name in self.params
+        assert name in self.params, f"The name({name}) is not in params"
         return self.params[name]
 
     def extract_parameters(self, program, scope=None):
@@ -2947,9 +3010,12 @@ class GraphProto:
         self.params = {}
         variables = program.global_block().vars
         for name in variables:
-            var = program.global_block().var(name)
             if name.endswith("feed") or name.endswith("fetch"):
                 continue
+            # This judgment will cause the PaddleInference model
+            # exported by PaddleSlim to skip some operators
+            # that need to be read in NHWC format.
+            var = program.global_block().var(name)
             if not var.persistable:
                 continue
             if isinstance(scope, dict):
@@ -3018,7 +3084,6 @@ class GraphProto:
             for op in block.ops:
                 if op.type == "fetch":
                     output_names.append(op.input("X")[0])
-
         outputs = [self.nodes[name] for name in output_names]
         outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
 

Reply via email to