This is an automated email from the ASF dual-hosted git repository.

andrewzhaoluo pushed a commit to branch aluo/fix-qnn-batch-matmul
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 24baa3e8e2d4fc89422866e3a5d9ebfcac777b47
Author: Andrew Zhao Luo <[email protected]>
AuthorDate: Wed Aug 3 17:12:12 2022 -0700

    initial commit
---
 src/relay/qnn/op/batch_matmul.cc               | 20 ++++++--
 tests/python/relay/test_op_qnn_batch_matmul.py | 64 ++++++++++++++++++--------
 2 files changed, 60 insertions(+), 24 deletions(-)

diff --git a/src/relay/qnn/op/batch_matmul.cc b/src/relay/qnn/op/batch_matmul.cc
index 4b0bcacaca..be5a314e80 100644
--- a/src/relay/qnn/op/batch_matmul.cc
+++ b/src/relay/qnn/op/batch_matmul.cc
@@ -106,7 +106,9 @@ Expr BatchMatmulThirdTerm(const Expr& y_quantized_data, 
const Expr& x_zero_point
   auto reducemult =
       Multiply(x_zero_point, Sum(Cast(y_quantized_data, DataType::Int(32)), 
axes, true, false));
   Array<Integer> newshape;
-  newshape = {1, 1, broadcast_dim_size};
+
+  // dimension of 0 in reshape copies old dimension size
+  newshape = {0, 1, broadcast_dim_size};
   return Reshape(reducemult, newshape);
 }
 
@@ -199,10 +201,18 @@ Expr QnnBatchMatmulCanonicalize(const Attrs& attrs, const 
Array<Expr>& new_args,
 }
 
 RELAY_REGISTER_OP("qnn.batch_matmul")
