This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new a3e03a3  [Torch] Frontend update to support PyTorch 1.10 (#9664)
a3e03a3 is described below

commit a3e03a37231b68d5ea2701e1c52088f4ebd20057
Author: masahi <[email protected]>
AuthorDate: Tue Dec 7 18:28:36 2021 +0900

    [Torch] Frontend update to support PyTorch 1.10 (#9664)
    
    * wip
    
    * fixed converting maskrcnn
    
    * fixed nll los
    
    * fixed linspace
    
    * fixed deformable conv2d
    
    * control flow and rnn test had no problem
    
    * swap more import orders
    
    * qmv3 test is having weird segfault
    
    * cleanup
    
    * black
---
 python/tvm/relay/frontend/pytorch.py               | 43 +++++++++++++++++-----
 python/tvm/relay/frontend/pytorch_utils.py         |  4 +-
 tests/python/frontend/pytorch/test_forward.py      |  2 +
 .../frontend/pytorch/test_object_detection.py      |  6 +--
 tests/python/frontend/pytorch/test_rnns.py         |  2 +-
 5 files changed, 41 insertions(+), 16 deletions(-)

diff --git a/python/tvm/relay/frontend/pytorch.py 
b/python/tvm/relay/frontend/pytorch.py
index d9b8f90..fbc5f6e 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -408,11 +408,16 @@ class PyTorchOpConverter:
         ):
             return data
 
+        if target_begin is None and target_end is None:
+            return data
+
         # Process begin
         begin = [0] * ndim
-        begin[dim] = target_begin
 
-        if not isinstance(begin[dim], int):
+        if target_begin is not None:
+            begin[dim] = target_begin
+
+        if target_begin is not None and not isinstance(begin[dim], int):
             tmp = []
             for b in begin:
                 if isinstance(b, int):
@@ -455,7 +460,7 @@ class PyTorchOpConverter:
                     )
         else:
             end = _op.cast(_op.shape_of(data), axis_dtype)
-            if not isinstance(target_end, tvm.tir.Any):
+            if target_end is not None and not isinstance(target_end, 
tvm.tir.Any):
                 ttype = self.infer_type(target_end).dtype
                 if str(ttype) != axis_dtype:
                     target_end = _op.cast(target_end, axis_dtype)
@@ -745,7 +750,13 @@ class PyTorchOpConverter:
         else:
             stop = start + step
 
-        dtype = "float32" if inputs[3] is not None else 
_convert_dtype_value(inputs[3])
+        if inputs[3] is None:
+            import torch
+
+            dtype = _convert_data_type(str(torch.get_default_dtype()))
+        else:
+            dtype = _convert_dtype_value(inputs[3])
+
         start = _create_typed_const(start, dtype)
         stop = _create_typed_const(stop, dtype)
         step = _create_typed_const(step, dtype)
@@ -2065,16 +2076,25 @@ class PyTorchOpConverter:
         data = inputs[0]
         weight = inputs[1]
         offset = inputs[2]
-        strides = (inputs[4], inputs[5])
-        padding = (inputs[6], inputs[7])
-        dilation = (inputs[8], inputs[9])
-        groups = inputs[10]
-        deformable_groups = inputs[11]
+
+        if len(inputs) > 12:
+            strides_offset = 5
+            bias = inputs[4]
+            logging.warning("mask argument in deformable conv2d is not 
supported and ignored")
+        else:
+            strides_offset = 4
+            bias = inputs[3]
+
+        strides = (inputs[strides_offset], inputs[strides_offset + 1])
+        padding = (inputs[strides_offset + 2], inputs[strides_offset + 3])
+        dilation = (inputs[strides_offset + 4], inputs[strides_offset + 5])
+        groups = inputs[strides_offset + 6]
+        deformable_groups = inputs[strides_offset + 7]
         weight_shape = self.infer_shape(weight)
         output_channels = weight_shape[0]
         kernel_size = (weight_shape[2], weight_shape[3])
 
-        return _op.nn.deformable_conv2d(
+        conv_out = _op.nn.deformable_conv2d(
             data,
             offset,
             weight,
@@ -2087,6 +2107,8 @@ class PyTorchOpConverter:
             kernel_size,
         )
 
+        return _op.nn.bias_add(conv_out, bias)
+
     def unbind(self, inputs, input_types):
         data = inputs[0]
         axis = int(inputs[1])
@@ -3059,6 +3081,7 @@ class PyTorchOpConverter:
             "aten::_unique2": self.unique,
             "aten::nll_loss": self.nll_loss,
             "aten::nll_loss2d": self.nll_loss,
+            "aten::nll_loss_nd": self.nll_loss,
             "aten::flip": self.flip,
             "aten::gru": self.gru,
             "aten::lstm": self.lstm,
diff --git a/python/tvm/relay/frontend/pytorch_utils.py 
b/python/tvm/relay/frontend/pytorch_utils.py
index 753b1f2..b4fe168 100644
--- a/python/tvm/relay/frontend/pytorch_utils.py
+++ b/python/tvm/relay/frontend/pytorch_utils.py
@@ -34,8 +34,8 @@ def is_version_greater_than(ver):
     import torch
     import re
 
-    return "".join(re.findall(r"(\d+\.)(\d+\.)(\d)", torch.__version__)[0]) > 
"".join(
-        re.findall(r"(\d+\.)(\d+\.)(\d)", ver)[0]
+    return int("".join(re.findall(r"(\d+)\.(\d+)\.(\d)", 
torch.__version__)[0])) > int(
+        "".join(re.findall(r"(\d+)\.(\d+)\.(\d)", ver)[0])
     )
 
 
diff --git a/tests/python/frontend/pytorch/test_forward.py 
b/tests/python/frontend/pytorch/test_forward.py
index b30b0af..0692ab8 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -35,6 +35,8 @@ from tvm.contrib.nvcc import have_fp16
 import pytest
 
 sys.setrecursionlimit(10000)
+torch.backends.cuda.matmul.allow_tf32 = False
+torch.backends.cudnn.allow_tf32 = False
 
 
 def list_ops(expr):
diff --git a/tests/python/frontend/pytorch/test_object_detection.py 
b/tests/python/frontend/pytorch/test_object_detection.py
index 3bc2bf5c..26ce5bf 100644
--- a/tests/python/frontend/pytorch/test_object_detection.py
+++ b/tests/python/frontend/pytorch/test_object_detection.py
@@ -19,6 +19,9 @@
 import numpy as np
 import cv2
 
+import torch
+import torchvision
+
 import tvm
 
 import tvm.testing
@@ -31,9 +34,6 @@ from tvm.relay.frontend.pytorch_utils import (
 )
 from tvm.contrib.download import download
 
-import torch
-import torchvision
-
 in_size = 300
 
 
diff --git a/tests/python/frontend/pytorch/test_rnns.py 
b/tests/python/frontend/pytorch/test_rnns.py
index b5784a6..b0180a7 100644
--- a/tests/python/frontend/pytorch/test_rnns.py
+++ b/tests/python/frontend/pytorch/test_rnns.py
@@ -15,9 +15,9 @@
 # specific language governing permissions and limitations
 # under the License.
 
+import torch
 import tvm
 import tvm.testing
-import torch
 import onnx
 import io
 import sys

Reply via email to