ymwangg commented on a change in pull request #8440:
URL: https://github.com/apache/tvm/pull/8440#discussion_r668122649



##########
File path: python/tvm/relay/frontend/onnx.py
##########
@@ -670,68 +670,67 @@ class MatMul(OnnxOpConverter):
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
         assert len(inputs) == 2, "MatMul op take 2 inputs, {} 
given".format(len(inputs))
+        inputs_0 = inputs[0]
+        inputs_1 = inputs[1]
+
         # Need to check input shape as batch matmul must be supported.
-        a_shape = shape_of(inputs[0])
-        a_rank = infer_shape(a_shape)[0]
-        b_shape = shape_of(inputs[1])
-        b_rank = infer_shape(b_shape)[0]
-        # When performing a batch matmul, we need to properly handle N-dim 
shapes.
-        if a_rank > 2 or b_rank > 2:
+        a_shape = infer_shape(inputs_0)
+        b_shape = infer_shape(inputs_1)
 
-            def flatten_to_3d(x, x_shape):
-                ndims = infer_shape(x_shape)[0]
-                newshape = _op.concatenate(
-                    [
-                        _expr.const([-1], 
dtype=infer_type(x_shape).checked_type.dtype),
-                        _op.strided_slice(x_shape, [ndims - 2], [ndims]),
-                    ],
-                    0,
-                )
-                out = _op.reshape(x, fold_constant(newshape))
-                return out
+        # When performing a batch matmul, we need to properly handle N-dim 
shapes.
+        if len(a_shape) > 2 and len(b_shape) > 2:
+            # Convert a into a 3 dimensional tensors.
+            need_reshape_output = False
+            if len(a_shape) != 3:
+                a = _op.reshape(inputs_0, [-1, a_shape[-2], a_shape[-1]])
+                need_reshape_output = True
+            else:
+                a = inputs_0
 
-            # Convert a and b into 3 dimensional tensors.
-            a = flatten_to_3d(inputs[0], a_shape)
-            b = flatten_to_3d(inputs[1], b_shape)
             # Transpose matrix dimensions of b.
-            b = _op.transpose(b, [0, 2, 1])
+            trans_axes = list(range(len(b_shape)))
+            trans_axes[-2], trans_axes[-1] = trans_axes[-1], trans_axes[-2]
+            b = _op.transpose(inputs_1, trans_axes)
+
+            # Convert b into a 3 dimensional tensor. Note that the last two 
dimensions
+            # are transposed.
+            if len(b_shape) != 3:
+                b = _op.reshape(b, [-1, b_shape[-1], b_shape[-2]])
+
             # Perform a batch matmul.
             output = _op.nn.batch_matmul(a, b)
-            # Determine the output batch dimension.
-            if a_rank > b_rank:
-                out_batch = _op.strided_slice(a_shape, [0], [a_rank - 2])
-            elif a_rank < b_rank:
-                out_batch = _op.strided_slice(b_shape, [0], [b_rank - 2])
-            # If its unclear how broadcasting should be applied, the output
-            # shape is determined by choosing the maximum value from each 
input.
-            else:
-                out_batch = _op.concatenate(
-                    [
-                        _op.maximum(
-                            _op.strided_slice(a_shape, [i], [i + 1]),
-                            _op.strided_slice(b_shape, [i], [i + 1]),
-                        )
-                        for i in range(a_rank - 2)
-                    ],
-                    0,
-                )
+
             # Reshape output to original dimensions.
