This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch 
revert-19842-fix/conv-transpose-dilation-legalize
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 660a337c6b4a49a4b7ab8d1388a3b3efb9ed4741
Author: Tianqi Chen <[email protected]>
AuthorDate: Tue Jun 23 08:49:51 2026 -0400

    Revert "[Relax] Legalize dilated conv_transpose (#19842)"
    
    This reverts commit 9808108e48af413a03ec35e512939a522132176b.
---
 python/tvm/relax/transform/legalize_ops/nn.py      | 89 +++++++++++-----------
 .../python/relax/test_transform_legalize_ops_nn.py | 52 -------------
 2 files changed, 45 insertions(+), 96 deletions(-)

diff --git a/python/tvm/relax/transform/legalize_ops/nn.py 
b/python/tvm/relax/transform/legalize_ops/nn.py
index 6116a41e76..d68426f02a 100644
--- a/python/tvm/relax/transform/legalize_ops/nn.py
+++ b/python/tvm/relax/transform/legalize_ops/nn.py
@@ -164,23 +164,24 @@ def _nn_conv1d_transpose(bb: BlockBuilder, call: Call) -> 
Expr:
             "and kernel layout other than IOW, so cannot be legalized by TOPI"
         )
         return call
-    strides = [int(s) for s in call.attrs.strides]
-    padding = [int(p) for p in call.attrs.padding]
-    output_padding = [int(o) for o in call.attrs.output_padding]
-    groups = int(call.attrs.groups)
-    out_dtype = call.ty.dtype
-    dilation = [int(d) for d in call.attrs.dilation]
-
-    def te_conv1d_transpose(data, kernel):
-        # Dilated transposed conv == transposed conv with a spatially dilated 
(zero-filled) kernel.
-        if any(d != 1 for d in dilation):
-            kernel = topi.nn.dilate(kernel, [1, 1, dilation[0]], 
name="kernel_dilate")
-        return topi.nn.group_conv1d_transpose_ncw(
-            data, kernel, strides, padding, out_dtype, output_padding, groups
+    dilation = call.attrs.dilation
+    if len(dilation) != 1 or dilation[0] != 1:
+        logging.info(
+            "TOPI conv1d_transpose does not support dilations other than 1, "
+            "and thus cannot be legalized by TOPI"
         )
+        return call
 
     return bb.call_te(
-        te_conv1d_transpose, call.args[0], call.args[1], 
primfunc_name_hint="conv1d_transpose"
+        topi.nn.group_conv1d_transpose_ncw,
+        call.args[0],
+        call.args[1],
+        stride=call.attrs.strides,
+        padding=call.attrs.padding,
+        out_dtype=call.ty.dtype,
+        output_padding=call.attrs.output_padding,
+        groups=call.attrs.groups,
+        primfunc_name_hint="conv1d_transpose",
     )
 
 
@@ -198,23 +199,24 @@ def _nn_conv2d_transpose(bb: BlockBuilder, call: Call) -> 
Expr:
             "and kernel layout other than IOHW, so cannot be legalized by TOPI"
         )
         return call
