t-vi commented on a change in pull request #6449:
URL: https://github.com/apache/incubator-tvm/pull/6449#discussion_r487271315



##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -127,8 +128,22 @@ def _is_quantized_tensor(data, prelude):
 # operator implementation
 def _elemwise(name):
     def _impl(inputs, input_types):
-        data0, data1 = _pytorch_promote_types(inputs[:2], input_types[:2])
-        return get_relay_op(name)(data0, data1)
+        dtype0, dtype1 = input_types[:2]
+        if isinstance(inputs[0], _expr.Expr):
+            dtype0 = _infer_type(inputs[0]).checked_type.dtype
+        if isinstance(inputs[1], _expr.Expr):
+            dtype1 = _infer_type(inputs[1]).checked_type.dtype
+

Review comment:
       I think we would eventually want to look at using type propagation more.
   However, the issue here is that PyTorch's default dtype for integral tensors 
is int64. I don't think we should be hacking around that, really, because we're 
bound to end up with cases where int64 is the right thing to have. If I 
understood the discussions on the forum correctly, the idea was to downcast 64 
bit indexing to 32 based if it is considered safe.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to