ymwangg commented on a change in pull request #8440:
URL: https://github.com/apache/tvm/pull/8440#discussion_r667278362
##########
File path: python/tvm/relay/frontend/onnx.py
##########
@@ -670,68 +670,67 @@ class MatMul(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
assert len(inputs) == 2, "MatMul op take 2 inputs, {}
given".format(len(inputs))
+ inputs_0 = inputs[0]
+ inputs_1 = inputs[1]
+
# Need to check input shape as batch matmul must be supported.
- a_shape = shape_of(inputs[0])
- a_rank = infer_shape(a_shape)[0]
- b_shape = shape_of(inputs[1])
- b_rank = infer_shape(b_shape)[0]
- # When performing a batch matmul, we need to properly handle N-dim
shapes.
- if a_rank > 2 or b_rank > 2:
+ a_shape = infer_shape(inputs_0)
+ b_shape = infer_shape(inputs_1)
- def flatten_to_3d(x, x_shape):
- ndims = infer_shape(x_shape)[0]
- newshape = _op.concatenate(
- [
- _expr.const([-1],
dtype=infer_type(x_shape).checked_type.dtype),
- _op.strided_slice(x_shape, [ndims - 2], [ndims]),
- ],
- 0,
- )
- out = _op.reshape(x, fold_constant(newshape))
- return out
+ # When performing a batch matmul, we need to properly handle N-dim
shapes.
+ if len(a_shape) > 2 and len(b_shape) > 2:
+ # Convert a into a 3 dimensional tensors.
+ need_reshape_output = False
+ if len(a_shape) != 3:
+ a = _op.reshape(inputs_0, [-1, a_shape[-2], a_shape[-1]])
+ need_reshape_output = True
+ else:
+ a = inputs_0
- # Convert a and b into 3 dimensional tensors.
- a = flatten_to_3d(inputs[0], a_shape)
- b = flatten_to_3d(inputs[1], b_shape)
# Transpose matrix dimensions of b.
- b = _op.transpose(b, [0, 2, 1])
+ trans_axes = list(range(len(b_shape)))
+ trans_axes[-2], trans_axes[-1] = trans_axes[-1], trans_axes[-2]
+ b = _op.transpose(inputs_1, trans_axes)
+
+ # Convert b into a 3 dimensional tensor. Note that the last two
dimensions
+ # are transposed.
+ if len(b_shape) != 3:
+ b = _op.reshape(b, [-1, b_shape[-1], b_shape[-2]])
+
# Perform a batch matmul.
output = _op.nn.batch_matmul(a, b)
- # Determine the output batch dimension.
- if a_rank > b_rank:
- out_batch = _op.strided_slice(a_shape, [0], [a_rank - 2])
- elif a_rank < b_rank:
- out_batch = _op.strided_slice(b_shape, [0], [b_rank - 2])
- # If its unclear how broadcasting should be applied, the output
- # shape is determined by choosing the maximum value from each
input.
- else:
- out_batch = _op.concatenate(
- [
- _op.maximum(
- _op.strided_slice(a_shape, [i], [i + 1]),
- _op.strided_slice(b_shape, [i], [i + 1]),
- )
- for i in range(a_rank - 2)
- ],
- 0,
- )
+
# Reshape output to original dimensions.
- final_shape = _op.concatenate(
- [
- out_batch,
- _op.strided_slice(
- a_shape, [infer_shape(a_shape)[0] - 2],
[infer_shape(a_shape)[0] - 1]
- ),
- _op.strided_slice(
- b_shape, [infer_shape(b_shape)[0] - 1],
[infer_shape(b_shape)[0]]
- ),
- ],
- 0,
+ if need_reshape_output:
+ return _op.reshape(output, [*a_shape[:-2], a_shape[-2],
b_shape[-1]])
+ return output
+ elif len(a_shape) > 2:
+ inputs_0 = _op.reshape(inputs_0, [-1, a_shape[-1]])
+
+ if len(b_shape) > 2:
+ trans_axes = list(range(len(b_shape)))
+ trans_axes[-2], trans_axes[-1] = trans_axes[-1], trans_axes[-2]
+ input_1 = _op.reshape(_op.transpose(inputs_1, trans_axes), [-1,
b_shape[-2]])
+ elif len(b_shape) == 2:
+ input_1 = _op.transpose(inputs_1, axes=(1, 0))
+ elif len(b_shape) == 1:
+ input_1 = _op.expand_dims(inputs_1, 0, 1)
+
+ out = _op.nn.dense(inputs_0, input_1)
+
+ if len(b_shape) == 1:
+ out = _op.squeeze(out, axis=[-1])
+
+ # Reshape output into a N dimensional tensor when a or b dim > 2
+ if len(a_shape) > 2:
+ out = _op.reshape(out, [*a_shape[:-1], b_shape[-1]])
+ elif len(b_shape) > 2:
+ out = _op.reshape(out, [a_shape[-2], -1, b_shape[-1]])
+ out = _op.reshape(
+ _op.transpose(out, [1, 0, 2]), [*b_shape[:-2], a_shape[-2],
b_shape[-1]]
)
Review comment:
I switched to the original ONNX implementation as I found some issues in
the Pytorch implementation.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]