-    strides = [int(s) for s in call.attrs.strides]
-    padding = [int(p) for p in call.attrs.padding]
-    output_padding = [int(o) for o in call.attrs.output_padding]
-    groups = int(call.attrs.groups)
-    out_dtype = call.ty.dtype
-    dilation = [int(d) for d in call.attrs.dilation]
-
-    def te_conv2d_transpose(data, kernel):
-        # Dilated transposed conv == transposed conv with a spatially dilated 
(zero-filled) kernel.
-        if any(d != 1 for d in dilation):
-            kernel = topi.nn.dilate(kernel, [1, 1, dilation[0], dilation[1]], 
name="kernel_dilate")
-        return topi.nn.group_conv2d_transpose_nchw(
-            data, kernel, strides, padding, out_dtype, output_padding, groups
+    dilation = call.attrs.dilation
+    if len(dilation) != 2 or any(d != 1 for d in dilation):
+        logging.info(
+            "TOPI conv2d_transpose does not support dilations other than 1, "
+            "and thus cannot be legalized by TOPI"
         )
+        return call
 
     return bb.call_te(
-        te_conv2d_transpose, call.args[0], call.args[1], 
primfunc_name_hint="conv2d_transpose"
+        topi.nn.group_conv2d_transpose_nchw,
+        call.args[0],
+        call.args[1],
+        stride=call.attrs.strides,
+        padding=call.attrs.padding,
+        out_dtype=call.ty.dtype,
+        output_padding=call.attrs.output_padding,
+        groups=call.attrs.groups,
+        primfunc_name_hint="conv2d_transpose",
     )
 
 
@@ -234,25 +236,24 @@ def _nn_conv3d_transpose(bb: BlockBuilder, call: Call) -> 
Expr:
             "and kernel layout other than IODHW, so cannot be legalized by 
TOPI"
         )
         return call
