comaniac commented on a change in pull request #8440:
URL: https://github.com/apache/tvm/pull/8440#discussion_r667245328



##########
File path: python/tvm/relay/frontend/onnx.py
##########
@@ -670,68 +670,67 @@ class MatMul(OnnxOpConverter):
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
         assert len(inputs) == 2, "MatMul op take 2 inputs, {} 
given".format(len(inputs))
+        inputs_0 = inputs[0]
+        inputs_1 = inputs[1]
+
         # Need to check input shape as batch matmul must be supported.
-        a_shape = shape_of(inputs[0])
-        a_rank = infer_shape(a_shape)[0]
-        b_shape = shape_of(inputs[1])
-        b_rank = infer_shape(b_shape)[0]
-        # When performing a batch matmul, we need to properly handle N-dim 
shapes.
-        if a_rank > 2 or b_rank > 2:
+        a_shape = infer_shape(inputs_0)
+        b_shape = infer_shape(inputs_1)
 
-            def flatten_to_3d(x, x_shape):
-                ndims = infer_shape(x_shape)[0]
-                newshape = _op.concatenate(
-                    [
-                        _expr.const([-1], 
dtype=infer_type(x_shape).checked_type.dtype),
-                        _op.strided_slice(x_shape, [ndims - 2], [ndims]),
-                    ],
-                    0,
-                )
-                out = _op.reshape(x, fold_constant(newshape))
-                return out
+        # When performing a batch matmul, we need to properly handle N-dim 
shapes.
+        if len(a_shape) > 2 and len(b_shape) > 2:
+            # Convert a into a 3 dimensional tensors.
+            need_reshape_output = False
+            if len(a_shape) != 3:
+                a = _op.reshape(inputs_0, [-1, a_shape[-2], a_shape[-1]])
+                need_reshape_output = True
+            else:
+                a = inputs_0
 
-            # Convert a and b into 3 dimensional tensors.
-            a = flatten_to_3d(inputs[0], a_shape)
-            b = flatten_to_3d(inputs[1], b_shape)
             # Transpose matrix dimensions of b.
-            b = _op.transpose(b, [0, 2, 1])
+            trans_axes = list(range(len(b_shape)))
+            trans_axes[-2], trans_axes[-1] = trans_axes[-1], trans_axes[-2]
+            b = _op.transpose(inputs_1, trans_axes)
+
+            # Convert b into a 3 dimensional tensor. Note that the last two 
dimensions
+            # are transposed.
+            if len(b_shape) != 3:
+                b = _op.reshape(b, [-1, b_shape[-1], b_shape[-2]])
+
             # Perform a batch matmul.
             output = _op.nn.batch_matmul(a, b)
-            # Determine the output batch dimension.
-            if a_rank > b_rank:
-                out_batch = _op.strided_slice(a_shape, [0], [a_rank - 2])
-            elif a_rank < b_rank:
-                out_batch = _op.strided_slice(b_shape, [0], [b_rank - 2])
-            # If its unclear how broadcasting should be applied, the output
-            # shape is determined by choosing the maximum value from each 
input.
-            else:
-                out_batch = _op.concatenate(
-                    [
-                        _op.maximum(
-                            _op.strided_slice(a_shape, [i], [i + 1]),
-                            _op.strided_slice(b_shape, [i], [i + 1]),
-                        )
-                        for i in range(a_rank - 2)
-                    ],
-                    0,
-                )
+
             # Reshape output to original dimensions.
