ymwangg commented on a change in pull request #8440:
URL: https://github.com/apache/tvm/pull/8440#discussion_r667278362



##########
File path: python/tvm/relay/frontend/onnx.py
##########
@@ -670,68 +670,67 @@ class MatMul(OnnxOpConverter):
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
         assert len(inputs) == 2, "MatMul op take 2 inputs, {} 
given".format(len(inputs))
+        inputs_0 = inputs[0]
+        inputs_1 = inputs[1]
+
         # Need to check input shape as batch matmul must be supported.
-        a_shape = shape_of(inputs[0])
-        a_rank = infer_shape(a_shape)[0]
-        b_shape = shape_of(inputs[1])
-        b_rank = infer_shape(b_shape)[0]
-        # When performing a batch matmul, we need to properly handle N-dim 
shapes.
-        if a_rank > 2 or b_rank > 2:
+        a_shape = infer_shape(inputs_0)
+        b_shape = infer_shape(inputs_1)
 
-            def flatten_to_3d(x, x_shape):
-                ndims = infer_shape(x_shape)[0]
-                newshape = _op.concatenate(
-                    [
-                        _expr.const([-1], 
dtype=infer_type(x_shape).checked_type.dtype),
-                        _op.strided_slice(x_shape, [ndims - 2], [ndims]),
-                    ],
-                    0,
-                )
-                out = _op.reshape(x, fold_constant(newshape))
-                return out
+        # When performing a batch matmul, we need to properly handle N-dim 
shapes.
+        if len(a_shape) > 2 and len(b_shape) > 2:
+            # Convert a into a 3 dimensional tensors.
+            need_reshape_output = False
+            if len(a_shape) != 3:
+                a = _op.reshape(inputs_0, [-1, a_shape[-2], a_shape[-1]])
+                need_reshape_output = True
+            else:
+                a = inputs_0
 
-            # Convert a and b into 3 dimensional tensors.
-            a = flatten_to_3d(inputs[0], a_shape)
-            b = flatten_to_3d(inputs[1], b_shape)
             # Transpose matrix dimensions of b.
-            b = _op.transpose(b, [0, 2, 1])
+            trans_axes = list(range(len(b_shape)))
+            trans_axes[-2], trans_axes[-1] = trans_axes[-1], trans_axes[-2]
+            b = _op.transpose(inputs_1, trans_axes)
+
+            # Convert b into a 3 dimensional tensor. Note that the last two 
dimensions
+            # are transposed.
+            if len(b_shape) != 3:
+                b = _op.reshape(b, [-1, b_shape[-1], b_shape[-2]])
+
             # Perform a batch matmul.
             output = _op.nn.batch_matmul(a, b)
-            # Determine the output batch dimension.
-            if a_rank > b_rank:
-                out_batch = _op.strided_slice(a_shape, [0], [a_rank - 2])
-            elif a_rank < b_rank:
-                out_batch = _op.strided_slice(b_shape, [0], [b_rank - 2])
-            # If its unclear how broadcasting should be applied, the output
-            # shape is determined by choosing the maximum value from each 
input.
-            else:
-                out_batch = _op.concatenate(
-                    [
-                        _op.maximum(
-                            _op.strided_slice(a_shape, [i], [i + 1]),
-                            _op.strided_slice(b_shape, [i], [i + 1]),
-                        )
-                        for i in range(a_rank - 2)
-                    ],
-                    0,
-                )
+
             # Reshape output to original dimensions.
-            final_shape = _op.concatenate(
-                [
-                    out_batch,
-                    _op.strided_slice(
-                        a_shape, [infer_shape(a_shape)[0] - 2], 
[infer_shape(a_shape)[0] - 1]
-                    ),
-                    _op.strided_slice(
-                        b_shape, [infer_shape(b_shape)[0] - 1], 
[infer_shape(b_shape)[0]]
-                    ),
-                ],
-                0,
+            if need_reshape_output:
+                return _op.reshape(output, [*a_shape[:-2], a_shape[-2], 
b_shape[-1]])
+            return output
+        elif len(a_shape) > 2:
+            inputs_0 = _op.reshape(inputs_0, [-1, a_shape[-1]])
+
+        if len(b_shape) > 2:
+            trans_axes = list(range(len(b_shape)))
+            trans_axes[-2], trans_axes[-1] = trans_axes[-1], trans_axes[-2]
+            input_1 = _op.reshape(_op.transpose(inputs_1, trans_axes), [-1, 
b_shape[-2]])
+        elif len(b_shape) == 2:
+            input_1 = _op.transpose(inputs_1, axes=(1, 0))
+        elif len(b_shape) == 1:
+            input_1 = _op.expand_dims(inputs_1, 0, 1)
+
+        out = _op.nn.dense(inputs_0, input_1)
+
+        if len(b_shape) == 1:
+            out = _op.squeeze(out, axis=[-1])
+
+        # Reshape output into a N dimensional tensor when a or b dim > 2
+        if len(a_shape) > 2:
+            out = _op.reshape(out, [*a_shape[:-1], b_shape[-1]])
+        elif len(b_shape) > 2:
+            out = _op.reshape(out, [a_shape[-2], -1, b_shape[-1]])
+            out = _op.reshape(
+                _op.transpose(out, [1, 0, 2]), [*b_shape[:-2], a_shape[-2], 
b_shape[-1]]
             )

Review comment:
       I switched to the original ONNX implementation as I found some issues in 
the Pytorch implementation.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to