comaniac commented on a change in pull request #8440:
URL: https://github.com/apache/tvm/pull/8440#discussion_r667245328
##########
File path: python/tvm/relay/frontend/onnx.py
##########
@@ -670,68 +670,67 @@ class MatMul(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
assert len(inputs) == 2, "MatMul op take 2 inputs, {}
given".format(len(inputs))
+ inputs_0 = inputs[0]
+ inputs_1 = inputs[1]
+
# Need to check input shape as batch matmul must be supported.
- a_shape = shape_of(inputs[0])
- a_rank = infer_shape(a_shape)[0]
- b_shape = shape_of(inputs[1])
- b_rank = infer_shape(b_shape)[0]
- # When performing a batch matmul, we need to properly handle N-dim
shapes.
- if a_rank > 2 or b_rank > 2:
+ a_shape = infer_shape(inputs_0)
+ b_shape = infer_shape(inputs_1)
- def flatten_to_3d(x, x_shape):
- ndims = infer_shape(x_shape)[0]
- newshape = _op.concatenate(
- [
- _expr.const([-1],
dtype=infer_type(x_shape).checked_type.dtype),
- _op.strided_slice(x_shape, [ndims - 2], [ndims]),
- ],
- 0,
- )
- out = _op.reshape(x, fold_constant(newshape))
- return out
+ # When performing a batch matmul, we need to properly handle N-dim
shapes.
+ if len(a_shape) > 2 and len(b_shape) > 2:
+ # Convert a into a 3 dimensional tensors.
+ need_reshape_output = False
+ if len(a_shape) != 3:
+ a = _op.reshape(inputs_0, [-1, a_shape[-2], a_shape[-1]])
+ need_reshape_output = True
+ else:
+ a = inputs_0
- # Convert a and b into 3 dimensional tensors.
- a = flatten_to_3d(inputs[0], a_shape)
- b = flatten_to_3d(inputs[1], b_shape)
# Transpose matrix dimensions of b.
- b = _op.transpose(b, [0, 2, 1])
+ trans_axes = list(range(len(b_shape)))
+ trans_axes[-2], trans_axes[-1] = trans_axes[-1], trans_axes[-2]
+ b = _op.transpose(inputs_1, trans_axes)
+
+ # Convert b into a 3 dimensional tensor. Note that the last two
dimensions
+ # are transposed.
+ if len(b_shape) != 3:
+ b = _op.reshape(b, [-1, b_shape[-1], b_shape[-2]])
+
# Perform a batch matmul.
output = _op.nn.batch_matmul(a, b)
- # Determine the output batch dimension.
- if a_rank > b_rank:
- out_batch = _op.strided_slice(a_shape, [0], [a_rank - 2])
- elif a_rank < b_rank:
- out_batch = _op.strided_slice(b_shape, [0], [b_rank - 2])
- # If its unclear how broadcasting should be applied, the output
- # shape is determined by choosing the maximum value from each
input.
- else:
- out_batch = _op.concatenate(
- [
- _op.maximum(
- _op.strided_slice(a_shape, [i], [i + 1]),
- _op.strided_slice(b_shape, [i], [i + 1]),
- )
- for i in range(a_rank - 2)
- ],
- 0,
- )
+
# Reshape output to original dimensions.
- final_shape = _op.concatenate(
- [
- out_batch,
- _op.strided_slice(
- a_shape, [infer_shape(a_shape)[0] - 2],
[infer_shape(a_shape)[0] - 1]
- ),
- _op.strided_slice(
- b_shape, [infer_shape(b_shape)[0] - 1],
[infer_shape(b_shape)[0]]
- ),
- ],
- 0,
+ if need_reshape_output:
+ return _op.reshape(output, [*a_shape[:-2], a_shape[-2],
b_shape[-1]])
+ return output
+ elif len(a_shape) > 2:
+ inputs_0 = _op.reshape(inputs_0, [-1, a_shape[-1]])
Review comment:
Then you can get rid of the no-else-return error.
```suggestion
if len(a_shape) > 2:
inputs_0 = _op.reshape(inputs_0, [-1, a_shape[-1]])
```
##########
File path: python/tvm/relay/frontend/onnx.py
##########
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, import-self, len-as-condition,
unused-argument, too-many-lines
-# pylint: disable=import-outside-toplevel
+# pylint: disable=import-outside-toplevel, no-else-return
Review comment:
Do not disable lint rule like this one.
##########
File path: python/tvm/relay/frontend/onnx.py
##########
@@ -670,68 +670,67 @@ class MatMul(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
assert len(inputs) == 2, "MatMul op take 2 inputs, {}
given".format(len(inputs))
+ inputs_0 = inputs[0]
+ inputs_1 = inputs[1]
+
# Need to check input shape as batch matmul must be supported.
- a_shape = shape_of(inputs[0])
- a_rank = infer_shape(a_shape)[0]
- b_shape = shape_of(inputs[1])
- b_rank = infer_shape(b_shape)[0]
- # When performing a batch matmul, we need to properly handle N-dim
shapes.
- if a_rank > 2 or b_rank > 2:
+ a_shape = infer_shape(inputs_0)
+ b_shape = infer_shape(inputs_1)
- def flatten_to_3d(x, x_shape):
- ndims = infer_shape(x_shape)[0]
- newshape = _op.concatenate(
- [
- _expr.const([-1],
dtype=infer_type(x_shape).checked_type.dtype),
- _op.strided_slice(x_shape, [ndims - 2], [ndims]),
- ],
- 0,
- )
- out = _op.reshape(x, fold_constant(newshape))
- return out
+ # When performing a batch matmul, we need to properly handle N-dim
shapes.
+ if len(a_shape) > 2 and len(b_shape) > 2:
+ # Convert a into a 3 dimensional tensors.
+ need_reshape_output = False
+ if len(a_shape) != 3:
+ a = _op.reshape(inputs_0, [-1, a_shape[-2], a_shape[-1]])
+ need_reshape_output = True
+ else:
+ a = inputs_0
- # Convert a and b into 3 dimensional tensors.
- a = flatten_to_3d(inputs[0], a_shape)
- b = flatten_to_3d(inputs[1], b_shape)
# Transpose matrix dimensions of b.
- b = _op.transpose(b, [0, 2, 1])
+ trans_axes = list(range(len(b_shape)))
+ trans_axes[-2], trans_axes[-1] = trans_axes[-1], trans_axes[-2]
+ b = _op.transpose(inputs_1, trans_axes)
+
+ # Convert b into a 3 dimensional tensor. Note that the last two
dimensions
+ # are transposed.
+ if len(b_shape) != 3:
+ b = _op.reshape(b, [-1, b_shape[-1], b_shape[-2]])
+
# Perform a batch matmul.
output = _op.nn.batch_matmul(a, b)
- # Determine the output batch dimension.
- if a_rank > b_rank:
- out_batch = _op.strided_slice(a_shape, [0], [a_rank - 2])
- elif a_rank < b_rank:
- out_batch = _op.strided_slice(b_shape, [0], [b_rank - 2])
- # If its unclear how broadcasting should be applied, the output
- # shape is determined by choosing the maximum value from each
input.
- else:
- out_batch = _op.concatenate(
- [
- _op.maximum(
- _op.strided_slice(a_shape, [i], [i + 1]),
- _op.strided_slice(b_shape, [i], [i + 1]),
- )
- for i in range(a_rank - 2)
- ],
- 0,
- )
+
# Reshape output to original dimensions.
- final_shape = _op.concatenate(
- [
- out_batch,
- _op.strided_slice(
- a_shape, [infer_shape(a_shape)[0] - 2],
[infer_shape(a_shape)[0] - 1]
- ),
- _op.strided_slice(
- b_shape, [infer_shape(b_shape)[0] - 1],
[infer_shape(b_shape)[0]]
- ),
- ],
- 0,
+ if need_reshape_output:
+ return _op.reshape(output, [*a_shape[:-2], a_shape[-2],
b_shape[-1]])
+ return output
+ elif len(a_shape) > 2:
+ inputs_0 = _op.reshape(inputs_0, [-1, a_shape[-1]])
+
+ if len(b_shape) > 2:
+ trans_axes = list(range(len(b_shape)))
+ trans_axes[-2], trans_axes[-1] = trans_axes[-1], trans_axes[-2]
+ input_1 = _op.reshape(_op.transpose(inputs_1, trans_axes), [-1,
b_shape[-2]])
+ elif len(b_shape) == 2:
+ input_1 = _op.transpose(inputs_1, axes=(1, 0))
+ elif len(b_shape) == 1:
+ input_1 = _op.expand_dims(inputs_1, 0, 1)
+
+ out = _op.nn.dense(inputs_0, input_1)
+
+ if len(b_shape) == 1:
+ out = _op.squeeze(out, axis=[-1])
+
+ # Reshape output into a N dimensional tensor when a or b dim > 2
+ if len(a_shape) > 2:
+ out = _op.reshape(out, [*a_shape[:-1], b_shape[-1]])
+ elif len(b_shape) > 2:
+ out = _op.reshape(out, [a_shape[-2], -1, b_shape[-1]])
+ out = _op.reshape(
+ _op.transpose(out, [1, 0, 2]), [*b_shape[:-2], a_shape[-2],
b_shape[-1]]
)
Review comment:
The pattern `reshape - transpose - reshape` cannot be optimized by
follow-up passes. Can we use just one `reshape` and one `transpose`?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]