-            final_shape = _op.concatenate(
-                [
-                    out_batch,
-                    _op.strided_slice(
-                        a_shape, [infer_shape(a_shape)[0] - 2], 
[infer_shape(a_shape)[0] - 1]
-                    ),
-                    _op.strided_slice(
-                        b_shape, [infer_shape(b_shape)[0] - 1], 
[infer_shape(b_shape)[0]]
-                    ),
-                ],
-                0,
+            if need_reshape_output:
+                return _op.reshape(output, [*a_shape[:-2], a_shape[-2], 
b_shape[-1]])
+            return output
+        elif len(a_shape) > 2:
+            inputs_0 = _op.reshape(inputs_0, [-1, a_shape[-1]])

Review comment:
       Then you can get rid of the no-else-return error.
   ```suggestion
   
           if len(a_shape) > 2:
               inputs_0 = _op.reshape(inputs_0, [-1, a_shape[-1]])
   ```

##########
File path: python/tvm/relay/frontend/onnx.py
##########
@@ -15,7 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 # pylint: disable=invalid-name, import-self, len-as-condition, 
unused-argument, too-many-lines
-# pylint: disable=import-outside-toplevel
+# pylint: disable=import-outside-toplevel, no-else-return

Review comment:
       Do not disable lint rule like this one.

##########
File path: python/tvm/relay/frontend/onnx.py
##########
@@ -670,68 +670,67 @@ class MatMul(OnnxOpConverter):
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
         assert len(inputs) == 2, "MatMul op take 2 inputs, {} 
given".format(len(inputs))
+        inputs_0 = inputs[0]
+        inputs_1 = inputs[1]
+
         # Need to check input shape as batch matmul must be supported.
-        a_shape = shape_of(inputs[0])
-        a_rank = infer_shape(a_shape)[0]
-        b_shape = shape_of(inputs[1])
-        b_rank = infer_shape(b_shape)[0]
-        # When performing a batch matmul, we need to properly handle N-dim 
shapes.
-        if a_rank > 2 or b_rank > 2:
+        a_shape = infer_shape(inputs_0)
+        b_shape = infer_shape(inputs_1)
 
-            def flatten_to_3d(x, x_shape):
-                ndims = infer_shape(x_shape)[0]
-                newshape = _op.concatenate(
-                    [
-                        _expr.const([-1], 
dtype=infer_type(x_shape).checked_type.dtype),
-                        _op.strided_slice(x_shape, [ndims - 2], [ndims]),
-                    ],
-                    0,
-                )
-                out = _op.reshape(x, fold_constant(newshape))
-                return out
+        # When performing a batch matmul, we need to properly handle N-dim 
shapes.
+        if len(a_shape) > 2 and len(b_shape) > 2:
+            # Convert a into a 3 dimensional tensors.
+            need_reshape_output = False
+            if len(a_shape) != 3:
+                a = _op.reshape(inputs_0, [-1, a_shape[-2], a_shape[-1]])
+                need_reshape_output = True
+            else:
+                a = inputs_0
 
-            # Convert a and b into 3 dimensional tensors.
-            a = flatten_to_3d(inputs[0], a_shape)
-            b = flatten_to_3d(inputs[1], b_shape)
             # Transpose matrix dimensions of b.
-            b = _op.transpose(b, [0, 2, 1])
+            trans_axes = list(range(len(b_shape)))
+            trans_axes[-2], trans_axes[-1] = trans_axes[-1], trans_axes[-2]
+            b = _op.transpose(inputs_1, trans_axes)
+
+            # Convert b into a 3 dimensional tensor. Note that the last two 
dimensions
+            # are transposed.
+            if len(b_shape) != 3:
+                b = _op.reshape(b, [-1, b_shape[-1], b_shape[-2]])
+
             # Perform a batch matmul.
             output = _op.nn.batch_matmul(a, b)
-            # Determine the output batch dimension.
-            if a_rank > b_rank:
-                out_batch = _op.strided_slice(a_shape, [0], [a_rank - 2])
-            elif a_rank < b_rank:
-                out_batch = _op.strided_slice(b_shape, [0], [b_rank - 2])
-            # If its unclear how broadcasting should be applied, the output
-            # shape is determined by choosing the maximum value from each 
input.
-            else:
-                out_batch = _op.concatenate(
-                    [
-                        _op.maximum(
-                            _op.strided_slice(a_shape, [i], [i + 1]),
-                            _op.strided_slice(b_shape, [i], [i + 1]),
-                        )
-                        for i in range(a_rank - 2)
-                    ],
-                    0,
-                )
+
             # Reshape output to original dimensions.
-            final_shape = _op.concatenate(
-                [
-                    out_batch,
-                    _op.strided_slice(
-                        a_shape, [infer_shape(a_shape)[0] - 2], 
[infer_shape(a_shape)[0] - 1]
-                    ),
-                    _op.strided_slice(
-                        b_shape, [infer_shape(b_shape)[0] - 1], 
[infer_shape(b_shape)[0]]
-                    ),
-                ],
-                0,
+            if need_reshape_output:
+                return _op.reshape(output, [*a_shape[:-2], a_shape[-2], 
b_shape[-1]])
+            return output
+        elif len(a_shape) > 2:
+            inputs_0 = _op.reshape(inputs_0, [-1, a_shape[-1]])
+
+        if len(b_shape) > 2:
+            trans_axes = list(range(len(b_shape)))
+            trans_axes[-2], trans_axes[-1] = trans_axes[-1], trans_axes[-2]
+            input_1 = _op.reshape(_op.transpose(inputs_1, trans_axes), [-1, 
b_shape[-2]])
+        elif len(b_shape) == 2:
+            input_1 = _op.transpose(inputs_1, axes=(1, 0))
+        elif len(b_shape) == 1:
+            input_1 = _op.expand_dims(inputs_1, 0, 1)
+
+        out = _op.nn.dense(inputs_0, input_1)
+
+        if len(b_shape) == 1:
+            out = _op.squeeze(out, axis=[-1])
+
+        # Reshape output into a N dimensional tensor when a or b dim > 2
+        if len(a_shape) > 2:
+            out = _op.reshape(out, [*a_shape[:-1], b_shape[-1]])
+        elif len(b_shape) > 2:
+            out = _op.reshape(out, [a_shape[-2], -1, b_shape[-1]])
+            out = _op.reshape(
+                _op.transpose(out, [1, 0, 2]), [*b_shape[:-2], a_shape[-2], 
b_shape[-1]]
             )

Review comment:
       The pattern `reshape - transpose - reshape` cannot be optimized by 
follow-up passes. Can we use just one `reshape` and one `transpose`?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to