This is an automated email from the ASF dual-hosted git repository.
andrewzhaoluo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c7d7164c42 [QNN] support zero points as variable scalar for
QnnBatchMatMul op (#13469)
c7d7164c42 is described below
commit c7d7164c421ead653d1a300a9610e8d9da14722b
Author: Valery Chernov <[email protected]>
AuthorDate: Thu Dec 1 21:18:57 2022 +0300
[QNN] support zero points as variable scalar for QnnBatchMatMul op (#13469)
* support zero points as variable scalar
* lint fix
* error logging was added for unsupported case when zero point is N-d tensor
* fix misprinting
* remove unnecessary TODOs
Co-authored-by: Valery Chernov <[email protected]>
---
src/relay/qnn/op/batch_matmul.cc | 78 ++++++++++++++++++++++++++--------------
1 file changed, 52 insertions(+), 26 deletions(-)
diff --git a/src/relay/qnn/op/batch_matmul.cc b/src/relay/qnn/op/batch_matmul.cc
index be5a314e80..a948d9387d 100644
--- a/src/relay/qnn/op/batch_matmul.cc
+++ b/src/relay/qnn/op/batch_matmul.cc
@@ -96,20 +96,42 @@ Expr BatchMatmulFirstTerm(const Expr& quantized_x, const
Expr& quantized_y,
}
Expr BatchMatmulSecondTerm(const Expr& x_quantized_data, const Expr&
y_zero_point) {
- Array<Integer> axes = {2};
- return Multiply(y_zero_point, Sum(Cast(x_quantized_data, DataType::Int(32)),
axes, true, false));
+ if (IsScalar(y_zero_point)) {
+ Array<Integer> axes = {2};
+ return Multiply(y_zero_point,
+ Sum(Cast(x_quantized_data, DataType::Int(32)), axes, true,
false));
+ } else {
+ LOG(FATAL) << "Tensor zero point (non-scalar) is not supported";
+ return Expr();
+ }
}
Expr BatchMatmulThirdTerm(const Expr& y_quantized_data, const Expr&
x_zero_point,
int broadcast_dim_size) {
- Array<Integer> axes = {2};
- auto reducemult =
- Multiply(x_zero_point, Sum(Cast(y_quantized_data, DataType::Int(32)),
axes, true, false));
- Array<Integer> newshape;
-
- // dimension of 0 in reshape copies old dimension size
- newshape = {0, 1, broadcast_dim_size};
- return Reshape(reducemult, newshape);
+ if (IsScalar(x_zero_point)) {
+ Array<Integer> axes = {2};
+ auto reducemult =
+ Multiply(x_zero_point, Sum(Cast(y_quantized_data, DataType::Int(32)),
axes, true, false));
+ Array<Integer> newshape;
+
+ // dimension of 0 in reshape copies old dimension size
+ newshape = {0, 1, broadcast_dim_size};
+ return Reshape(reducemult, newshape);
+ } else {
+ LOG(FATAL) << "Tensor zero point (non-scalar) is not supported";
+ return Expr();
+ }
+}
+
+Expr BatchMatmulFourthTerm(Expr x_zero_point, Expr y_zero_point, int
reduction_dim_size) {
+ if (IsScalar(x_zero_point) && IsScalar(y_zero_point)) {
+ auto zero_point_mul = Multiply(x_zero_point, y_zero_point);
+ auto const_scale = MakeConstantScalar(DataType::Int(32),
reduction_dim_size);
+ return Multiply(zero_point_mul, const_scale);
+ } else {
+ LOG(FATAL) << "Tensor zero point (non-scalar) is not supported";
+ return Expr();
+ }
}
Expr BatchMatmulFourthTerm(int x_zero_point_int, int y_zero_point_int, int
reduction_dim_size) {
@@ -175,27 +197,31 @@ Expr QnnBatchMatmulCanonicalize(const Attrs& attrs, const
Array<Expr>& new_args,
const auto* qnn_batch_matmul_attrs = attrs.as<BatchMatmulAttrs>();
- // Extract the integer zero points.
- auto y_zero_point_int = GetScalarFromConstant<int>(y_zero_point);
- auto x_zero_point_int = GetScalarFromConstant<int>(x_zero_point);
-
// Get all the terms as described in the comments.
auto term1 = BatchMatmulFirstTerm(quantized_x, quantized_y,
qnn_batch_matmul_attrs);
auto term2 = BatchMatmulSecondTerm(quantized_x, y_zero_point);
auto term3 = BatchMatmulThirdTerm(quantized_y, x_zero_point,
broadcast_dim_size);
- auto term4 = BatchMatmulFourthTerm(x_zero_point_int, y_zero_point_int,
reduction_dim_size);
-
- // Combine those 4 terms depending on the zero points to get the best
lowering.
- if (x_zero_point_int == 0 && y_zero_point_int == 0) {
- // term 2, 3 and 4 become zero.
- return term1;
- } else if (x_zero_point_int == 0 && y_zero_point_int != 0) {
- // term 3 and term 4 become zero.
- return Subtract(term1, term2);
- } else if (x_zero_point_int != 0 && y_zero_point_int == 0) {
- // term 2 and term 4 become zero.
- return Subtract(term1, term3);
+
+ if (IsConstScalar(x_zero_point) && IsConstScalar(y_zero_point)) {
+ // Extract the integer zero points.
+ auto y_zero_point_int = GetScalarFromConstant<int>(y_zero_point);
+ auto x_zero_point_int = GetScalarFromConstant<int>(x_zero_point);
+ auto term4 = BatchMatmulFourthTerm(x_zero_point_int, y_zero_point_int,
reduction_dim_size);
+ // Combine those 4 terms depending on the zero points to get the best
lowering.
+ if (x_zero_point_int == 0 && y_zero_point_int == 0) {
+ // term 2, 3 and 4 become zero.
+ return term1;
+ } else if (x_zero_point_int == 0 && y_zero_point_int != 0) {
+ // term 3 and term 4 become zero.
+ return Subtract(term1, term2);
+ } else if (x_zero_point_int != 0 && y_zero_point_int == 0) {
+ // term 2 and term 4 become zero.
+ return Subtract(term1, term3);
+ } else {
+ return BatchMatmulCombineTerms(term1, term2, term3, term4);
+ }
} else {
+ auto term4 = BatchMatmulFourthTerm(x_zero_point, y_zero_point,
reduction_dim_size);
return BatchMatmulCombineTerms(term1, term2, term3, term4);
}
}