This is an automated email from the ASF dual-hosted git repository.
jiangjiajun pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e005f8574c [Frontend][PaddlePaddle] PaddlePaddle model with NCHW data
format that supports quantization (#16651)
e005f8574c is described below
commit e005f8574ca0208d75e9fd0790caa1a06d95af94
Author: Zheng-Bicheng <[email protected]>
AuthorDate: Thu Mar 7 17:39:00 2024 +0800
[Frontend][PaddlePaddle] PaddlePaddle model with NCHW data format that
supports quantization (#16651)
* support conv2d when data_format is NHWC
* modify the annotation
* Do not convert input data when processing quantization conv_2d nodes
* Fix code formatting issues
* fixed error code format
* update dequantize and quantize
* fixed bug when model is fp32 model
* update dequantize and quantize
* update for paddle quantize model when format is NCHW
---
python/tvm/relay/frontend/paddlepaddle.py | 83 +++++++++++++++++++++++++++----
1 file changed, 74 insertions(+), 9 deletions(-)
diff --git a/python/tvm/relay/frontend/paddlepaddle.py
b/python/tvm/relay/frontend/paddlepaddle.py
index bb72d30352..b00bb43d46 100755
--- a/python/tvm/relay/frontend/paddlepaddle.py
+++ b/python/tvm/relay/frontend/paddlepaddle.py
@@ -31,6 +31,7 @@ from .. import expr as _expr
from .. import function as _function
from .. import ty as _ty
from .. import op as _op
+from .. import qnn as _qnn
from .common import (
autopad,
fold_constant,
@@ -314,9 +315,9 @@ def convert_conv2d(g, op, block):
strides = op.attr("strides")
kernel = g.get_node(op.input("Filter")[0])
- kernel_layout = "OIHW"
input_x = g.get_node(op.input("Input")[0])
data_layout = op.attr("data_format")
+ kernel_layout = "OIHW" if data_layout == "NCHW" else "HWIO"
out_channels, _, k_h, k_w = infer_shape(kernel)
if padding_algorithm == "VALID":
paddings = [0, 0]
@@ -336,9 +337,15 @@ def convert_conv2d(g, op, block):
msg = f'Value {padding_algorithm} in attribute "padding" of operator
Conv is not "valid."'
raise tvm.error.OpAttributeInvalid(msg)
- if data_layout == "NHWC":
- kernel_layout = "HWIO"
- # PaddlePaddle wieght layout is "OIHW", tvm need "HWIO" when op
data_format is "NHWC".
+ is_quantized = op.has_attr("quantization_type")
+ # PaddlePaddle wieght layout is "OIHW", tvm need "HWIO" when op
data_format is "NHWC".
+ # There are two situations when converting the data format of weights:
+ # 1 Conv_2d is not a quantified OP, its weight information is the weights
themselves.
+ # We directly convert the weight information when processing conv_2d.
+ # 2 Conv_2d is a quantified OP, and its weight information is the output of
+ # the quantize_linear operator. Therefore, the weight information needs
to be
+ # transformed when processing the quantize_linear operator.
+ if (not is_quantized) and (data_layout == "NHWC"):
kernel_data = g.get_params(op.input("Filter")[0])
kernel_data = kernel_data.asnumpy()
kernel_data = kernel_data.transpose((2, 3, 1, 0))
@@ -1626,7 +1633,7 @@ def convert_pool3d(g, op, block):
raise tvm.error.OpAttributeInvalid(msg.format(padding_algorithm))
# handle with special case
- # while kernel size less than input size
+ # while kernel size more than input size
# shrink kernel size to input size
if (
not isinstance(in_h, _op.Expr)
@@ -1812,6 +1819,59 @@ def convert_roi_align(g, op, block):
g.add_node(op.output("Out")[0], out)
+def convert_dequantize_linear(g, op, block):
+ """Operator converter for dequantize_linear."""
+
+ data_node_name = op.input("X")[0]
+ data_node = g.get_node(data_node_name)
+
+ # paddle_scale = tvm_scale * 127
+ paddle_quantize_scale = g.get_params(op.input("Scale")[0]).asnumpy()
+ tvm_quantize_scale = paddle_quantize_scale / 127.0
+
+ tvm_quantize_zp = g.get_params(op.input("ZeroPoint")[0]).asnumpy()
+
+ tvm_quantize_axis = op.attr("quant_axis")
+ if tvm_quantize_axis == -1:
+ tvm_quantize_axis = 0
+
+ if len(infer_shape(data_node)) < 2:
+ tvm_quantize_axis = 0
+
+ out = _qnn.op.dequantize(
+ data=data_node,
+ input_scale=_op.const(tvm_quantize_scale, "float32"),
+ input_zero_point=_op.const(tvm_quantize_zp, "int32"),
+ axis=tvm_quantize_axis,
+ )
+ g.add_node(op.output("Y")[0], out)
+
+
+def convert_quantize_linear(g, op, block):
+ """Operator converter for dequantize_linear."""
+
+ data_node_name = op.input("X")[0]
+ data_node = g.get_node(data_node_name)
+
+ # paddle_scale = tvm_scale * 127
+ paddle_quantize_scale = g.get_params(op.input("Scale")[0]).asnumpy()
+ tvm_quantize_scale = paddle_quantize_scale / 127.0
+
+ tvm_quantize_zp = g.get_params(op.input("ZeroPoint")[0]).asnumpy()
+ tvm_quantize_axis = op.attr("quant_axis")
+
+ if tvm_quantize_axis == -1:
+ tvm_quantize_axis = 0
+
+ out = _qnn.op.quantize(
+ data=data_node,
+ output_scale=_op.const(tvm_quantize_scale, "float32"),
+ output_zero_point=_op.const(tvm_quantize_zp, "int32"),
+ axis=tvm_quantize_axis,
+ )
+ g.add_node(op.output("Y")[0], out)
+
+
def convert_rnn(g, op, block):
"""Operator converter for rnn."""
@@ -2386,11 +2446,11 @@ def convert_slice(g, op, block):
def convert_softmax(g, op, block):
"""Operator converter for softmax."""
+ x = g.get_node(op.input("X")[0])
axis = op.attr("axis")
input_shape = block.var(op.input("X")[0]).shape
if axis < 0:
axis = len(input_shape) + axis
- x = g.get_node(op.input("X")[0])
m = _op.max(x, axis, keepdims=True)
e = _op.exp(x - m)
out = e / _op.sum(e, axis, keepdims=True)
@@ -2905,6 +2965,9 @@ _convert_map = {
"unstack": convert_unstack,
"where": convert_where,
"where_index": convert_where_index,
+ # Quantized
+ "dequantize_linear": convert_dequantize_linear,
+ "quantize_linear": convert_quantize_linear,
}
@@ -2938,7 +3001,7 @@ class GraphProto:
if name is None:
return self.params
- assert name in self.params
+ assert name in self.params, f"The name({name}) is not in params"
return self.params[name]
def extract_parameters(self, program, scope=None):
@@ -2947,9 +3010,12 @@ class GraphProto:
self.params = {}
variables = program.global_block().vars
for name in variables:
- var = program.global_block().var(name)
if name.endswith("feed") or name.endswith("fetch"):
continue
+ # This judgment will cause the PaddleInference model
+ # exported by PaddleSlim to skip some operators
+ # that need to be read in NHWC format.
+ var = program.global_block().var(name)
if not var.persistable:
continue
if isinstance(scope, dict):
@@ -3018,7 +3084,6 @@ class GraphProto:
for op in block.ops:
if op.type == "fetch":
output_names.append(op.input("X")[0])
-
outputs = [self.nodes[name] for name in output_names]
outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)