This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 21a00f544e [Bugfix][Relax][Pytorch] Bugfix of conv_transpose1d and
conv_transpose2d (#17968)
21a00f544e is described below
commit 21a00f544e52ad977fcc7a0b7f6ff5645c125bdd
Author: kavin-mcw <[email protected]>
AuthorDate: Wed May 14 15:53:57 2025 +0530
[Bugfix][Relax][Pytorch] Bugfix of conv_transpose1d and conv_transpose2d
(#17968)
* Bugfix conv_transpose1d and conv_transpose2d
* fix lint issue
* Update tests to reflect changes
* lint fix
---
.../tvm/relax/frontend/torch/base_fx_graph_translator.py | 16 ++++++++++++----
python/tvm/relax/frontend/torch/fx_translator.py | 2 ++
.../python/relax/test_frontend_from_exported_program.py | 12 ++++++++----
tests/python/relax/test_frontend_from_fx.py | 12 ++++++++----
4 files changed, 30 insertions(+), 12 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index f789eb8af3..5c8d7095e5 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -733,6 +733,7 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
padding: Optional[Tuple],
dilation: Optional[Tuple],
groups: Optional[Tuple],
+ output_padding: Optional[Tuple],
) -> relax.Var:
conv1d_transpose = self.block_builder.emit(
relax.op.nn.conv1d_transpose(
@@ -742,8 +743,9 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
padding=padding,
dilation=dilation,
groups=groups,
+ output_padding=output_padding,
data_layout="NCW",
- kernel_layout="OIW",
+ kernel_layout="IOW",
out_dtype="float32",
)
)
@@ -762,8 +764,9 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
bias = args[2] if len(args) > 2 else None
stride = args[3] if len(args) > 3 else 1
padding = args[4] if len(args) > 4 else 0
- dilation = args[5] if len(args) > 5 else 1
+ output_padding = args[5] if len(args) > 5 else 0
groups = args[6] if len(args) > 6 else 1
+ dilation = args[7] if len(args) > 7 else 1
return self._conv_transpose1d_impl(
x,
weight,
@@ -772,6 +775,7 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
padding=padding,
dilation=dilation,
groups=groups,
+ output_padding=output_padding,
)
def _conv_transpose2d_impl(
@@ -783,6 +787,7 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
padding: Optional[Tuple],
dilation: Optional[Tuple],
groups: Optional[Tuple],
+ output_padding: Optional[Tuple],
) -> relax.Var:
conv2d_transpose = self.block_builder.emit(
relax.op.nn.conv2d_transpose(
@@ -792,8 +797,9 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
padding=padding,
dilation=dilation,
groups=groups,
+ output_padding=output_padding,
data_layout="NCHW",
- kernel_layout="OIHW",
+ kernel_layout="IOHW",
out_dtype="float32",
)
)
@@ -812,8 +818,9 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
bias = args[2] if len(args) > 2 else None
stride = args[3] if len(args) > 3 else 1
padding = args[4] if len(args) > 4 else 0
- dilation = args[5] if len(args) > 5 else 1
+ output_padding = args[5] if len(args) > 5 else 0
groups = args[6] if len(args) > 6 else 1
+ dilation = args[7] if len(args) > 7 else 1
return self._conv_transpose2d_impl(
x,
weight,
@@ -822,6 +829,7 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
padding=padding,
dilation=dilation,
groups=groups,
+ output_padding=output_padding,
)
def _conv1d_impl(
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index fc12f877e0..97a2b51e49 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -294,6 +294,7 @@ class TorchFXImporter(BaseFXGraphImporter):
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
+ output_padding=module.output_padding,
)
def _conv_transpose2d_module(self, node: fx.Node) -> relax.Var:
@@ -310,6 +311,7 @@ class TorchFXImporter(BaseFXGraphImporter):
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
+ output_padding=module.output_padding,
)
def _conv1d_module(self, node: fx.Node) -> relax.Var:
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index f01da1336e..80da6fcf19 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -1870,9 +1870,10 @@ def test_conv_transpose1d():
w1,
strides=[1],
padding=[0, 0],
+ output_padding=[0],
dilation=[1],
data_layout="NCW",
- kernel_layout="OIW",
+ kernel_layout="IOW",
out_layout="NCW",
out_dtype="float32",
)
@@ -1904,9 +1905,10 @@ def test_conv_transpose1d():
w1,
strides=[1],
padding=[0, 0],
+ output_padding=[0],
dilation=[1],
data_layout="NCW",
- kernel_layout="OIW",
+ kernel_layout="IOW",
out_layout="NCW",
out_dtype="float32",
)
@@ -1962,9 +1964,10 @@ def test_conv_transpose2d():
w1,
strides=[1, 1],
padding=[0, 0, 0, 0],
+ output_padding=[0, 0],
dilation=[1, 1],
data_layout="NCHW",
- kernel_layout="OIHW",
+ kernel_layout="IOHW",
out_layout="NCHW",
out_dtype="float32",
)
@@ -1996,9 +1999,10 @@ def test_conv_transpose2d():
w1,
strides=[1, 1],
padding=[0, 0, 0, 0],
+ output_padding=[0, 0],
dilation=[1, 1],
data_layout="NCHW",
- kernel_layout="OIHW",
+ kernel_layout="IOHW",
out_layout="NCHW",
out_dtype="float32",
)
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index f507071b07..7fb2bed328 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -168,9 +168,10 @@ def test_conv1d_transpose():
w1,
strides=[1],
padding=[0, 0],
+ output_padding=[0],
dilation=[1],
data_layout="NCW",
- kernel_layout="OIW",
+ kernel_layout="IOW",
out_layout="NCW",
out_dtype="float32",
)
@@ -202,9 +203,10 @@ def test_conv1d_transpose():
w1,
strides=[1],
padding=[0, 0],
+ output_padding=[0],
dilation=[1],
data_layout="NCW",
- kernel_layout="OIW",
+ kernel_layout="IOW",
out_layout="NCW",
out_dtype="float32",
)
@@ -352,9 +354,10 @@ def test_conv2d_transpose():
w1,
strides=[1, 1],
padding=[0, 0, 0, 0],
+ output_padding=[0, 0],
dilation=[1, 1],
data_layout="NCHW",
- kernel_layout="OIHW",
+ kernel_layout="IOHW",
out_layout="NCHW",
out_dtype="float32",
)
@@ -386,9 +389,10 @@ def test_conv2d_transpose():
w1,
strides=[1, 1],
padding=[0, 0, 0, 0],
+ output_padding=[0, 0],
dilation=[1, 1],
data_layout="NCHW",
- kernel_layout="OIHW",
+ kernel_layout="IOHW",
out_layout="NCHW",
out_dtype="float32",
)