AndrewZhaoLuo commented on a change in pull request #9946:
URL: https://github.com/apache/tvm/pull/9946#discussion_r791061268



##########
File path: python/tvm/relay/frontend/common.py
##########
@@ -938,9 +939,14 @@ def ensure_scalar_shape(x):
         return x
 
     num_elem = np.prod(x_shape)
-    assert num_elem == 1, "Cannot squeeze tensor shape {} to scalar 
form.".format(x_shape)
-
-    return _op.squeeze(x)
+    if num_elem == 1:
+        return _op.squeeze(x)
+    else:
+        if force_assert:
+            assert num_elem == 1, "Cannot squeeze tensor shape {} to scalar 
form.".format(x_shape)
+        else:
+            return x

Review comment:
       This changes the behavior of `ensure_scalar_shape` -- now we can return 
non-scalar elements, I would instead in `QLinearMatMul's` 
`try_resolve_to_const` simply, wrap `ensure_scalar_shape` calls with a check if 
the number of elements is 1. 

##########
File path: python/tvm/relay/frontend/common.py
##########
@@ -938,9 +939,14 @@ def ensure_scalar_shape(x):
         return x
 
     num_elem = np.prod(x_shape)
-    assert num_elem == 1, "Cannot squeeze tensor shape {} to scalar 
form.".format(x_shape)
-
-    return _op.squeeze(x)
+    if num_elem == 1:
+        return _op.squeeze(x)
+    else:
+        if force_assert:
+            assert num_elem == 1, "Cannot squeeze tensor shape {} to scalar 
form.".format(x_shape)

Review comment:
       In this codepath, you will return None

##########
File path: python/tvm/relay/frontend/onnx.py
##########
@@ -3804,10 +3804,9 @@ def _impl_v10(cls, inputs, attr, params):
         #
         # This function attempts to present 'x' in a form that meets both of 
those
         # requirements.
-        def try_resolve_to_const_scalar(x, dtype_override=None):
+        def try_resolve_to_const(x, dtype_override=None, allow1D=False):
             x2 = try_resolve_var_to_const(x, params)
-            x3 = ensure_scalar_shape(x2)

Review comment:
       echoing above, instead of adding a new flag to ensure_scalar_shape that 
breaks invariants, take the check you wrote and push it out here.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to