-            final_shape = _op.concatenate(
-                [
-                    out_batch,
-                    _op.strided_slice(
-                        a_shape, [infer_shape(a_shape)[0] - 2], 
[infer_shape(a_shape)[0] - 1]
-                    ),
-                    _op.strided_slice(
-                        b_shape, [infer_shape(b_shape)[0] - 1], 
[infer_shape(b_shape)[0]]
-                    ),
-                ],
-                0,
+            if need_reshape_output:
+                return _op.reshape(output, [*a_shape[:-2], a_shape[-2], 
b_shape[-1]])
+            return output
+        elif len(a_shape) > 2:
+            inputs_0 = _op.reshape(inputs_0, [-1, a_shape[-1]])

Review comment:
       thanks

##########
File path: python/tvm/relay/frontend/onnx.py
##########
@@ -670,68 +670,67 @@ class MatMul(OnnxOpConverter):
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
         assert len(inputs) == 2, "MatMul op take 2 inputs, {} 
given".format(len(inputs))
+        inputs_0 = inputs[0]
+        inputs_1 = inputs[1]
+
         # Need to check input shape as batch matmul must be supported.
-        a_shape = shape_of(inputs[0])
-        a_rank = infer_shape(a_shape)[0]
-        b_shape = shape_of(inputs[1])
-        b_rank = infer_shape(b_shape)[0]
-        # When performing a batch matmul, we need to properly handle N-dim 
shapes.
-        if a_rank > 2 or b_rank > 2:
+        a_shape = infer_shape(inputs_0)
+        b_shape = infer_shape(inputs_1)
 
-            def flatten_to_3d(x, x_shape):
-                ndims = infer_shape(x_shape)[0]
-                newshape = _op.concatenate(
-                    [
-                        _expr.const([-1], 
dtype=infer_type(x_shape).checked_type.dtype),
-                        _op.strided_slice(x_shape, [ndims - 2], [ndims]),
-                    ],
-                    0,
-                )
-                out = _op.reshape(x, fold_constant(newshape))
-                return out
+        # When performing a batch matmul, we need to properly handle N-dim 
shapes.
+        if len(a_shape) > 2 and len(b_shape) > 2:
+            # Convert a into a 3 dimensional tensors.
+            need_reshape_output = False
+            if len(a_shape) != 3:
+                a = _op.reshape(inputs_0, [-1, a_shape[-2], a_shape[-1]])
+                need_reshape_output = True
+            else:
+                a = inputs_0
 
-            # Convert a and b into 3 dimensional tensors.
-            a = flatten_to_3d(inputs[0], a_shape)
-            b = flatten_to_3d(inputs[1], b_shape)
             # Transpose matrix dimensions of b.
-            b = _op.transpose(b, [0, 2, 1])
+            trans_axes = list(range(len(b_shape)))
+            trans_axes[-2], trans_axes[-1] = trans_axes[-1], trans_axes[-2]
+            b = _op.transpose(inputs_1, trans_axes)
+
+            # Convert b into a 3 dimensional tensor. Note that the last two 
dimensions
+            # are transposed.
+            if len(b_shape) != 3:
+                b = _op.reshape(b, [-1, b_shape[-1], b_shape[-2]])
+
             # Perform a batch matmul.
             output = _op.nn.batch_matmul(a, b)
-            # Determine the output batch dimension.
-            if a_rank > b_rank:
-                out_batch = _op.strided_slice(a_shape, [0], [a_rank - 2])
-            elif a_rank < b_rank:
-                out_batch = _op.strided_slice(b_shape, [0], [b_rank - 2])
-            # If its unclear how broadcasting should be applied, the output
-            # shape is determined by choosing the maximum value from each 
input.
-            else:
-                out_batch = _op.concatenate(
-                    [
-                        _op.maximum(
-                            _op.strided_slice(a_shape, [i], [i + 1]),
-                            _op.strided_slice(b_shape, [i], [i + 1]),
-                        )
-                        for i in range(a_rank - 2)
-                    ],
-                    0,
-                )
+
             # Reshape output to original dimensions.
-            final_shape = _op.concatenate(
-                [
-                    out_batch,
-                    _op.strided_slice(
-                        a_shape, [infer_shape(a_shape)[0] - 2], 
[infer_shape(a_shape)[0] - 1]
-                    ),
-                    _op.strided_slice(
-                        b_shape, [infer_shape(b_shape)[0] - 1], 
[infer_shape(b_shape)[0]]
-                    ),
-                ],
-                0,
+            if need_reshape_output:
+                return _op.reshape(output, [*a_shape[:-2], a_shape[-2], 
b_shape[-1]])
+            return output
+        elif len(a_shape) > 2:
+            inputs_0 = _op.reshape(inputs_0, [-1, a_shape[-1]])

Review comment:
       Thanks!




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to