This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 939a42b [Torch] Fix cast to long (#6301)
939a42b is described below
commit 939a42b4e976a41e8513b720421d3c3678493715
Author: masahi <[email protected]>
AuthorDate: Wed Aug 19 20:50:36 2020 +0900
[Torch] Fix cast to long (#6301)
* [Torch] fix cast to long
* retrigger
---
python/tvm/relay/frontend/pytorch.py | 7 +++++--
tests/python/frontend/pytorch/test_forward.py | 16 ++++++----------
2 files changed, 11 insertions(+), 12 deletions(-)
diff --git a/python/tvm/relay/frontend/pytorch.py
b/python/tvm/relay/frontend/pytorch.py
index 235cec0..85dd5f4 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -31,6 +31,7 @@ from .. import expr as _expr
from .. import op as _op
from ..ty import TupleType, TensorType, Any
from ..loops import while_loop
+from .. import transform
from .common import get_relay_op
from .common import infer_shape as _infer_shape
from .common import infer_value as _infer_value
@@ -1507,14 +1508,16 @@ def _to():
cast_func = {
6: float,
3: int,
+ 4: int
}
cast_func_expr = {
6: lambda x: _op.cast(x, "float32"),
3: lambda x: _op.cast(x, "int32"),
+ 4: lambda x: _op.cast(x, "int64"),
}
if inputs[1] in cast_func and not isinstance(data, _expr.Expr):
return cast_func[inputs[1]](data)
- elif inputs[1] in cast_func and isinstance(data, _expr.Expr):
+ elif inputs[1] in cast_func_expr and isinstance(data, _expr.Expr):
return cast_func_expr[inputs[1]](data)
return data
@@ -2668,4 +2671,4 @@ def from_pytorch(script_module, input_shapes,
custom_convert_map=None, default_d
mod["main"] = tvm.relay.Function(_analysis.free_vars(ret[0]), ret[0])
- return mod, tvm_params
+ return transform.RemoveUnusedFunctions()(mod), tvm_params
diff --git a/tests/python/frontend/pytorch/test_forward.py
b/tests/python/frontend/pytorch/test_forward.py
index 2302f0f..d5b4ed2 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -1296,31 +1296,27 @@ def test_upsample():
def test_to():
""" test for aten::to(...) """
class ToCPU(Module):
- def __init__(self):
- super().__init__()
-
def forward(self, x):
return x.to("cpu")
class ToFloat(Module):
- def __init__(self):
- super().__init__()
-
def forward(self, x):
return x.float()
class ToInt(Module):
- def __init__(self):
- super().__init__()
-
def forward(self, x):
return x.int()
+ class ToLong(Module):
+ def forward(self, x):
+ return x.long()
+
verify_model(ToCPU().eval(), torch.rand((1, 3, 32, 32)))
verify_model(ToFloat().eval(), torch.zeros((1, 3, 32, 32),
dtype=torch.int))
verify_model(ToFloat().eval(), torch.tensor(2, dtype=torch.int))
verify_model(ToInt().eval(), torch.zeros((1, 3, 32, 32)))
- verify_model(ToInt().eval(), torch.tensor(2.0))
+ verify_model(ToInt().eval(), torch.tensor(0.8))
+ verify_model(ToLong().eval(), torch.tensor(0.8))
def test_adaptive_pool3d():