This is an automated email from the ASF dual-hosted git repository.
andrewzhaoluo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 8f664f50fc [QNN] Fix qnn.dequantize scale and zp shape (#10880)
8f664f50fc is described below
commit 8f664f50fce151d5dbd1f6361f1b73812cf6f922
Author: Sevin F. Varoglu <[email protected]>
AuthorDate: Tue Apr 5 15:18:34 2022 -0700
[QNN] Fix qnn.dequantize scale and zp shape (#10880)
* [QNN] Fix qnn.dequantize scale and zp shape
* Rework
* Add review feedback
---
src/relay/qnn/op/dequantize.cc | 15 +++++++++------
tests/python/relay/test_op_qnn_dequantize.py | 14 ++++++++++++++
2 files changed, 23 insertions(+), 6 deletions(-)
diff --git a/src/relay/qnn/op/dequantize.cc b/src/relay/qnn/op/dequantize.cc
index c843eb3f54..9a9c60d9ea 100644
--- a/src/relay/qnn/op/dequantize.cc
+++ b/src/relay/qnn/op/dequantize.cc
@@ -56,17 +56,20 @@ bool DequantizeRel(const Array<Type>& types, int
num_inputs, const Attrs& attrs,
auto rank = static_cast<int>(data->shape.size());
axis = (axis < 0) ? ((rank > 0) ? data->shape.size() + axis : 0) : axis;
- // If zero point and scale are scalar then axis doesnt matter.
- bool scale_is_scalar = (types[1].as<TensorTypeNode>())->shape.size() == 0;
- bool zp_is_scalar = (types[2].as<TensorTypeNode>())->shape.size() == 0;
-
- if (!(scale_is_scalar && zp_is_scalar)) {
+ // If zero point and scale are scalar or have arbitrary rank with one
element,
+ // then axis doesn't matter.
+ bool scale_is_scalar = (types[1].as<TensorTypeNode>())->shape.size() == 0 ||
+
get_const_int((types[1].as<TensorTypeNode>())->Size()) == 1;
+ bool zp_is_scalar = (types[2].as<TensorTypeNode>())->shape.size() == 0 ||
+ get_const_int((types[2].as<TensorTypeNode>())->Size())
== 1;
+
+ if (!scale_is_scalar || !zp_is_scalar) {
ICHECK_LT(axis, rank > 0 ? rank : 1) << "axis " << dequantize_attrs->axis
<< " is out of range";
ICHECK_GE(axis, 0) << "axis " << dequantize_attrs->axis << " is out of
range";
}
PrimExpr axis_shape;
- if (rank > 0) {
+ if (!scale_is_scalar || !zp_is_scalar) {
axis_shape = data->shape[axis];
} else {
axis_shape = Integer(1);
diff --git a/tests/python/relay/test_op_qnn_dequantize.py
b/tests/python/relay/test_op_qnn_dequantize.py
index 70ea05fe18..b332bd94f3 100644
--- a/tests/python/relay/test_op_qnn_dequantize.py
+++ b/tests/python/relay/test_op_qnn_dequantize.py
@@ -128,6 +128,20 @@ def test_channelwise_axis_0():
)
+def test_per_tensor_vector_args():
+ data = np.array([0, 1, 2, 3, 4, 251, 252, 253, 254, 255]).astype("uint8")
+ output = np.array([-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5,
64]).astype("float32")
+
+ quant_args = {
+ "in_zero_point": np.array([127]).astype("int32"),
+ "in_scale": np.array([0.5]).astype("float32"),
+ }
+
+ dequantize_test_driver(
+ in_dtype="uint8", quant_args=quant_args, in_data=data,
verify_output_data=output, axis=-1
+ )
+
+
def test_dynamic_dequantize():
x = relay.var("x", shape=(1, 2, 3, 4), dtype="int8")
scale_var = relay.var("scale", shape=(), dtype="float32")