This is an automated email from the ASF dual-hosted git repository.

andrewzhaoluo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c7d7164c42 [QNN] support zero points as variable scalar for 
QnnBatchMatMul op (#13469)
c7d7164c42 is described below

commit c7d7164c421ead653d1a300a9610e8d9da14722b
Author: Valery Chernov <[email protected]>
AuthorDate: Thu Dec 1 21:18:57 2022 +0300

    [QNN] support zero points as variable scalar for QnnBatchMatMul op (#13469)
    
    * support zero points as variable scalar
    
    * lint fix
    
    * error logging was added for unsupported case when zero point is N-d tensor
    
    * fix misprinting
    
    * remove unnecessary TODOs
    
    Co-authored-by: Valery Chernov <[email protected]>
---
 src/relay/qnn/op/batch_matmul.cc | 78 ++++++++++++++++++++++++++--------------
 1 file changed, 52 insertions(+), 26 deletions(-)

diff --git a/src/relay/qnn/op/batch_matmul.cc b/src/relay/qnn/op/batch_matmul.cc
index be5a314e80..a948d9387d 100644
--- a/src/relay/qnn/op/batch_matmul.cc
+++ b/src/relay/qnn/op/batch_matmul.cc
@@ -96,20 +96,42 @@ Expr BatchMatmulFirstTerm(const Expr& quantized_x, const 
Expr& quantized_y,
 }
 
 Expr BatchMatmulSecondTerm(const Expr& x_quantized_data, const Expr& 
y_zero_point) {
-  Array<Integer> axes = {2};
-  return Multiply(y_zero_point, Sum(Cast(x_quantized_data, DataType::Int(32)), 
axes, true, false));
+  if (IsScalar(y_zero_point)) {
+    Array<Integer> axes = {2};
+    return Multiply(y_zero_point,
+                    Sum(Cast(x_quantized_data, DataType::Int(32)), axes, true, 
false));
+  } else {
+    LOG(FATAL) << "Tensor zero point (non-scalar) is not supported";
+    return Expr();
+  }
 }
 
 Expr BatchMatmulThirdTerm(const Expr& y_quantized_data, const Expr& 
x_zero_point,
                           int broadcast_dim_size) {
-  Array<Integer> axes = {2};
-  auto reducemult =
-      Multiply(x_zero_point, Sum(Cast(y_quantized_data, DataType::Int(32)), 
axes, true, false));
-  Array<Integer> newshape;
-
-  // dimension of 0 in reshape copies old dimension size
-  newshape = {0, 1, broadcast_dim_size};
-  return Reshape(reducemult, newshape);
+  if (IsScalar(x_zero_point)) {
+    Array<Integer> axes = {2};
+    auto reducemult =
+        Multiply(x_zero_point, Sum(Cast(y_quantized_data, DataType::Int(32)), 
axes, true, false));
+    Array<Integer> newshape;
+
+    // dimension of 0 in reshape copies old dimension size
+    newshape = {0, 1, broadcast_dim_size};
+    return Reshape(reducemult, newshape);
+  } else {
+    LOG(FATAL) << "Tensor zero point (non-scalar) is not supported";
+    return Expr();
+  }
+}
+
+Expr BatchMatmulFourthTerm(Expr x_zero_point, Expr y_zero_point, int 
reduction_dim_size) {
+  if (IsScalar(x_zero_point) && IsScalar(y_zero_point)) {
+    auto zero_point_mul = Multiply(x_zero_point, y_zero_point);
+    auto const_scale = MakeConstantScalar(DataType::Int(32), 
reduction_dim_size);
+    return Multiply(zero_point_mul, const_scale);
+  } else {
+    LOG(FATAL) << "Tensor zero point (non-scalar) is not supported";
+    return Expr();
+  }
 }
 
 Expr BatchMatmulFourthTerm(int x_zero_point_int, int y_zero_point_int, int 
reduction_dim_size) {
@@ -175,27 +197,31 @@ Expr QnnBatchMatmulCanonicalize(const Attrs& attrs, const 
Array<Expr>& new_args,
 
   const auto* qnn_batch_matmul_attrs = attrs.as<BatchMatmulAttrs>();
 
-  // Extract the integer zero points.
-  auto y_zero_point_int = GetScalarFromConstant<int>(y_zero_point);
-  auto x_zero_point_int = GetScalarFromConstant<int>(x_zero_point);
-
   // Get all the terms as described in the comments.
   auto term1 = BatchMatmulFirstTerm(quantized_x, quantized_y, 
qnn_batch_matmul_attrs);
   auto term2 = BatchMatmulSecondTerm(quantized_x, y_zero_point);
   auto term3 = BatchMatmulThirdTerm(quantized_y, x_zero_point, 
broadcast_dim_size);
-  auto term4 = BatchMatmulFourthTerm(x_zero_point_int, y_zero_point_int, 
reduction_dim_size);
-
-  // Combine those 4 terms depending on the zero points to get the best 
lowering.
-  if (x_zero_point_int == 0 && y_zero_point_int == 0) {
-    // term 2, 3 and 4 become zero.
-    return term1;
-  } else if (x_zero_point_int == 0 && y_zero_point_int != 0) {
-    // term 3 and term 4 become zero.
-    return Subtract(term1, term2);
-  } else if (x_zero_point_int != 0 && y_zero_point_int == 0) {
-    // term 2 and term 4 become zero.
-    return Subtract(term1, term3);
+
+  if (IsConstScalar(x_zero_point) && IsConstScalar(y_zero_point)) {
+    // Extract the integer zero points.
+    auto y_zero_point_int = GetScalarFromConstant<int>(y_zero_point);
+    auto x_zero_point_int = GetScalarFromConstant<int>(x_zero_point);
+    auto term4 = BatchMatmulFourthTerm(x_zero_point_int, y_zero_point_int, 
reduction_dim_size);
+    // Combine those 4 terms depending on the zero points to get the best 
lowering.
+    if (x_zero_point_int == 0 && y_zero_point_int == 0) {
+      // term 2, 3 and 4 become zero.
+      return term1;
+    } else if (x_zero_point_int == 0 && y_zero_point_int != 0) {
+      // term 3 and term 4 become zero.
+      return Subtract(term1, term2);
+    } else if (x_zero_point_int != 0 && y_zero_point_int == 0) {
+      // term 2 and term 4 become zero.
+      return Subtract(term1, term3);
+    } else {
+      return BatchMatmulCombineTerms(term1, term2, term3, term4);
+    }
   } else {
+    auto term4 = BatchMatmulFourthTerm(x_zero_point, y_zero_point, 
reduction_dim_size);
     return BatchMatmulCombineTerms(term1, term2, term3, term4);
   }
 }

Reply via email to