siju-samuel commented on a change in pull request #5834:
URL: https://github.com/apache/incubator-tvm/pull/5834#discussion_r442032927



##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -1733,12 +1780,19 @@ def _convert_dtype_value(val):
                                0:"torch.unit8",
                                None:"torch.int64"} # Default is torch.int64
     if val in convert_torch_dtype_map:
-        return convert_torch_dtype_map[val]
+        return _convert_data_type(convert_torch_dtype_map[val])
     else:
         msg = "Torch data type value %d is not handled yet." % (val)
         raise NotImplementedError(msg)
 
-def _convert_data_type(input_type):
+def _convert_data_type(input_type, default_dtype=None):
+    """converts the PyTorch scalar type input_type to a TVM dtype.
+       optionally, default_dtype can be a TVM dtype that is used
+       if input_type is None (but not when it is unknown)"""
+    if input_type is None and default_dtype is not None:
+        return default_dtype
+
+    input_type = input_type.lower()
     if input_type in ["double", "torch.float64"]:
         return "float64"
     elif input_type in ["float", "torch.float32"]:

Review comment:
       Add "float32" here

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -1755,12 +1809,21 @@ def _convert_data_type(input_type):
         return "int8"

Review comment:
       Add "int64" here

##########
File path: tests/python/frontend/pytorch/test_forward.py
##########
@@ -2363,6 +2366,23 @@ def forward(self, *args):
     t2 = torch.rand([1, 3]).float()
     verify_model(Addcmul2().float().eval(), input_data=[input_data, t1, t2])
 
+def test_forward_traced_function():
+    def fn(t1, t2):
+        return t1 + t2
+
+    tensor1 = torch.randn(3, 4)
+    tensor2 = torch.randn(3, 4)
+    verify_model(fn, input_data=[tensor1, tensor2])
+
+def test_forward_dtypes():
+    def fn(t1, t2):
+        return 2.5 * t1 + t2
+
+    for dt in [torch.int32, torch.int64, torch.double]:
+        tensor1 = torch.randn(3, 4).to(dtype=dt)
+        tensor2 = torch.randn(3, 4).to(dtype=dt)
+        verify_model(fn, input_data=[tensor1, tensor2])
+

Review comment:
       add `test_forward_traced_function` and `test_forward_dtypes` to 
[main](https://github.com/apache/incubator-tvm/blob/f305b31d6343f207b913eb1aafc8d07782445e33/tests/python/frontend/pytorch/test_forward.py#L2528)

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -115,64 +115,70 @@ def inplace_add_to_add(op_name):
     return False
 
 
+
 # operator implementation
 def _elemwise(name):
     def _impl(inputs, input_types):
-        # TODO: Figure out a better way to get typing to work for tensor + 
scalar
-        type0 = input_types[0]
-        if isinstance(inputs[1], _expr.Expr):
-            type0 = input_types[1]
-
-        type1 = input_types[1]
-        if isinstance(inputs[0], _expr.Expr):
-            type1 = input_types[0]
-
-        data0 = _convert_elemwise_input(inputs[0], type0)
-        data1 = _convert_elemwise_input(inputs[1], type1)
-
+        data0, data1 = _pytorch_promote_types(inputs[:2], input_types[:2])

Review comment:
       i was trying to run gp2 model and i was getting some error with the new 
modifications. 
   Code you can download from 
   https://gist.github.com/siju-samuel/34e63e0719e06679b5c3688bce7a0515
   
   Error is
   
   ```
   Traceback (most recent call last):
     File "gp2.py", line 26, in <module>
       mod, params = relay.frontend.from_pytorch(scripted_model, input_shapes)
     File "/home/siju/workspace/tvm/python/tvm/relay/frontend/pytorch.py", line 
2645, in from_pytorch
       ret = convert_operators(_get_operator_nodes(graph.nodes()),
     File "/home/siju/workspace/tvm/python/tvm/relay/frontend/pytorch.py", line 
2555, in convert_operators
       relay_out = relay_op(inputs, _get_input_types(op_node, 
default_dtype=default_dtype))
     File "/home/siju/workspace/tvm/python/tvm/relay/frontend/pytorch.py", line 
1694, in _impl
       return _elemwise("add")(inputs, input_types)
     File "/home/siju/workspace/tvm/python/tvm/relay/frontend/pytorch.py", line 
151, in _impl
       return get_relay_op(name)(data0, data1)
     File "/home/siju/workspace/tvm/python/tvm/relay/op/tensor.py", line 513, 
in add
       return _make.add(lhs, rhs)
     File "/home/siju/workspace/tvm/python/tvm/_ffi/_ctypes/packed_func.py", 
line 225, in __call__
       raise get_last_ffi_error()
   tvm._ffi.base.TVMError: Traceback (most recent call last):
     [bt] (4) /home/siju/workspace/tvm/build/libtvm.so(TVMFuncCall+0x69) 
[0x7f983831cc09]
     [bt] (3) /home/siju/workspace/tvm/build/libtvm.so(+0xa5be6b) 
[0x7f9837f31e6b]
     [bt] (2) 
/home/siju/workspace/tvm/build/libtvm.so(tvm::runtime::TVMMovableArgValue_::operator
 tvm::RelayExpr<tvm::RelayExpr, void>() const+0x63) [0x7f9837e0ecf3]
     [bt] (1) /home/siju/workspace/tvm/build/libtvm.so(tvm::RelayExpr 
tvm::runtime::TVMPODValue_::AsObjectRef<tvm::RelayExpr>() const+0x1a6) 
[0x7f9837ab21b6]
     [bt] (0) /home/siju/workspace/tvm/build/libtvm.so(+0x5cc26b) 
[0x7f9837aa226b]
     File "/home/siju/workspace/tvm/include/tvm/runtime/packed_func.h", line 
1423
   TVMError: Check failed: type_code_ == kTVMObjectHandle (0 vs. 8) : expected 
Object but get int
   ```




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to