JCBrouwer commented on a change in pull request #10423:
URL: https://github.com/apache/tvm/pull/10423#discussion_r816752900



##########
File path: python/tvm/topi/cuda/conv2d_transpose.py
##########
@@ -171,124 +182,361 @@ def _fallback_schedule(N, F, Y, X):
         cfg["unroll_explicit"] = OtherOptionEntity(True)
         cfg["auto_unroll_max_step"] = OtherOptionEntity(1500)
 
+    pad_data = op.input_tensors[0]
+    kernel = op.input_tensors[1]
+    conv = op.output(0)
+
+    ##### space definition begin #####
+    n, f, y, x = s[conv].op.axis
+    rc = s[conv].op.reduce_axis[0]
+    # TODO(@kevinthesun): Support tuning/optimization for dynamic shape.
+    bs = pad_data.shape[0]
+    n_tuning_axis = n if isinstance(bs, tvm.tir.IntImm) else 1
+    cfg.define_split("tile_n", cfg.axis(n_tuning_axis), num_outputs=4)
+    cfg.define_split("tile_f", cfg.axis(f), num_outputs=4)
+    cfg.define_split("tile_y", cfg.axis(y), num_outputs=4)
+    cfg.define_split("tile_x", cfg.axis(x), num_outputs=4)
+    cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=3)
+    cfg.define_knob("auto_unroll_max_step", [64, 512, 1500])
+
+    target = tvm.target.Target.current()
+    if target.kind.name in ["nvptx", "rocm"]:
+        cfg.define_knob("unroll_explicit", [1])
+    else:
+        cfg.define_knob("unroll_explicit", [0, 1])
+
+    if cfg.is_fallback:
+        N, F, Y, X = get_const_tuple(conv.shape)
+        if not isinstance(N, int):
+            N = 1
+        _fallback_schedule(N, F, Y, X)
+
+    ##### space definition end #####
+
+    if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag:
+        s[kernel].compute_inline()
+
+    if conv.op in s.outputs:
+        output = conv
+        OL = s.cache_write(conv, "local")
+    else:
+        output = s.outputs[0].output(0)
+        s[conv].set_scope("local")
+        OL = conv
+
+    # create cache stage
+    s[pad_data].set_scope("shared")
+    AA = pad_data
+    WW = s.cache_read(kernel, "shared", [OL])
+
+    # tile and bind spatial axes
+    n, f, y, x = s[output].op.axis
+    kernel_scope, n = s[output].split(n, nparts=1)
+    bn, vn, tn, ni = cfg["tile_n"].apply(s, output, n)
+    bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f)
+    by, vy, ty, yi = cfg["tile_y"].apply(s, output, y)
+    bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)
+
+    s[output].reorder(bn, bf, by, bx, vn, vf, vy, vx, tn, tf, ty, tx, ni, fi, 
yi, xi)
+    s[output].bind(bn, te.thread_axis("blockIdx.z"))
+    s[output].bind(bf, te.thread_axis("blockIdx.y"))
+    s[output].bind(s[output].fuse(by, bx), te.thread_axis("blockIdx.x"))
+    s[output].bind(vn, te.thread_axis("vthread"))
+    s[output].bind(vf, te.thread_axis("vthread"))
+    s[output].bind(vy, te.thread_axis("vthread"))
+    s[output].bind(vx, te.thread_axis("vthread"))
+
+    cfg.define_knob("fuse_yx", [0, 1])  # fuse ty,tx or tn,tf
+
+    if cfg["fuse_yx"].val:
+        s[output].bind(tn, te.thread_axis("threadIdx.z"))
+        s[output].bind(tf, te.thread_axis("threadIdx.y"))
+        tyx = s[output].fuse(ty, tx)
+        s[output].bind(s[output].fuse(ty, tx), te.thread_axis("threadIdx.x"))
+        s[OL].compute_at(s[output], tyx)
+
+        # number of threads
+        n_tz = cfg["tile_n"].size[2]
+        n_ty = cfg["tile_f"].size[2]
+        n_tx = cfg["tile_y"].size[2] * cfg["tile_x"].size[2]
+    else:
+        s[output].bind(s[output].fuse(tn, tf), te.thread_axis("threadIdx.z"))
+        s[output].bind(ty, te.thread_axis("threadIdx.y"))
+        s[output].bind(tx, te.thread_axis("threadIdx.x"))
+        s[OL].compute_at(s[output], tx)
+
+        # number of threads
+        n_tz = cfg["tile_n"].size[2] * cfg["tile_f"].size[2]
+        n_ty = cfg["tile_y"].size[2]
+        n_tx = cfg["tile_x"].size[2]
+
+    # tile reduction axes
+    n, f, y, x = s[OL].op.axis
+    rc, ry, rx = s[OL].op.reduce_axis
+    rco, rcm, rci = cfg["tile_rc"].apply(s, OL, rc)
+    s[OL].reorder(rco, rcm, ry, rx, rci, n, f, y, x)
+
+    s[AA].compute_at(s[OL], rx)
+    s[WW].compute_at(s[OL], rx)
+
+    # cooperative fetching
+    for load in [AA, WW]:
+        n, f, y, x = s[load].op.axis
+        fused = s[load].fuse(f, y, x)
+        tz, fused = s[load].split(fused, nparts=n_tz)
+        ty, fused = s[load].split(fused, nparts=n_ty)
+        tx, fused = s[load].split(fused, nparts=n_tx)
+        s[load].bind(tz, te.thread_axis("threadIdx.z"))
+        s[load].bind(ty, te.thread_axis("threadIdx.y"))
+        s[load].bind(tx, te.thread_axis("threadIdx.x"))
+
+    s[output].pragma(kernel_scope, "auto_unroll_max_step", 
cfg["auto_unroll_max_step"].val)
+    s[output].pragma(kernel_scope, "unroll_explicit", 
cfg["unroll_explicit"].val)
+
+
[email protected]_topi_compute("group_conv2d_transpose_nchw.cuda")
+def group_conv2d_transpose_nchw(cfg, data, kernel, stride, padding, out_dtype, 
output_padding, groups):
+    """Transposed 2D convolution nchw forward operator.
+
+    Parameters
+    ----------
+    cfg: ConfigEntity
+        The config for this template
+    Input : tvm.te.Tensor
+        4-D with shape [batch, in_channel, in_height, in_width]
+    Filter : tvm.te.Tensor
+        4-D with shape [in_channel, num_filter, filter_height, filter_width]
+    strides : tuple of two ints
+        The spatial stride along height and width
+    padding : int or str
+        Padding size, or ['VALID', 'SAME']
+    out_dtype: str
+        The output type. This is used in mixed precision
+    output_padding : tuple of two ints
+        Used to disambiguate output shape.
+    groups : int
+        number of groups
+
+    Returns
+    -------
+    Output : tvm.te.Tensor
+        4-D with shape [batch, out_channel, out_height, out_width]
+    """
+    if groups == 1:
+        return conv2d_transpose_nchw(data, kernel, stride, padding, out_dtype, 
output_padding)
+
+    batch, inp_channels, inp_height, inp_width = get_const_tuple(data.shape)
+    assert inp_channels % groups == 0, f"input channels {inp_channels} must 
divide group size {groups}"
+    _, out_channels, kernel_height, kernel_width = 
get_const_tuple(kernel.shape)
+    stride_height, stride_width = stride
+    outpad_height, outpad_width = output_padding
+    assert outpad_height < stride_height and outpad_width < stride_width
+    cfg.stride = stride
+    pad_top, pad_left, pad_bottom, pad_right = nn.get_pad_tuple(
+        padding, (kernel_height, kernel_width)
+    )
+
+    out_width = (inp_width - 1) * stride_width + kernel_width - pad_left - 
pad_right + outpad_width
+    pad_left = kernel_width - 1 - pad_left
+    pad_right = kernel_width - 1 - pad_right + outpad_width
+    dilated_width = stride_width * (inp_width - 1) + 1
+
+    out_height = (
+        (inp_height - 1) * stride_height + kernel_height - pad_top - 
pad_bottom + outpad_height
+    )
+    pad_top = kernel_height - 1 - pad_top
+    pad_bottom = kernel_height - 1 - pad_bottom + outpad_height
+    dilated_height = stride_height * (inp_height - 1) + 1
+
+    # compute pad
+    data = te.compute(
+        (
+            batch,
+            inp_channels,
+            pad_top + dilated_height + pad_bottom,
+            pad_left + dilated_width + pad_right,
+        ),
+        lambda n, c, y, x: tvm.tir.if_then_else(
+            tvm.tir.all(
+                x >= pad_left,
+                x < pad_left + dilated_width,
+                tvm.tir.indexmod(x - pad_left, stride_width).equal(0),
+                y >= pad_top,
+                y < pad_top + dilated_height,
+                tvm.tir.indexmod(y - pad_top, stride_height).equal(0),
+            ),
+            data[
+                n,
+                c,
+                tvm.tir.indexdiv(y - pad_top, stride_height),
+                tvm.tir.indexdiv(x - pad_left, stride_width),
+            ],
+            tvm.tir.const(0.0, data.dtype),
+        ),
+        name="data_pad",
+    )
+
+    # transform kernel layout from IOHW to OIHW, and rotate kernel by 180 
degrees
+    kernel_transform = te.compute(
+        (out_channels, inp_channels, kernel_height, kernel_width),
+        lambda i, o, h, w: kernel[o][i][kernel_height - 1 - h][kernel_width - 
1 - w],
+        name="kernel_transform",
+    )
+    
+    dc = te.reduce_axis((0, inp_channels // groups), name="dc")
+    dh = te.reduce_axis((0, kernel_height), name="dh")
+    dw = te.reduce_axis((0, kernel_width), name="dw")
+    data_out = te.compute(
+        (batch, out_channels * groups, out_height, out_width),
+        lambda b, c, h, w: te.sum(
+            data[
+                b, c // out_channels * (inp_channels // groups) + dc, h + dh, 
w + dw
+            ].astype(out_dtype)
+            * kernel_transform[
+                c % out_channels,
+                c // out_channels * (inp_channels // groups) + dc,
+                dh,
+                dw
+            ].astype(out_dtype),
+            axis=[dc, dh, dw],
+        ),
+        tag="group_conv2d_transpose_nchw",
+    )
+
+    return data_out
+
+
[email protected]_topi_schedule("group_conv2d_transpose_nchw.cuda")
+def schedule_group_conv2d_transpose_nchw(cfg, outs):
+    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
+    s = te.create_schedule([x.op for x in outs])
+
     def _callback(op):
-        if op.tag == "conv2d_transpose_nchw":
-            pad_data = op.input_tensors[0]
-            kernel = op.input_tensors[1]
-            conv = op.output(0)
-
-            ##### space definition begin #####
-            n, f, y, x = s[conv].op.axis
-            rc = s[conv].op.reduce_axis[0]
-            # TODO(@kevinthesun): Support tuning/optimization for dynamic 
shape.
-            bs = pad_data.shape[0]
-            n_tuning_axis = n if isinstance(bs, tvm.tir.IntImm) else 1
-            cfg.define_split("tile_n", cfg.axis(n_tuning_axis), num_outputs=4)
-            cfg.define_split("tile_f", cfg.axis(f), num_outputs=4)
-            cfg.define_split("tile_y", cfg.axis(y), num_outputs=4)
-            cfg.define_split("tile_x", cfg.axis(x), num_outputs=4)
-            cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=3)
-            cfg.define_knob("auto_unroll_max_step", [64, 512, 1500])
-
-            target = tvm.target.Target.current()
-            if target.kind.name in ["nvptx", "rocm"]:
-                cfg.define_knob("unroll_explicit", [1])
-            else:
-                cfg.define_knob("unroll_explicit", [0, 1])
-
-            if cfg.is_fallback:
-                N, F, Y, X = get_const_tuple(conv.shape)
-                if not isinstance(N, int):
-                    N = 1
-                _fallback_schedule(N, F, Y, X)
-
-            ##### space definition end #####
-
-            if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in 
kernel.op.tag:
-                s[kernel].compute_inline()
-
-            if conv.op in s.outputs:
-                output = conv
-                OL = s.cache_write(conv, "local")
-            else:
-                output = s.outputs[0].output(0)
-                s[conv].set_scope("local")
-                OL = conv
-
-            # create cache stage
-            s[pad_data].set_scope("shared")
-            AA = pad_data
-            WW = s.cache_read(kernel, "shared", [OL])
-
-            # tile and bind spatial axes
-            n, f, y, x = s[output].op.axis
-            kernel_scope, n = s[output].split(n, nparts=1)
-            bn, vn, tn, ni = cfg["tile_n"].apply(s, output, n)
-            bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f)
-            by, vy, ty, yi = cfg["tile_y"].apply(s, output, y)
-            bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)
-
-            s[output].reorder(bn, bf, by, bx, vn, vf, vy, vx, tn, tf, ty, tx, 
ni, fi, yi, xi)
-            s[output].bind(bn, te.thread_axis("blockIdx.z"))
-            s[output].bind(bf, te.thread_axis("blockIdx.y"))
-            s[output].bind(s[output].fuse(by, bx), 
te.thread_axis("blockIdx.x"))
-            s[output].bind(vn, te.thread_axis("vthread"))
-            s[output].bind(vf, te.thread_axis("vthread"))
-            s[output].bind(vy, te.thread_axis("vthread"))
-            s[output].bind(vx, te.thread_axis("vthread"))
-
-            cfg.define_knob("fuse_yx", [0, 1])  # fuse ty,tx or tn,tf
-
-            if cfg["fuse_yx"].val:
-                s[output].bind(tn, te.thread_axis("threadIdx.z"))
-                s[output].bind(tf, te.thread_axis("threadIdx.y"))
-                tyx = s[output].fuse(ty, tx)
-                s[output].bind(s[output].fuse(ty, tx), 
te.thread_axis("threadIdx.x"))
-                s[OL].compute_at(s[output], tyx)
-
-                # number of threads
-                n_tz = cfg["tile_n"].size[2]
-                n_ty = cfg["tile_f"].size[2]
-                n_tx = cfg["tile_y"].size[2] * cfg["tile_x"].size[2]
-            else:
-                s[output].bind(s[output].fuse(tn, tf), 
te.thread_axis("threadIdx.z"))
-                s[output].bind(ty, te.thread_axis("threadIdx.y"))
-                s[output].bind(tx, te.thread_axis("threadIdx.x"))
-                s[OL].compute_at(s[output], tx)
-
-                # number of threads
-                n_tz = cfg["tile_n"].size[2] * cfg["tile_f"].size[2]
-                n_ty = cfg["tile_y"].size[2]
-                n_tx = cfg["tile_x"].size[2]
-
-            # tile reduction axes
-            n, f, y, x = s[OL].op.axis
-            rc, ry, rx = s[OL].op.reduce_axis
-            rco, rcm, rci = cfg["tile_rc"].apply(s, OL, rc)
-            s[OL].reorder(rco, rcm, ry, rx, rci, n, f, y, x)
-
-            s[AA].compute_at(s[OL], rx)
-            s[WW].compute_at(s[OL], rx)
-
-            # cooperative fetching
-            for load in [AA, WW]:
-                n, f, y, x = s[load].op.axis
-                fused = s[load].fuse(f, y, x)
-                tz, fused = s[load].split(fused, nparts=n_tz)
-                ty, fused = s[load].split(fused, nparts=n_ty)
-                tx, fused = s[load].split(fused, nparts=n_tx)
-                s[load].bind(tz, te.thread_axis("threadIdx.z"))
-                s[load].bind(ty, te.thread_axis("threadIdx.y"))
-                s[load].bind(tx, te.thread_axis("threadIdx.x"))
-
-            s[output].pragma(kernel_scope, "auto_unroll_max_step", 
cfg["auto_unroll_max_step"].val)
-            s[output].pragma(kernel_scope, "unroll_explicit", 
cfg["unroll_explicit"].val)
+        if op.tag == "group_conv2d_transpose_nchw":
+            _schedule_group_conv2d_transpose_nchw(cfg, s, op)
+
+        elif op.tag == "conv2d_transpose_nchw":  # groups == 1 cases delegate 
to regular conv2d_transpose
+            _schedule_conv2d_transpose_nchw(cfg, s, op)
 
     traverse_inline(s, outs[0].op, _callback)
 
     return s
 
 
