rohanmukh commented on a change in pull request #8251:
URL: https://github.com/apache/tvm/pull/8251#discussion_r652209025



##########
File path: python/tvm/relay/frontend/tensorflow_ops.py
##########
@@ -1162,6 +1163,9 @@ def _impl(inputs, attr, params, mod):
         adj_x = attr["adj_x"]
         adj_y = attr["adj_y"]
         input_x = _op.transpose(input_x, axes=[0, 2, 1]) if adj_x else input_x
+        shape_y = _infer_shape(input_y, mod)
+        if len(shape_y) < 3:
+            input_y = _op.reshape(input_y, (1, orig_shape_y[-2], 
orig_shape_y[-1]))

Review comment:
       Thanks @comaniac . It is important for cases where `ndim=len(shape_x)` 
is <=3. I have test cases that fail without this line. Like 
`_test_batch_matmul((1, 8, 64), (64, 1), "float32", False, False)`.  However 
the case that you mentioned can also happen for certain input configurations. I 
refactored the logic to avoid that. 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to