This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 9808108e48 [Relax] Legalize dilated conv_transpose (#19842)
9808108e48 is described below

commit 9808108e48af413a03ec35e512939a522132176b
Author: Guan-Ming Chiu <[email protected]>
AuthorDate: Tue Jun 23 12:28:20 2026 +0800

    [Relax] Legalize dilated conv_transpose (#19842)
    
    ## Why
    
    relax.nn.conv{1,2,3}d_transpose with dilation > 1 silently bailed in
    legalize and then crashed in VM codegen with an opaque error.
    
    ## How
    
    - Lower dilation > 1 by zero-filling (dilating) the kernel, then reusing
    the existing TOPI transposed-conv compute (1D/2D/3D).
    - Unsupported non-NCHW layouts and out_layout != data_layout keep their
    existing passthrough (left for downstream/BYOC codegen such as CLML),
    unchanged.
    - Add a 2D-dilation structural test.
    
    Signed-off-by: Guan-Ming (Wesley) Chiu 
<[email protected]>
---
 python/tvm/relax/transform/legalize_ops/nn.py      | 89 +++++++++++-----------
 .../python/relax/test_transform_legalize_ops_nn.py | 52 +++++++++++++
 2 files changed, 96 insertions(+), 45 deletions(-)

diff --git a/python/tvm/relax/transform/legalize_ops/nn.py 
b/python/tvm/relax/transform/legalize_ops/nn.py
index d68426f02a..6116a41e76 100644
--- a/python/tvm/relax/transform/legalize_ops/nn.py
+++ b/python/tvm/relax/transform/legalize_ops/nn.py
@@ -164,24 +164,23 @@ def _nn_conv1d_transpose(bb: BlockBuilder, call: Call) -> 
Expr:
             "and kernel layout other than IOW, so cannot be legalized by TOPI"
         )
         return call
-    dilation = call.attrs.dilation
-    if len(dilation) != 1 or dilation[0] != 1:
-        logging.info(
-            "TOPI conv1d_transpose does not support dilations other than 1, "
-            "and thus cannot be legalized by TOPI"
+    strides = [int(s) for s in call.attrs.strides]
+    padding = [int(p) for p in call.attrs.padding]
+    output_padding = [int(o) for o in call.attrs.output_padding]
+    groups = int(call.attrs.groups)
+    out_dtype = call.ty.dtype
+    dilation = [int(d) for d in call.attrs.dilation]
+
+    def te_conv1d_transpose(data, kernel):
+        # Dilated transposed conv == transposed conv with a spatially dilated 
(zero-filled) kernel.
+        if any(d != 1 for d in dilation):
+            kernel = topi.nn.dilate(kernel, [1, 1, dilation[0]], 
name="kernel_dilate")
+        return topi.nn.group_conv1d_transpose_ncw(
+            data, kernel, strides, padding, out_dtype, output_padding, groups
         )
-        return call
 
     return bb.call_te(
-        topi.nn.group_conv1d_transpose_ncw,
-        call.args[0],
-        call.args[1],
-        stride=call.attrs.strides,
-        padding=call.attrs.padding,
-        out_dtype=call.ty.dtype,
-        output_padding=call.attrs.output_padding,
-        groups=call.attrs.groups,
-        primfunc_name_hint="conv1d_transpose",
+        te_conv1d_transpose, call.args[0], call.args[1], 
primfunc_name_hint="conv1d_transpose"
     )
 
 
@@ -199,24 +198,23 @@ def _nn_conv2d_transpose(bb: BlockBuilder, call: Call) -> 
Expr:
             "and kernel layout other than IOHW, so cannot be legalized by TOPI"
         )
         return call
-    dilation = call.attrs.dilation
-    if len(dilation) != 2 or any(d != 1 for d in dilation):
-        logging.info(
-            "TOPI conv2d_transpose does not support dilations other than 1, "
-            "and thus cannot be legalized by TOPI"
+    strides = [int(s) for s in call.attrs.strides]
+    padding = [int(p) for p in call.attrs.padding]
+    output_padding = [int(o) for o in call.attrs.output_padding]
+    groups = int(call.attrs.groups)
+    out_dtype = call.ty.dtype
+    dilation = [int(d) for d in call.attrs.dilation]
+
+    def te_conv2d_transpose(data, kernel):
+        # Dilated transposed conv == transposed conv with a spatially dilated 
(zero-filled) kernel.
+        if any(d != 1 for d in dilation):
+            kernel = topi.nn.dilate(kernel, [1, 1, dilation[0], dilation[1]], 
name="kernel_dilate")
+        return topi.nn.group_conv2d_transpose_nchw(
+            data, kernel, strides, padding, out_dtype, output_padding, groups
         )
-        return call
 
     return bb.call_te(
-        topi.nn.group_conv2d_transpose_nchw,
-        call.args[0],
-        call.args[1],
-        stride=call.attrs.strides,
-        padding=call.attrs.padding,
-        out_dtype=call.ty.dtype,
-        output_padding=call.attrs.output_padding,
-        groups=call.attrs.groups,
-        primfunc_name_hint="conv2d_transpose",
+        te_conv2d_transpose, call.args[0], call.args[1], 
primfunc_name_hint="conv2d_transpose"
     )
 
 
@@ -236,24 +234,25 @@ def _nn_conv3d_transpose(bb: BlockBuilder, call: Call) -> 
Expr:
             "and kernel layout other than IODHW, so cannot be legalized by 
TOPI"
         )
         return call
-    dilation = call.attrs.dilation
-    if len(dilation) != 3 or any(d != 1 for d in dilation):
-        logging.info(
-            "TOPI conv3d_transpose does not support dilations other than 1, "
-            "and thus cannot be legalized by TOPI"
+    strides = [int(s) for s in call.attrs.strides]
+    padding = [int(p) for p in call.attrs.padding]
+    output_padding = [int(o) for o in call.attrs.output_padding]
+    groups = int(call.attrs.groups)
+    out_dtype = call.ty.dtype
+    dilation = [int(d) for d in call.attrs.dilation]
+
+    def te_conv3d_transpose(data, kernel):
+        # Dilated transposed conv == transposed conv with a spatially dilated 
(zero-filled) kernel.
+        if any(d != 1 for d in dilation):
+            kernel = topi.nn.dilate(
+                kernel, [1, 1, dilation[0], dilation[1], dilation[2]], 
name="kernel_dilate"
+            )
+        return topi.nn.group_conv3d_transpose_ncdhw(
+            data, kernel, strides, padding, out_dtype, output_padding, groups
         )
-        return call
 
     return bb.call_te(
-        topi.nn.group_conv3d_transpose_ncdhw,
-        call.args[0],
-        call.args[1],
-        strides=call.attrs.strides,
-        padding=call.attrs.padding,
-        out_dtype=call.ty.dtype,
-        output_padding=call.attrs.output_padding,
-        groups=call.attrs.groups,
-        primfunc_name_hint="conv3d_transpose",
+        te_conv3d_transpose, call.args[0], call.args[1], 
primfunc_name_hint="conv3d_transpose"
     )
 
 
diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py 
b/tests/python/relax/test_transform_legalize_ops_nn.py
index 601985f7be..88621b9067 100644
--- a/tests/python/relax/test_transform_legalize_ops_nn.py
+++ b/tests/python/relax/test_transform_legalize_ops_nn.py
@@ -725,6 +725,58 @@ def test_conv2d_transpose_symbolic():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_conv2d_transpose_dilation():
+    # fmt: off
+    @tvm.script.ir_module
+    class Conv2dTranspose:
+        @R.function
+        def main(x: R.Tensor((1, 1, 3, 3), "float32"), w: R.Tensor((1, 1, 2, 
2), "float32")):
+            gv = R.nn.conv2d_transpose(x, w, dilation=(2, 2))
+            return gv
+
+    @I.ir_module(s_tir=True)
+    class Expected:
+        @T.prim_func(private=True, s_tir=True)
+        def conv2d_transpose(x: T.Buffer((T.int64(1), T.int64(1), T.int64(3), 
T.int64(3)), "float32"), w: T.Buffer((T.int64(1), T.int64(1), T.int64(2), 
T.int64(2)), "float32"), compute: T.Buffer((T.int64(1), T.int64(1), T.int64(5), 
T.int64(5)), "float32")):
+            T.func_attr({"tirx.noalias": True})
+            data_dilate = T.sblock_alloc_buffer((T.int64(1), T.int64(1), 
T.int64(3), T.int64(3)))
+            data_pad = T.sblock_alloc_buffer((T.int64(1), T.int64(1), 
T.int64(7), T.int64(7)))
+            kernel_dilate = T.sblock_alloc_buffer((T.int64(1), T.int64(1), 
T.int64(3), T.int64(3)))
+            kernel_transform = T.sblock_alloc_buffer((T.int64(1), T.int64(1), 
T.int64(3), T.int64(3)))
+            for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1), T.int64(3), 
T.int64(3)):
+                with T.sblock("data_dilate"):
+                    v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, 
i3])
+                    data_dilate[v_i0, v_i1, v_i2, v_i3] = x[v_i0, v_i1, v_i2, 
v_i3]
+            for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1), T.int64(7), 
T.int64(7)):
+                with T.sblock("data_pad"):
+                    v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, 
i3])
+                    data_pad[v_i0, v_i1, v_i2, v_i3] = 
T.if_then_else(T.int64(2) <= v_i2 and v_i2 < T.int64(5) and T.int64(2) <= v_i3 
and v_i3 < T.int64(5), data_dilate[v_i0, v_i1, v_i2 - T.int64(2), v_i3 - 
T.int64(2)], T.float32(0.0))
+            for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1), T.int64(3), 
T.int64(3)):
+                with T.sblock("kernel_dilate"):
+                    v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, 
i3])
+                    kernel_dilate[v_i0, v_i1, v_i2, v_i3] = 
T.if_then_else(v_i2 % T.int64(2) == T.int64(0) and v_i3 % T.int64(2) == 
T.int64(0), w[v_i0, v_i1, v_i2 // T.int64(2), v_i3 // T.int64(2)], 
T.float32(0.0))
+            for o, i, h, w_1 in T.grid(T.int64(1), T.int64(1), T.int64(3), 
T.int64(3)):
+                with T.sblock("kernel_transform"):
+                    v_o, v_i, v_h, v_w = T.axis.remap("SSSS", [o, i, h, w_1])
+                    kernel_transform[v_o, v_i, v_h, v_w] = kernel_dilate[v_i, 
v_o, T.int64(2) - v_h, T.int64(2) - v_w]
+            for b, c, h, w_1, dc, dh, dw in T.grid(T.int64(1), T.int64(1), 
T.int64(5), T.int64(5), T.int64(1), T.int64(3), T.int64(3)):
+                with T.sblock("compute"):
+                    v_b, v_c, v_h, v_w, v_dc, v_dh, v_dw = 
T.axis.remap("SSSSRRR", [b, c, h, w_1, dc, dh, dw])
+                    with T.init():
+                        compute[v_b, v_c, v_h, v_w] = T.float32(0.0)
+                    compute[v_b, v_c, v_h, v_w] = compute[v_b, v_c, v_h, v_w] 
+ data_pad[v_b, v_dc, v_h + v_dh, v_w + v_dw] * kernel_transform[v_c, v_dc, 
v_dh, v_dw]
+
+        @R.function
+        def main(x: R.Tensor((1, 1, 3, 3), dtype="float32"), w: R.Tensor((1, 
1, 2, 2), dtype="float32")) -> R.Tensor((1, 1, 5, 5), dtype="float32"):
+            cls = Expected
+            gv = R.call_tir(cls.conv2d_transpose, (x, w), out_ty=R.Tensor((1, 
1, 5, 5), dtype="float32"))
+            return gv
+    # fmt: on
+
+    mod = LegalizeOps()(Conv2dTranspose)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 def test_max_pool2d():
     # fmt: off
     @tvm.script.ir_module

Reply via email to