This is an automated email from the ASF dual-hosted git repository. andrewzhaoluo pushed a commit to branch aluo/fix-qnn-batch-matmul in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 24baa3e8e2d4fc89422866e3a5d9ebfcac777b47 Author: Andrew Zhao Luo <[email protected]> AuthorDate: Wed Aug 3 17:12:12 2022 -0700 initial commit --- src/relay/qnn/op/batch_matmul.cc | 20 ++++++-- tests/python/relay/test_op_qnn_batch_matmul.py | 64 ++++++++++++++++++-------- 2 files changed, 60 insertions(+), 24 deletions(-) diff --git a/src/relay/qnn/op/batch_matmul.cc b/src/relay/qnn/op/batch_matmul.cc index 4b0bcacaca..be5a314e80 100644 --- a/src/relay/qnn/op/batch_matmul.cc +++ b/src/relay/qnn/op/batch_matmul.cc @@ -106,7 +106,9 @@ Expr BatchMatmulThirdTerm(const Expr& y_quantized_data, const Expr& x_zero_point auto reducemult = Multiply(x_zero_point, Sum(Cast(y_quantized_data, DataType::Int(32)), axes, true, false)); Array<Integer> newshape; - newshape = {1, 1, broadcast_dim_size}; + + // dimension of 0 in reshape copies old dimension size + newshape = {0, 1, broadcast_dim_size}; return Reshape(reducemult, newshape); } @@ -199,10 +201,18 @@ Expr QnnBatchMatmulCanonicalize(const Attrs& attrs, const Array<Expr>& new_args, } RELAY_REGISTER_OP("qnn.batch_matmul") - .describe(R"code(Applies a linear transformation: :math:`Z = XY`. -- **data**: quantized(int8, unit8) `(x1, x2, ..., xn, input_dim)` -- **weight**: quantized(int8, unit8) `(units, input_dim)` -- **out**: quantized(int32) `(x1, x2, ..., xn, units)`. + .describe(R"code(Compute batch matrix multiplication of `tensor_a` and `tensor_b`. + +Note we expect tensor_b to be transposed to copy the standard nn.batch_matmul conventions. + +.. math:: + + batch\_matmul(A, B)[i, :, :] = matmul(A[i, :, :], B[i, :, :]^T) + +- **data**: quantized(int8, unit8) `(i, m, k)` +- **weight**: quantized(int8, unit8) `(i, n, k)` +- **out**: quantized(int32) `(i, m, n)`. + )code" TVM_ADD_FILELINE) .set_attrs_type<BatchMatmulAttrs>() .set_num_inputs(6) diff --git a/tests/python/relay/test_op_qnn_batch_matmul.py b/tests/python/relay/test_op_qnn_batch_matmul.py index 91648aca3d..8e0d962352 100644 --- a/tests/python/relay/test_op_qnn_batch_matmul.py +++ b/tests/python/relay/test_op_qnn_batch_matmul.py @@ -71,9 +71,13 @@ def make_configuration( def make_int_configuration( - xzero_point_zero=True, yzero_point_zero=True, requantize_output=False, per_channel=False + xzero_point_zero=True, + yzero_point_zero=True, + requantize_output=False, + per_channel=False, + batch_size=1, ): - x_shape, y_shape, output_shape = (1, 4, 5), (1, 3, 5), (1, 4, 3) + x_shape, y_shape, output_shape = (batch_size, 4, 5), (batch_size, 3, 5), (batch_size, 4, 3) if xzero_point_zero == True: x_zero_point = 0 else: @@ -86,6 +90,7 @@ def make_int_configuration( in_dtype = "int8" out_dtype = "int32" if not requantize_output else "int8" + quantized_x_np = ( np.array( [ @@ -110,12 +115,16 @@ def make_int_configuration( 17, -21, ] - ) # sum = 3 + )[ # sum = 3 + np.newaxis, np.newaxis, : + ] + .repeat(batch_size, axis=1) .astype(in_dtype) .reshape(x_shape) ) quantized_y_np = ( - np.array([1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 1, 3, 5, 7, 9]) + np.array([1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 1, 3, 5, 7, 9])[np.newaxis, np.newaxis, :] + .repeat(batch_size, axis=1) .astype(in_dtype) .reshape(y_shape) ) @@ -143,8 +152,13 @@ def make_int_configuration( if requantize_output else None ) - - output = output.astype(out_dtype).reshape(output_shape) + # Outputs are for batch size 1, make batch size n version + output = ( + output[np.newaxis, np.newaxis, :] + .repeat(batch_size, axis=1) + .astype(out_dtype) + .reshape(output_shape) + ) return make_configuration( quantized_x=quantized_x_np, quantized_y=quantized_y_np, @@ -206,37 +220,49 @@ def qnn_batch_matmul_driver(test_configuration): def test_qnn_batch_matmul_xzp0_yzp0(): with TempOpAttr("qnn.batch_matmul", "FTVMQnnLegalize", legalize_qnn_batch_matmul): - - int32_output_params = make_int_configuration(xzero_point_zero=True, yzero_point_zero=True) - qnn_batch_matmul_driver(int32_output_params) + for batch_size in [1, 4, 7]: + int32_output_params = make_int_configuration( + xzero_point_zero=True, yzero_point_zero=True, batch_size=batch_size + ) + qnn_batch_matmul_driver(int32_output_params) def test_qnn_batch_matmul_xzp0(): with TempOpAttr("qnn.batch_matmul", "FTVMQnnLegalize", legalize_qnn_batch_matmul): - - int32_output_params = make_int_configuration(xzero_point_zero=True, yzero_point_zero=False) - qnn_batch_matmul_driver(int32_output_params) + for batch_size in [1, 4, 7]: + int32_output_params = make_int_configuration( + xzero_point_zero=True, yzero_point_zero=False, batch_size=batch_size + ) + qnn_batch_matmul_driver(int32_output_params) def test_qnn_batch_matmul_yzp0(): with TempOpAttr("qnn.batch_matmul", "FTVMQnnLegalize", legalize_qnn_batch_matmul): - int32_output_params = make_int_configuration(xzero_point_zero=False, yzero_point_zero=True) - qnn_batch_matmul_driver(int32_output_params) + for batch_size in [1, 4, 7]: + int32_output_params = make_int_configuration( + xzero_point_zero=False, yzero_point_zero=True, batch_size=batch_size + ) + qnn_batch_matmul_driver(int32_output_params) def test_qnn_batch_matmul(): with TempOpAttr("qnn.batch_matmul", "FTVMQnnLegalize", legalize_qnn_batch_matmul): + for batch_size in [1, 4, 7]: - int32_output_params = make_int_configuration(xzero_point_zero=False, yzero_point_zero=False) - qnn_batch_matmul_driver(int32_output_params) + int32_output_params = make_int_configuration( + xzero_point_zero=False, yzero_point_zero=False, batch_size=batch_size + ) + qnn_batch_matmul_driver(int32_output_params) def test_qnn_batch_matmul_with_requantized_output(): with TempOpAttr("qnn.dense", "FTVMQnnLegalize", legalize_qnn_batch_matmul): - - int8_requantized_output_params = make_int_configuration(requantize_output=True) - qnn_batch_matmul_driver(int8_requantized_output_params) + for batch_size in [1, 4, 7]: + int8_requantized_output_params = make_int_configuration( + requantize_output=True, batch_size=batch_size + ) + qnn_batch_matmul_driver(int8_requantized_output_params) if __name__ == "__main__":
