masahi commented on a change in pull request #10423:
URL: https://github.com/apache/tvm/pull/10423#discussion_r816686353



##########
File path: python/tvm/topi/cuda/conv2d_transpose.py
##########
@@ -171,124 +182,361 @@ def _fallback_schedule(N, F, Y, X):
         cfg["unroll_explicit"] = OtherOptionEntity(True)
         cfg["auto_unroll_max_step"] = OtherOptionEntity(1500)
 
+    pad_data = op.input_tensors[0]
+    kernel = op.input_tensors[1]
+    conv = op.output(0)
+
+    ##### space definition begin #####
+    n, f, y, x = s[conv].op.axis
+    rc = s[conv].op.reduce_axis[0]
+    # TODO(@kevinthesun): Support tuning/optimization for dynamic shape.
+    bs = pad_data.shape[0]
+    n_tuning_axis = n if isinstance(bs, tvm.tir.IntImm) else 1
+    cfg.define_split("tile_n", cfg.axis(n_tuning_axis), num_outputs=4)
+    cfg.define_split("tile_f", cfg.axis(f), num_outputs=4)
+    cfg.define_split("tile_y", cfg.axis(y), num_outputs=4)
+    cfg.define_split("tile_x", cfg.axis(x), num_outputs=4)
+    cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=3)
+    cfg.define_knob("auto_unroll_max_step", [64, 512, 1500])
+
+    target = tvm.target.Target.current()
+    if target.kind.name in ["nvptx", "rocm"]:
+        cfg.define_knob("unroll_explicit", [1])
+    else:
+        cfg.define_knob("unroll_explicit", [0, 1])
+
+    if cfg.is_fallback:
+        N, F, Y, X = get_const_tuple(conv.shape)
+        if not isinstance(N, int):
+            N = 1
+        _fallback_schedule(N, F, Y, X)
+
+    ##### space definition end #####
+
+    if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag:
+        s[kernel].compute_inline()
+
+    if conv.op in s.outputs:
+        output = conv
+        OL = s.cache_write(conv, "local")
+    else:
+        output = s.outputs[0].output(0)
+        s[conv].set_scope("local")
+        OL = conv
+
+    # create cache stage
+    s[pad_data].set_scope("shared")
+    AA = pad_data
+    WW = s.cache_read(kernel, "shared", [OL])
+
+    # tile and bind spatial axes
+    n, f, y, x = s[output].op.axis
+    kernel_scope, n = s[output].split(n, nparts=1)
+    bn, vn, tn, ni = cfg["tile_n"].apply(s, output, n)
+    bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f)
+    by, vy, ty, yi = cfg["tile_y"].apply(s, output, y)
+    bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)
+
+    s[output].reorder(bn, bf, by, bx, vn, vf, vy, vx, tn, tf, ty, tx, ni, fi, 
yi, xi)
+    s[output].bind(bn, te.thread_axis("blockIdx.z"))
+    s[output].bind(bf, te.thread_axis("blockIdx.y"))
+    s[output].bind(s[output].fuse(by, bx), te.thread_axis("blockIdx.x"))
+    s[output].bind(vn, te.thread_axis("vthread"))
+    s[output].bind(vf, te.thread_axis("vthread"))
+    s[output].bind(vy, te.thread_axis("vthread"))
+    s[output].bind(vx, te.thread_axis("vthread"))
+
+    cfg.define_knob("fuse_yx", [0, 1])  # fuse ty,tx or tn,tf
+
+    if cfg["fuse_yx"].val:
+        s[output].bind(tn, te.thread_axis("threadIdx.z"))
+        s[output].bind(tf, te.thread_axis("threadIdx.y"))
+        tyx = s[output].fuse(ty, tx)
+        s[output].bind(s[output].fuse(ty, tx), te.thread_axis("threadIdx.x"))
+        s[OL].compute_at(s[output], tyx)
+
+        # number of threads
+        n_tz = cfg["tile_n"].size[2]
+        n_ty = cfg["tile_f"].size[2]
+        n_tx = cfg["tile_y"].size[2] * cfg["tile_x"].size[2]
+    else:
+        s[output].bind(s[output].fuse(tn, tf), te.thread_axis("threadIdx.z"))
+        s[output].bind(ty, te.thread_axis("threadIdx.y"))
+        s[output].bind(tx, te.thread_axis("threadIdx.x"))
+        s[OL].compute_at(s[output], tx)
+
+        # number of threads
+        n_tz = cfg["tile_n"].size[2] * cfg["tile_f"].size[2]
+        n_ty = cfg["tile_y"].size[2]
+        n_tx = cfg["tile_x"].size[2]
+
+    # tile reduction axes
+    n, f, y, x = s[OL].op.axis
+    rc, ry, rx = s[OL].op.reduce_axis
+    rco, rcm, rci = cfg["tile_rc"].apply(s, OL, rc)
+    s[OL].reorder(rco, rcm, ry, rx, rci, n, f, y, x)
+
+    s[AA].compute_at(s[OL], rx)
+    s[WW].compute_at(s[OL], rx)
+
+    # cooperative fetching
+    for load in [AA, WW]:
+        n, f, y, x = s[load].op.axis
+        fused = s[load].fuse(f, y, x)
+        tz, fused = s[load].split(fused, nparts=n_tz)
+        ty, fused = s[load].split(fused, nparts=n_ty)
+        tx, fused = s[load].split(fused, nparts=n_tx)
+        s[load].bind(tz, te.thread_axis("threadIdx.z"))
+        s[load].bind(ty, te.thread_axis("threadIdx.y"))
+        s[load].bind(tx, te.thread_axis("threadIdx.x"))
+
+    s[output].pragma(kernel_scope, "auto_unroll_max_step", 
cfg["auto_unroll_max_step"].val)
+    s[output].pragma(kernel_scope, "unroll_explicit", 
cfg["unroll_explicit"].val)
+
+
[email protected]_topi_compute("group_conv2d_transpose_nchw.cuda")
+def group_conv2d_transpose_nchw(cfg, data, kernel, stride, padding, out_dtype, 
output_padding, groups):

Review comment:
       I think you followed the existing code for cpu, but can you just add 
group support to the existing `conv2d_transpose_nchw`, rather than rolling 
another implementation?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to