This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 939a42b  [Torch] Fix cast to long (#6301)
939a42b is described below

commit 939a42b4e976a41e8513b720421d3c3678493715
Author: masahi <[email protected]>
AuthorDate: Wed Aug 19 20:50:36 2020 +0900

    [Torch] Fix cast to long (#6301)
    
    * [Torch] fix cast to long
    
    * retrigger
---
 python/tvm/relay/frontend/pytorch.py          |  7 +++++--
 tests/python/frontend/pytorch/test_forward.py | 16 ++++++----------
 2 files changed, 11 insertions(+), 12 deletions(-)

diff --git a/python/tvm/relay/frontend/pytorch.py 
b/python/tvm/relay/frontend/pytorch.py
index 235cec0..85dd5f4 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -31,6 +31,7 @@ from .. import expr as _expr
 from .. import op as _op
 from ..ty import TupleType, TensorType, Any
 from ..loops import while_loop
+from .. import transform
 from .common import get_relay_op
 from .common import infer_shape as _infer_shape
 from .common import infer_value as _infer_value
@@ -1507,14 +1508,16 @@ def _to():
         cast_func = {
             6: float,
             3: int,
+            4: int
         }
         cast_func_expr = {
             6: lambda x: _op.cast(x, "float32"),
             3: lambda x: _op.cast(x, "int32"),
+            4: lambda x: _op.cast(x, "int64"),
         }
         if inputs[1] in cast_func and not isinstance(data, _expr.Expr):
             return cast_func[inputs[1]](data)
-        elif inputs[1] in cast_func and isinstance(data, _expr.Expr):
+        elif inputs[1] in cast_func_expr and isinstance(data, _expr.Expr):
             return cast_func_expr[inputs[1]](data)
         return data
 
@@ -2668,4 +2671,4 @@ def from_pytorch(script_module, input_shapes, 
custom_convert_map=None, default_d
 
     mod["main"] = tvm.relay.Function(_analysis.free_vars(ret[0]), ret[0])
 
-    return mod, tvm_params
+    return transform.RemoveUnusedFunctions()(mod), tvm_params
diff --git a/tests/python/frontend/pytorch/test_forward.py 
b/tests/python/frontend/pytorch/test_forward.py
index 2302f0f..d5b4ed2 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -1296,31 +1296,27 @@ def test_upsample():
 def test_to():
     """ test for aten::to(...) """
     class ToCPU(Module):
-        def __init__(self):
-            super().__init__()
-
         def forward(self, x):
             return x.to("cpu")
 
     class ToFloat(Module):
-        def __init__(self):
-            super().__init__()
-
         def forward(self, x):
             return x.float()
 
     class ToInt(Module):
-        def __init__(self):
-            super().__init__()
-
         def forward(self, x):
             return x.int()
 
+    class ToLong(Module):
+        def forward(self, x):
+            return x.long()
+
     verify_model(ToCPU().eval(), torch.rand((1, 3, 32, 32)))
     verify_model(ToFloat().eval(), torch.zeros((1, 3, 32, 32), 
dtype=torch.int))
     verify_model(ToFloat().eval(), torch.tensor(2, dtype=torch.int))
     verify_model(ToInt().eval(), torch.zeros((1, 3, 32, 32)))
-    verify_model(ToInt().eval(), torch.tensor(2.0))
+    verify_model(ToInt().eval(), torch.tensor(0.8))
+    verify_model(ToLong().eval(), torch.tensor(0.8))
 
 
 def test_adaptive_pool3d():

Reply via email to