-    strides = [int(s) for s in call.attrs.strides]
-    padding = [int(p) for p in call.attrs.padding]
-    output_padding = [int(o) for o in call.attrs.output_padding]
-    groups = int(call.attrs.groups)
-    out_dtype = call.ty.dtype
-    dilation = [int(d) for d in call.attrs.dilation]
-
-    def te_conv3d_transpose(data, kernel):
-        # Dilated transposed conv == transposed conv with a spatially dilated 
(zero-filled) kernel.
-        if any(d != 1 for d in dilation):
-            kernel = topi.nn.dilate(
-                kernel, [1, 1, dilation[0], dilation[1], dilation[2]], 
name="kernel_dilate"
-            )
-        return topi.nn.group_conv3d_transpose_ncdhw(
-            data, kernel, strides, padding, out_dtype, output_padding, groups
+    dilation = call.attrs.dilation
+    if len(dilation) != 3 or any(d != 1 for d in dilation):
+        logging.info(
+            "TOPI conv3d_transpose does not support dilations other than 1, "
+            "and thus cannot be legalized by TOPI"
         )
+        return call
 
     return bb.call_te(
-        te_conv3d_transpose, call.args[0], call.args[1], 
primfunc_name_hint="conv3d_transpose"
+        topi.nn.group_conv3d_transpose_ncdhw,
+        call.args[0],
+        call.args[1],
+        strides=call.attrs.strides,
+        padding=call.attrs.padding,
+        out_dtype=call.ty.dtype,
+        output_padding=call.attrs.output_padding,
+        groups=call.attrs.groups,
+        primfunc_name_hint="conv3d_transpose",
     )
 
 
diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py 
b/tests/python/relax/test_transform_legalize_ops_nn.py
index 88621b9067..601985f7be 100644
--- a/tests/python/relax/test_transform_legalize_ops_nn.py
+++ b/tests/python/relax/test_transform_legalize_ops_nn.py
@@ -725,58 +725,6 @@ def test_conv2d_transpose_symbolic():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
-def test_conv2d_transpose_dilation():
-    # fmt: off
-    @tvm.script.ir_module
-    class Conv2dTranspose:
-        @R.function
-        def main(x: R.Tensor((1, 1, 3, 3), "float32"), w: R.Tensor((1, 1, 2, 
2), "float32")):
-            gv = R.nn.conv2d_transpose(x, w, dilation=(2, 2))
-            return gv
-
-    @I.ir_module(s_tir=True)
-    class Expected:
-        @T.prim_func(private=True, s_tir=True)
-        def conv2d_transpose(x: T.Buffer((T.int64(1), T.int64(1), T.int64(3), 
T.int64(3)), "float32"), w: T.Buffer((T.int64(1), T.int64(1), T.int64(2), 
T.int64(2)), "float32"), compute: T.Buffer((T.int64(1), T.int64(1), T.int64(5), 
T.int64(5)), "float32")):
-            T.func_attr({"tirx.noalias": True})
-            data_dilate = T.sblock_alloc_buffer((T.int64(1), T.int64(1), 
T.int64(3), T.int64(3)))
-            data_pad = T.sblock_alloc_buffer((T.int64(1), T.int64(1), 
T.int64(7), T.int64(7)))
-            kernel_dilate = T.sblock_alloc_buffer((T.int64(1), T.int64(1), 
T.int64(3), T.int64(3)))
-            kernel_transform = T.sblock_alloc_buffer((T.int64(1), T.int64(1), 
T.int64(3), T.int64(3)))
-            for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1), T.int64(3), 
T.int64(3)):
-                with T.sblock("data_dilate"):
-                    v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, 
i3])
-                    data_dilate[v_i0, v_i1, v_i2, v_i3] = x[v_i0, v_i1, v_i2, 
v_i3]
-            for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1), T.int64(7), 
T.int64(7)):
-                with T.sblock("data_pad"):
-                    v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, 
i3])
-                    data_pad[v_i0, v_i1, v_i2, v_i3] = 
T.if_then_else(T.int64(2) <= v_i2 and v_i2 < T.int64(5) and T.int64(2) <= v_i3 
and v_i3 < T.int64(5), data_dilate[v_i0, v_i1, v_i2 - T.int64(2), v_i3 - 
T.int64(2)], T.float32(0.0))
-            for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1), T.int64(3), 
T.int64(3)):
-                with T.sblock("kernel_dilate"):
-                    v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, 
i3])
-                    kernel_dilate[v_i0, v_i1, v_i2, v_i3] = 
T.if_then_else(v_i2 % T.int64(2) == T.int64(0) and v_i3 % T.int64(2) == 
T.int64(0), w[v_i0, v_i1, v_i2 // T.int64(2), v_i3 // T.int64(2)], 
T.float32(0.0))
-            for o, i, h, w_1 in T.grid(T.int64(1), T.int64(1), T.int64(3), 
T.int64(3)):
-                with T.sblock("kernel_transform"):
-                    v_o, v_i, v_h, v_w = T.axis.remap("SSSS", [o, i, h, w_1])
-                    kernel_transform[v_o, v_i, v_h, v_w] = kernel_dilate[v_i, 
v_o, T.int64(2) - v_h, T.int64(2) - v_w]
-            for b, c, h, w_1, dc, dh, dw in T.grid(T.int64(1), T.int64(1), 
T.int64(5), T.int64(5), T.int64(1), T.int64(3), T.int64(3)):
-                with T.sblock("compute"):
-                    v_b, v_c, v_h, v_w, v_dc, v_dh, v_dw = 
T.axis.remap("SSSSRRR", [b, c, h, w_1, dc, dh, dw])
-                    with T.init():
-                        compute[v_b, v_c, v_h, v_w] = T.float32(0.0)
-                    compute[v_b, v_c, v_h, v_w] = compute[v_b, v_c, v_h, v_w] 
+ data_pad[v_b, v_dc, v_h + v_dh, v_w + v_dw] * kernel_transform[v_c, v_dc, 
v_dh, v_dw]
-
-        @R.function
-        def main(x: R.Tensor((1, 1, 3, 3), dtype="float32"), w: R.Tensor((1, 
1, 2, 2), dtype="float32")) -> R.Tensor((1, 1, 5, 5), dtype="float32"):
-            cls = Expected
-            gv = R.call_tir(cls.conv2d_transpose, (x, w), out_ty=R.Tensor((1, 
1, 5, 5), dtype="float32"))
-            return gv
-    # fmt: on
-
-    mod = LegalizeOps()(Conv2dTranspose)
-    tvm.ir.assert_structural_equal(mod, Expected)
-
-
 def test_max_pool2d():
     # fmt: off
     @tvm.script.ir_module

Reply via email to