This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new a3e03a3 [Torch] Frontend update to support PyTorch 1.10 (#9664)
a3e03a3 is described below
commit a3e03a37231b68d5ea2701e1c52088f4ebd20057
Author: masahi <[email protected]>
AuthorDate: Tue Dec 7 18:28:36 2021 +0900
[Torch] Frontend update to support PyTorch 1.10 (#9664)
* wip
* fixed converting maskrcnn
* fixed nll los
* fixed linspace
* fixed deformable conv2d
* control flow and rnn test had no problem
* swap more import orders
* qmv3 test is having weird segfault
* cleanup
* black
---
python/tvm/relay/frontend/pytorch.py | 43 +++++++++++++++++-----
python/tvm/relay/frontend/pytorch_utils.py | 4 +-
tests/python/frontend/pytorch/test_forward.py | 2 +
.../frontend/pytorch/test_object_detection.py | 6 +--
tests/python/frontend/pytorch/test_rnns.py | 2 +-
5 files changed, 41 insertions(+), 16 deletions(-)
diff --git a/python/tvm/relay/frontend/pytorch.py
b/python/tvm/relay/frontend/pytorch.py
index d9b8f90..fbc5f6e 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -408,11 +408,16 @@ class PyTorchOpConverter:
):
return data
+ if target_begin is None and target_end is None:
+ return data
+
# Process begin
begin = [0] * ndim
- begin[dim] = target_begin
- if not isinstance(begin[dim], int):
+ if target_begin is not None:
+ begin[dim] = target_begin
+
+ if target_begin is not None and not isinstance(begin[dim], int):
tmp = []
for b in begin:
if isinstance(b, int):
@@ -455,7 +460,7 @@ class PyTorchOpConverter:
)
else:
end = _op.cast(_op.shape_of(data), axis_dtype)
- if not isinstance(target_end, tvm.tir.Any):
+ if target_end is not None and not isinstance(target_end,
tvm.tir.Any):
ttype = self.infer_type(target_end).dtype
if str(ttype) != axis_dtype:
target_end = _op.cast(target_end, axis_dtype)
@@ -745,7 +750,13 @@ class PyTorchOpConverter:
else:
stop = start + step
- dtype = "float32" if inputs[3] is not None else
_convert_dtype_value(inputs[3])
+ if inputs[3] is None:
+ import torch
+
+ dtype = _convert_data_type(str(torch.get_default_dtype()))
+ else:
+ dtype = _convert_dtype_value(inputs[3])
+
start = _create_typed_const(start, dtype)
stop = _create_typed_const(stop, dtype)
step = _create_typed_const(step, dtype)
@@ -2065,16 +2076,25 @@ class PyTorchOpConverter:
data = inputs[0]
weight = inputs[1]
offset = inputs[2]
- strides = (inputs[4], inputs[5])
- padding = (inputs[6], inputs[7])
- dilation = (inputs[8], inputs[9])
- groups = inputs[10]
- deformable_groups = inputs[11]
+
+ if len(inputs) > 12:
+ strides_offset = 5
+ bias = inputs[4]
+ logging.warning("mask argument in deformable conv2d is not
supported and ignored")
+ else:
+ strides_offset = 4
+ bias = inputs[3]
+
+ strides = (inputs[strides_offset], inputs[strides_offset + 1])
+ padding = (inputs[strides_offset + 2], inputs[strides_offset + 3])
+ dilation = (inputs[strides_offset + 4], inputs[strides_offset + 5])
+ groups = inputs[strides_offset + 6]
+ deformable_groups = inputs[strides_offset + 7]
weight_shape = self.infer_shape(weight)
output_channels = weight_shape[0]
kernel_size = (weight_shape[2], weight_shape[3])
- return _op.nn.deformable_conv2d(
+ conv_out = _op.nn.deformable_conv2d(
data,
offset,
weight,
@@ -2087,6 +2107,8 @@ class PyTorchOpConverter:
kernel_size,
)
+ return _op.nn.bias_add(conv_out, bias)
+
def unbind(self, inputs, input_types):
data = inputs[0]
axis = int(inputs[1])
@@ -3059,6 +3081,7 @@ class PyTorchOpConverter:
"aten::_unique2": self.unique,
"aten::nll_loss": self.nll_loss,
"aten::nll_loss2d": self.nll_loss,
+ "aten::nll_loss_nd": self.nll_loss,
"aten::flip": self.flip,
"aten::gru": self.gru,
"aten::lstm": self.lstm,
diff --git a/python/tvm/relay/frontend/pytorch_utils.py
b/python/tvm/relay/frontend/pytorch_utils.py
index 753b1f2..b4fe168 100644
--- a/python/tvm/relay/frontend/pytorch_utils.py
+++ b/python/tvm/relay/frontend/pytorch_utils.py
@@ -34,8 +34,8 @@ def is_version_greater_than(ver):
import torch
import re
- return "".join(re.findall(r"(\d+\.)(\d+\.)(\d)", torch.__version__)[0]) >
"".join(
- re.findall(r"(\d+\.)(\d+\.)(\d)", ver)[0]
+ return int("".join(re.findall(r"(\d+)\.(\d+)\.(\d)",
torch.__version__)[0])) > int(
+ "".join(re.findall(r"(\d+)\.(\d+)\.(\d)", ver)[0])
)
diff --git a/tests/python/frontend/pytorch/test_forward.py
b/tests/python/frontend/pytorch/test_forward.py
index b30b0af..0692ab8 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -35,6 +35,8 @@ from tvm.contrib.nvcc import have_fp16
import pytest
sys.setrecursionlimit(10000)
+torch.backends.cuda.matmul.allow_tf32 = False
+torch.backends.cudnn.allow_tf32 = False
def list_ops(expr):
diff --git a/tests/python/frontend/pytorch/test_object_detection.py
b/tests/python/frontend/pytorch/test_object_detection.py
index 3bc2bf5c..26ce5bf 100644
--- a/tests/python/frontend/pytorch/test_object_detection.py
+++ b/tests/python/frontend/pytorch/test_object_detection.py
@@ -19,6 +19,9 @@
import numpy as np
import cv2
+import torch
+import torchvision
+
import tvm
import tvm.testing
@@ -31,9 +34,6 @@ from tvm.relay.frontend.pytorch_utils import (
)
from tvm.contrib.download import download
-import torch
-import torchvision
-
in_size = 300
diff --git a/tests/python/frontend/pytorch/test_rnns.py
b/tests/python/frontend/pytorch/test_rnns.py
index b5784a6..b0180a7 100644
--- a/tests/python/frontend/pytorch/test_rnns.py
+++ b/tests/python/frontend/pytorch/test_rnns.py
@@ -15,9 +15,9 @@
# specific language governing permissions and limitations
# under the License.
+import torch
import tvm
import tvm.testing
-import torch
import onnx
import io
import sys