-    .describe(R"code(Applies a linear transformation: :math:`Z = XY`.
-- **data**: quantized(int8, unit8) `(x1, x2, ..., xn, input_dim)`
-- **weight**: quantized(int8, unit8) `(units, input_dim)`
-- **out**: quantized(int32) `(x1, x2, ..., xn, units)`.
+    .describe(R"code(Compute batch matrix multiplication of `tensor_a` and 
`tensor_b`.
+
+Note we expect tensor_b to be transposed to copy the standard nn.batch_matmul 
conventions.
+
+.. math::
+
+  batch\_matmul(A, B)[i, :, :] = matmul(A[i, :, :], B[i, :, :]^T)
+
+- **data**: quantized(int8, unit8) `(i, m, k)`
+- **weight**: quantized(int8, unit8) `(i, n, k)`
+- **out**: quantized(int32) `(i, m, n)`.
+
 )code" TVM_ADD_FILELINE)
     .set_attrs_type<BatchMatmulAttrs>()
     .set_num_inputs(6)
diff --git a/tests/python/relay/test_op_qnn_batch_matmul.py 
b/tests/python/relay/test_op_qnn_batch_matmul.py
index 91648aca3d..8e0d962352 100644
--- a/tests/python/relay/test_op_qnn_batch_matmul.py
+++ b/tests/python/relay/test_op_qnn_batch_matmul.py
@@ -71,9 +71,13 @@ def make_configuration(
 
 
 def make_int_configuration(
-    xzero_point_zero=True, yzero_point_zero=True, requantize_output=False, 
per_channel=False
+    xzero_point_zero=True,
+    yzero_point_zero=True,
+    requantize_output=False,
+    per_channel=False,
+    batch_size=1,
 ):
-    x_shape, y_shape, output_shape = (1, 4, 5), (1, 3, 5), (1, 4, 3)
+    x_shape, y_shape, output_shape = (batch_size, 4, 5), (batch_size, 3, 5), 
(batch_size, 4, 3)
     if xzero_point_zero == True:
         x_zero_point = 0
     else:
@@ -86,6 +90,7 @@ def make_int_configuration(
 
     in_dtype = "int8"
     out_dtype = "int32" if not requantize_output else "int8"
+
     quantized_x_np = (
         np.array(
             [
@@ -110,12 +115,16 @@ def make_int_configuration(
                 17,
                 -21,
             ]
-        )  # sum = 3
+        )[  # sum = 3
+            np.newaxis, np.newaxis, :
+        ]
+        .repeat(batch_size, axis=1)
         .astype(in_dtype)
         .reshape(x_shape)
     )
     quantized_y_np = (
-        np.array([1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 1, 3, 5, 7, 9])
+        np.array([1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 1, 3, 5, 7, 
9])[np.newaxis, np.newaxis, :]
+        .repeat(batch_size, axis=1)
         .astype(in_dtype)
         .reshape(y_shape)
     )
@@ -143,8 +152,13 @@ def make_int_configuration(
         if requantize_output
         else None
     )
-
-    output = output.astype(out_dtype).reshape(output_shape)
+    # Outputs are for batch size 1, make batch size n version
+    output = (
+        output[np.newaxis, np.newaxis, :]
+        .repeat(batch_size, axis=1)
+        .astype(out_dtype)
+        .reshape(output_shape)
+    )
     return make_configuration(
         quantized_x=quantized_x_np,
         quantized_y=quantized_y_np,
@@ -206,37 +220,49 @@ def qnn_batch_matmul_driver(test_configuration):
 
 def test_qnn_batch_matmul_xzp0_yzp0():
     with TempOpAttr("qnn.batch_matmul", "FTVMQnnLegalize", 
legalize_qnn_batch_matmul):
-
-        int32_output_params = make_int_configuration(xzero_point_zero=True, 
yzero_point_zero=True)
-        qnn_batch_matmul_driver(int32_output_params)
+        for batch_size in [1, 4, 7]:
+            int32_output_params = make_int_configuration(
+                xzero_point_zero=True, yzero_point_zero=True, 
batch_size=batch_size
+            )
+            qnn_batch_matmul_driver(int32_output_params)
 
 
 def test_qnn_batch_matmul_xzp0():
     with TempOpAttr("qnn.batch_matmul", "FTVMQnnLegalize", 
legalize_qnn_batch_matmul):
-
-        int32_output_params = make_int_configuration(xzero_point_zero=True, 
yzero_point_zero=False)
-        qnn_batch_matmul_driver(int32_output_params)
+        for batch_size in [1, 4, 7]:
+            int32_output_params = make_int_configuration(
+                xzero_point_zero=True, yzero_point_zero=False, 
batch_size=batch_size
+            )
+            qnn_batch_matmul_driver(int32_output_params)
 
 
 def test_qnn_batch_matmul_yzp0():
     with TempOpAttr("qnn.batch_matmul", "FTVMQnnLegalize", 
legalize_qnn_batch_matmul):
 
-        int32_output_params = make_int_configuration(xzero_point_zero=False, 
yzero_point_zero=True)
-        qnn_batch_matmul_driver(int32_output_params)
+        for batch_size in [1, 4, 7]:
+            int32_output_params = make_int_configuration(
+                xzero_point_zero=False, yzero_point_zero=True, 
batch_size=batch_size
+            )
+            qnn_batch_matmul_driver(int32_output_params)
 
 
 def test_qnn_batch_matmul():
     with TempOpAttr("qnn.batch_matmul", "FTVMQnnLegalize", 
legalize_qnn_batch_matmul):
+        for batch_size in [1, 4, 7]:
 
-        int32_output_params = make_int_configuration(xzero_point_zero=False, 
yzero_point_zero=False)
-        qnn_batch_matmul_driver(int32_output_params)
+            int32_output_params = make_int_configuration(
+                xzero_point_zero=False, yzero_point_zero=False, 
batch_size=batch_size
+            )
+            qnn_batch_matmul_driver(int32_output_params)
 
 
 def test_qnn_batch_matmul_with_requantized_output():
     with TempOpAttr("qnn.dense", "FTVMQnnLegalize", legalize_qnn_batch_matmul):
-
-        int8_requantized_output_params = 
make_int_configuration(requantize_output=True)
-        qnn_batch_matmul_driver(int8_requantized_output_params)
+        for batch_size in [1, 4, 7]:
+            int8_requantized_output_params = make_int_configuration(
+                requantize_output=True, batch_size=batch_size
+            )
+            qnn_batch_matmul_driver(int8_requantized_output_params)
 
 
 if __name__ == "__main__":

Reply via email to