This is an automated email from the ASF dual-hosted git repository.

andrewzhaoluo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 8f664f50fc [QNN] Fix qnn.dequantize scale and zp shape (#10880)
8f664f50fc is described below

commit 8f664f50fce151d5dbd1f6361f1b73812cf6f922
Author: Sevin F. Varoglu <[email protected]>
AuthorDate: Tue Apr 5 15:18:34 2022 -0700

    [QNN] Fix qnn.dequantize scale and zp shape (#10880)
    
    * [QNN] Fix qnn.dequantize scale and zp shape
    
    * Rework
    
    * Add review feedback
---
 src/relay/qnn/op/dequantize.cc               | 15 +++++++++------
 tests/python/relay/test_op_qnn_dequantize.py | 14 ++++++++++++++
 2 files changed, 23 insertions(+), 6 deletions(-)

diff --git a/src/relay/qnn/op/dequantize.cc b/src/relay/qnn/op/dequantize.cc
index c843eb3f54..9a9c60d9ea 100644
--- a/src/relay/qnn/op/dequantize.cc
+++ b/src/relay/qnn/op/dequantize.cc
@@ -56,17 +56,20 @@ bool DequantizeRel(const Array<Type>& types, int 
num_inputs, const Attrs& attrs,
   auto rank = static_cast<int>(data->shape.size());
   axis = (axis < 0) ? ((rank > 0) ? data->shape.size() + axis : 0) : axis;
 
-  // If zero point and scale are scalar then axis doesnt matter.
-  bool scale_is_scalar = (types[1].as<TensorTypeNode>())->shape.size() == 0;
-  bool zp_is_scalar = (types[2].as<TensorTypeNode>())->shape.size() == 0;
-
-  if (!(scale_is_scalar && zp_is_scalar)) {
+  // If zero point and scale are scalar or have arbitrary rank with one 
element,
+  // then axis doesn't matter.
+  bool scale_is_scalar = (types[1].as<TensorTypeNode>())->shape.size() == 0 ||
+                         
get_const_int((types[1].as<TensorTypeNode>())->Size()) == 1;
+  bool zp_is_scalar = (types[2].as<TensorTypeNode>())->shape.size() == 0 ||
+                      get_const_int((types[2].as<TensorTypeNode>())->Size()) 
== 1;
+
+  if (!scale_is_scalar || !zp_is_scalar) {
     ICHECK_LT(axis, rank > 0 ? rank : 1) << "axis " << dequantize_attrs->axis 
<< " is out of range";
     ICHECK_GE(axis, 0) << "axis " << dequantize_attrs->axis << " is out of 
range";
   }
 
   PrimExpr axis_shape;
-  if (rank > 0) {
+  if (!scale_is_scalar || !zp_is_scalar) {
     axis_shape = data->shape[axis];
   } else {
     axis_shape = Integer(1);
diff --git a/tests/python/relay/test_op_qnn_dequantize.py 
b/tests/python/relay/test_op_qnn_dequantize.py
index 70ea05fe18..b332bd94f3 100644
--- a/tests/python/relay/test_op_qnn_dequantize.py
+++ b/tests/python/relay/test_op_qnn_dequantize.py
@@ -128,6 +128,20 @@ def test_channelwise_axis_0():
     )
 
 
+def test_per_tensor_vector_args():
+    data = np.array([0, 1, 2, 3, 4, 251, 252, 253, 254, 255]).astype("uint8")
+    output = np.array([-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 
64]).astype("float32")
+
+    quant_args = {
+        "in_zero_point": np.array([127]).astype("int32"),
+        "in_scale": np.array([0.5]).astype("float32"),
+    }
+
+    dequantize_test_driver(
+        in_dtype="uint8", quant_args=quant_args, in_data=data, 
verify_output_data=output, axis=-1
+    )
+
+
 def test_dynamic_dequantize():
     x = relay.var("x", shape=(1, 2, 3, 4), dtype="int8")
     scale_var = relay.var("scale", shape=(), dtype="float32")

Reply via email to