This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 21a00f544e [Bugfix][Relax][Pytorch] Bugfix of conv_transpose1d and 
conv_transpose2d (#17968)
21a00f544e is described below

commit 21a00f544e52ad977fcc7a0b7f6ff5645c125bdd
Author: kavin-mcw <[email protected]>
AuthorDate: Wed May 14 15:53:57 2025 +0530

    [Bugfix][Relax][Pytorch] Bugfix of conv_transpose1d and conv_transpose2d 
(#17968)
    
    * Bugfix conv_transpose1d and conv_transpose2d
    
    * fix lint issue
    
    * Update tests to reflect changes
    
    * lint fix
---
 .../tvm/relax/frontend/torch/base_fx_graph_translator.py | 16 ++++++++++++----
 python/tvm/relax/frontend/torch/fx_translator.py         |  2 ++
 .../python/relax/test_frontend_from_exported_program.py  | 12 ++++++++----
 tests/python/relax/test_frontend_from_fx.py              | 12 ++++++++----
 4 files changed, 30 insertions(+), 12 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index f789eb8af3..5c8d7095e5 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -733,6 +733,7 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         padding: Optional[Tuple],
         dilation: Optional[Tuple],
         groups: Optional[Tuple],
+        output_padding: Optional[Tuple],
     ) -> relax.Var:
         conv1d_transpose = self.block_builder.emit(
             relax.op.nn.conv1d_transpose(
@@ -742,8 +743,9 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
                 padding=padding,
                 dilation=dilation,
                 groups=groups,
+                output_padding=output_padding,
                 data_layout="NCW",
-                kernel_layout="OIW",
+                kernel_layout="IOW",
                 out_dtype="float32",
             )
         )
@@ -762,8 +764,9 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         bias = args[2] if len(args) > 2 else None
         stride = args[3] if len(args) > 3 else 1
         padding = args[4] if len(args) > 4 else 0
-        dilation = args[5] if len(args) > 5 else 1
+        output_padding = args[5] if len(args) > 5 else 0
         groups = args[6] if len(args) > 6 else 1
+        dilation = args[7] if len(args) > 7 else 1
         return self._conv_transpose1d_impl(
             x,
             weight,
@@ -772,6 +775,7 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
             padding=padding,
             dilation=dilation,
             groups=groups,
+            output_padding=output_padding,
         )
 
     def _conv_transpose2d_impl(
@@ -783,6 +787,7 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         padding: Optional[Tuple],
         dilation: Optional[Tuple],
         groups: Optional[Tuple],
+        output_padding: Optional[Tuple],
     ) -> relax.Var:
         conv2d_transpose = self.block_builder.emit(
             relax.op.nn.conv2d_transpose(
@@ -792,8 +797,9 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
                 padding=padding,
                 dilation=dilation,
                 groups=groups,
+                output_padding=output_padding,
                 data_layout="NCHW",
-                kernel_layout="OIHW",
+                kernel_layout="IOHW",
                 out_dtype="float32",
             )
         )
@@ -812,8 +818,9 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         bias = args[2] if len(args) > 2 else None
         stride = args[3] if len(args) > 3 else 1
         padding = args[4] if len(args) > 4 else 0
-        dilation = args[5] if len(args) > 5 else 1
+        output_padding = args[5] if len(args) > 5 else 0
         groups = args[6] if len(args) > 6 else 1
+        dilation = args[7] if len(args) > 7 else 1
         return self._conv_transpose2d_impl(
             x,
             weight,
@@ -822,6 +829,7 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
             padding=padding,
             dilation=dilation,
             groups=groups,
+            output_padding=output_padding,
         )
 
     def _conv1d_impl(
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index fc12f877e0..97a2b51e49 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -294,6 +294,7 @@ class TorchFXImporter(BaseFXGraphImporter):
             padding=module.padding,
             dilation=module.dilation,
             groups=module.groups,
+            output_padding=module.output_padding,
         )
 
     def _conv_transpose2d_module(self, node: fx.Node) -> relax.Var:
@@ -310,6 +311,7 @@ class TorchFXImporter(BaseFXGraphImporter):
             padding=module.padding,
             dilation=module.dilation,
             groups=module.groups,
+            output_padding=module.output_padding,
         )
 
     def _conv1d_module(self, node: fx.Node) -> relax.Var:
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index f01da1336e..80da6fcf19 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -1870,9 +1870,10 @@ def test_conv_transpose1d():
                     w1,
                     strides=[1],
                     padding=[0, 0],
+                    output_padding=[0],
                     dilation=[1],
                     data_layout="NCW",
-                    kernel_layout="OIW",
+                    kernel_layout="IOW",
                     out_layout="NCW",
                     out_dtype="float32",
                 )
@@ -1904,9 +1905,10 @@ def test_conv_transpose1d():
                     w1,
                     strides=[1],
                     padding=[0, 0],
+                    output_padding=[0],
                     dilation=[1],
                     data_layout="NCW",
-                    kernel_layout="OIW",
+                    kernel_layout="IOW",
                     out_layout="NCW",
                     out_dtype="float32",
                 )
@@ -1962,9 +1964,10 @@ def test_conv_transpose2d():
                     w1,
                     strides=[1, 1],
                     padding=[0, 0, 0, 0],
+                    output_padding=[0, 0],
                     dilation=[1, 1],
                     data_layout="NCHW",
-                    kernel_layout="OIHW",
+                    kernel_layout="IOHW",
                     out_layout="NCHW",
                     out_dtype="float32",
                 )
@@ -1996,9 +1999,10 @@ def test_conv_transpose2d():
                     w1,
                     strides=[1, 1],
                     padding=[0, 0, 0, 0],
+                    output_padding=[0, 0],
                     dilation=[1, 1],
                     data_layout="NCHW",
-                    kernel_layout="OIHW",
+                    kernel_layout="IOHW",
                     out_layout="NCHW",
                     out_dtype="float32",
                 )
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index f507071b07..7fb2bed328 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -168,9 +168,10 @@ def test_conv1d_transpose():
                     w1,
                     strides=[1],
                     padding=[0, 0],
+                    output_padding=[0],
                     dilation=[1],
                     data_layout="NCW",
-                    kernel_layout="OIW",
+                    kernel_layout="IOW",
                     out_layout="NCW",
                     out_dtype="float32",
                 )
@@ -202,9 +203,10 @@ def test_conv1d_transpose():
                     w1,
                     strides=[1],
                     padding=[0, 0],
+                    output_padding=[0],
                     dilation=[1],
                     data_layout="NCW",
-                    kernel_layout="OIW",
+                    kernel_layout="IOW",
                     out_layout="NCW",
                     out_dtype="float32",
                 )
@@ -352,9 +354,10 @@ def test_conv2d_transpose():
                     w1,
                     strides=[1, 1],
                     padding=[0, 0, 0, 0],
+                    output_padding=[0, 0],
                     dilation=[1, 1],
                     data_layout="NCHW",
-                    kernel_layout="OIHW",
+                    kernel_layout="IOHW",
                     out_layout="NCHW",
                     out_dtype="float32",
                 )
@@ -386,9 +389,10 @@ def test_conv2d_transpose():
                     w1,
                     strides=[1, 1],
                     padding=[0, 0, 0, 0],
+                    output_padding=[0, 0],
                     dilation=[1, 1],
                     data_layout="NCHW",
-                    kernel_layout="OIHW",
+                    kernel_layout="IOHW",
                     out_layout="NCHW",
                     out_dtype="float32",
                 )

Reply via email to