jcf94 commented on a change in pull request #7845:
URL: https://github.com/apache/tvm/pull/7845#discussion_r612988079
##########
File path: tests/python/frontend/pytorch/test_forward.py
##########
@@ -162,7 +162,7 @@ def measure_latency(model, input_shapes, output_shapes,
thresh, dryruns=40):
return est
-def verify_model(model_name, input_data=[], custom_convert_map={}, rtol=1e-5,
atol=1e-5):
+def verify_model(model_name, input_data=[], custom_convert_map={}, rtol=1e-5,
atol=1e-5, expected_ops=[]):
Review comment:
Good addition!
##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -1606,18 +1606,30 @@ def matmul(self, inputs, input_types):
if need_reshape_output:
return _op.reshape(output, [*a_shape[:-2], a_shape[-2],
b_shape[-1]])
return output
+ elif len(a_shape) > 2:
+ inputs_0 = _op.reshape(inputs_0, [-1, a_shape[-1]])
- # Otherwise a simple dense op will get the job done.
- if len(b_shape) == 1:
+ if len(b_shape) > 2:
+ trans_axes = list(range(len(b_shape)))
+ trans_axes[-2], trans_axes[-1] = trans_axes[-1], trans_axes[-2]
+ input_1 = _op.reshape(_op.transpose(inputs_1, trans_axes), [-1,
b_shape[-2]])
+ elif len(b_shape) == 1:
input_1 = _op.expand_dims(inputs_1, 0, 1)
- else:
+ elif len(b_shape) == 2:
input_1 = _op.transpose(inputs_1, axes=(1, 0))
out = _op.nn.dense(inputs_0, input_1)
if len(b_shape) == 1:
out = _op.squeeze(out, axis=[-1])
+ # Reshape output into a N dimensional tensor when a or b dim > 2
+ if len(a_shape) > 2:
+ out = _op.reshape(out, [*a_shape[:-1], b_shape[-1]])
+ elif len(b_shape) > 2:
+ out = _op.reshape(out, [a_shape[-2], -1, b_shape[-1]])
+ out = _op.reshape(_op.transpose(out, [1, 0, 2]), [*b_shape[:-2],
a_shape[-2], b_shape[-1]])
+
Review comment:
Based on @comaniac 's comments, I think it will be better to also merge
these conditions together:
if len(a_shape) > 2:
elif len(b_shape) > 2:
elif len(b_shape) == 1:
##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -1606,18 +1606,30 @@ def matmul(self, inputs, input_types):
if need_reshape_output:
return _op.reshape(output, [*a_shape[:-2], a_shape[-2],
b_shape[-1]])
return output
+ elif len(a_shape) > 2:
+ inputs_0 = _op.reshape(inputs_0, [-1, a_shape[-1]])
- # Otherwise a simple dense op will get the job done.
- if len(b_shape) == 1:
+ if len(b_shape) > 2:
+ trans_axes = list(range(len(b_shape)))
+ trans_axes[-2], trans_axes[-1] = trans_axes[-1], trans_axes[-2]
+ input_1 = _op.reshape(_op.transpose(inputs_1, trans_axes), [-1,
b_shape[-2]])
+ elif len(b_shape) == 1:
input_1 = _op.expand_dims(inputs_1, 0, 1)
- else:
+ elif len(b_shape) == 2:
input_1 = _op.transpose(inputs_1, axes=(1, 0))
Review comment:
Not necessary but if these order can be changed to:
if len(b_shape) > 2:
elif len(b_shape) == 2:
elif len(b_shape) == 1:
will be better.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]