+def _schedule_group_conv2d_transpose_nchw(cfg, s, op):

Review comment:
       `_schedule_conv2d_transpose_nchw` gives me the following error when 
groups > 1
   ``` 
   RuntimeError: Memory verification failed with the following errors:
   Did you forget to bind?
       Variable `W` is directly accessed by host memory (it is not contained in 
a thread environment or in the function arguments.
   ```
   
   `_schedule_group_conv2d_transpose_nchw` has the following lines which I 
believe alleviate that:
   ```python
       ...
       s[kernel_transform].compute_inline()
   
       s[kernel_transform].set_scope("shared")
       WW = kernel_transform
       ...
   ```
   
   Also the tiling in `_schedule_group_conv2d_transpose_nchw` is different, 
similar to the group_conv2d_schedule:
   ```python
       cfg.define_split("tile_n", cfg.axis(n), num_outputs=4)
       cfg.define_split("tile_g", cfg.axis(groups), num_outputs=2)
       cfg.define_split("tile_f", cfg.axis(num_filters // groups), 
num_outputs=4)
       cfg.define_split("tile_y", cfg.axis(y), num_outputs=4)
       cfg.define_split("tile_x", cfg.axis(x), num_outputs=4)
       cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=3)
   ```
   vs `_schedule_conv2d_transpose_nchw`
   ```python
       cfg.define_split("tile_n", cfg.axis(n_tuning_axis), num_outputs=4)
       cfg.define_split("tile_f", cfg.axis(f), num_outputs=4)
       cfg.define_split("tile_y", cfg.axis(y), num_outputs=4)
       cfg.define_split("tile_x", cfg.axis(x), num_outputs=4)
       cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=3)
   ```
   This causes some more changes with reorder/binding later on.
   
   I could try merging them to one function with those changes conditioned on 
`groups` if you like.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to