masahi commented on a change in pull request #10423:
URL: https://github.com/apache/tvm/pull/10423#discussion_r817131237
##########
File path: python/tvm/topi/cuda/conv2d_transpose.py
##########
@@ -171,124 +182,361 @@ def _fallback_schedule(N, F, Y, X):
cfg["unroll_explicit"] = OtherOptionEntity(True)
cfg["auto_unroll_max_step"] = OtherOptionEntity(1500)
+ pad_data = op.input_tensors[0]
+ kernel = op.input_tensors[1]
+ conv = op.output(0)
+
+ ##### space definition begin #####
+ n, f, y, x = s[conv].op.axis
+ rc = s[conv].op.reduce_axis[0]
+ # TODO(@kevinthesun): Support tuning/optimization for dynamic shape.
+ bs = pad_data.shape[0]
+ n_tuning_axis = n if isinstance(bs, tvm.tir.IntImm) else 1
+ cfg.define_split("tile_n", cfg.axis(n_tuning_axis), num_outputs=4)
+ cfg.define_split("tile_f", cfg.axis(f), num_outputs=4)
+ cfg.define_split("tile_y", cfg.axis(y), num_outputs=4)
+ cfg.define_split("tile_x", cfg.axis(x), num_outputs=4)
+ cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=3)
+ cfg.define_knob("auto_unroll_max_step", [64, 512, 1500])
+
+ target = tvm.target.Target.current()
+ if target.kind.name in ["nvptx", "rocm"]:
+ cfg.define_knob("unroll_explicit", [1])
+ else:
+ cfg.define_knob("unroll_explicit", [0, 1])
+
+ if cfg.is_fallback:
+ N, F, Y, X = get_const_tuple(conv.shape)
+ if not isinstance(N, int):
+ N = 1
+ _fallback_schedule(N, F, Y, X)
+
+ ##### space definition end #####
+
+ if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag:
+ s[kernel].compute_inline()
+
+ if conv.op in s.outputs:
+ output = conv
+ OL = s.cache_write(conv, "local")
+ else:
+ output = s.outputs[0].output(0)
+ s[conv].set_scope("local")
+ OL = conv
+
+ # create cache stage
+ s[pad_data].set_scope("shared")
+ AA = pad_data
+ WW = s.cache_read(kernel, "shared", [OL])
+
+ # tile and bind spatial axes
+ n, f, y, x = s[output].op.axis
+ kernel_scope, n = s[output].split(n, nparts=1)
+ bn, vn, tn, ni = cfg["tile_n"].apply(s, output, n)
+ bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f)
+ by, vy, ty, yi = cfg["tile_y"].apply(s, output, y)
+ bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)
+
+ s[output].reorder(bn, bf, by, bx, vn, vf, vy, vx, tn, tf, ty, tx, ni, fi,
yi, xi)
+ s[output].bind(bn, te.thread_axis("blockIdx.z"))
+ s[output].bind(bf, te.thread_axis("blockIdx.y"))
+ s[output].bind(s[output].fuse(by, bx), te.thread_axis("blockIdx.x"))
+ s[output].bind(vn, te.thread_axis("vthread"))
+ s[output].bind(vf, te.thread_axis("vthread"))
+ s[output].bind(vy, te.thread_axis("vthread"))
+ s[output].bind(vx, te.thread_axis("vthread"))
+
+ cfg.define_knob("fuse_yx", [0, 1]) # fuse ty,tx or tn,tf
+
+ if cfg["fuse_yx"].val:
+ s[output].bind(tn, te.thread_axis("threadIdx.z"))
+ s[output].bind(tf, te.thread_axis("threadIdx.y"))
+ tyx = s[output].fuse(ty, tx)
+ s[output].bind(s[output].fuse(ty, tx), te.thread_axis("threadIdx.x"))
+ s[OL].compute_at(s[output], tyx)
+
+ # number of threads
+ n_tz = cfg["tile_n"].size[2]
+ n_ty = cfg["tile_f"].size[2]
+ n_tx = cfg["tile_y"].size[2] * cfg["tile_x"].size[2]
+ else:
+ s[output].bind(s[output].fuse(tn, tf), te.thread_axis("threadIdx.z"))
+ s[output].bind(ty, te.thread_axis("threadIdx.y"))
+ s[output].bind(tx, te.thread_axis("threadIdx.x"))
+ s[OL].compute_at(s[output], tx)
+
+ # number of threads
+ n_tz = cfg["tile_n"].size[2] * cfg["tile_f"].size[2]
+ n_ty = cfg["tile_y"].size[2]
+ n_tx = cfg["tile_x"].size[2]
+
+ # tile reduction axes
+ n, f, y, x = s[OL].op.axis
+ rc, ry, rx = s[OL].op.reduce_axis
+ rco, rcm, rci = cfg["tile_rc"].apply(s, OL, rc)
+ s[OL].reorder(rco, rcm, ry, rx, rci, n, f, y, x)
+
+ s[AA].compute_at(s[OL], rx)
+ s[WW].compute_at(s[OL], rx)
+
+ # cooperative fetching
+ for load in [AA, WW]:
+ n, f, y, x = s[load].op.axis
+ fused = s[load].fuse(f, y, x)
+ tz, fused = s[load].split(fused, nparts=n_tz)
+ ty, fused = s[load].split(fused, nparts=n_ty)
+ tx, fused = s[load].split(fused, nparts=n_tx)
+ s[load].bind(tz, te.thread_axis("threadIdx.z"))
+ s[load].bind(ty, te.thread_axis("threadIdx.y"))
+ s[load].bind(tx, te.thread_axis("threadIdx.x"))
+
+ s[output].pragma(kernel_scope, "auto_unroll_max_step",
cfg["auto_unroll_max_step"].val)
+ s[output].pragma(kernel_scope, "unroll_explicit",
cfg["unroll_explicit"].val)
+
+
[email protected]_topi_compute("group_conv2d_transpose_nchw.cuda")
+def group_conv2d_transpose_nchw(cfg, data, kernel, stride, padding, out_dtype,
output_padding, groups):
+ """Transposed 2D convolution nchw forward operator.
+
+ Parameters
+ ----------
+ cfg: ConfigEntity
+ The config for this template
+ Input : tvm.te.Tensor
+ 4-D with shape [batch, in_channel, in_height, in_width]
+ Filter : tvm.te.Tensor
+ 4-D with shape [in_channel, num_filter, filter_height, filter_width]
+ strides : tuple of two ints
+ The spatial stride along height and width
+ padding : int or str
+ Padding size, or ['VALID', 'SAME']
+ out_dtype: str
+ The output type. This is used in mixed precision
+ output_padding : tuple of two ints
+ Used to disambiguate output shape.
+ groups : int
+ number of groups
+
+ Returns
+ -------
+ Output : tvm.te.Tensor
+ 4-D with shape [batch, out_channel, out_height, out_width]
+ """
+ if groups == 1:
+ return conv2d_transpose_nchw(data, kernel, stride, padding, out_dtype,
output_padding)
+
+ batch, inp_channels, inp_height, inp_width = get_const_tuple(data.shape)
+ assert inp_channels % groups == 0, f"input channels {inp_channels} must
divide group size {groups}"
+ _, out_channels, kernel_height, kernel_width =
get_const_tuple(kernel.shape)
+ stride_height, stride_width = stride
+ outpad_height, outpad_width = output_padding
+ assert outpad_height < stride_height and outpad_width < stride_width
+ cfg.stride = stride
+ pad_top, pad_left, pad_bottom, pad_right = nn.get_pad_tuple(
+ padding, (kernel_height, kernel_width)
+ )
+
+ out_width = (inp_width - 1) * stride_width + kernel_width - pad_left -
pad_right + outpad_width
+ pad_left = kernel_width - 1 - pad_left
+ pad_right = kernel_width - 1 - pad_right + outpad_width
+ dilated_width = stride_width * (inp_width - 1) + 1
+
+ out_height = (
+ (inp_height - 1) * stride_height + kernel_height - pad_top -
pad_bottom + outpad_height
+ )
+ pad_top = kernel_height - 1 - pad_top
+ pad_bottom = kernel_height - 1 - pad_bottom + outpad_height
+ dilated_height = stride_height * (inp_height - 1) + 1
+
+ # compute pad
+ data = te.compute(
+ (
+ batch,
+ inp_channels,
+ pad_top + dilated_height + pad_bottom,
+ pad_left + dilated_width + pad_right,
+ ),
+ lambda n, c, y, x: tvm.tir.if_then_else(
+ tvm.tir.all(
+ x >= pad_left,
+ x < pad_left + dilated_width,
+ tvm.tir.indexmod(x - pad_left, stride_width).equal(0),
+ y >= pad_top,
+ y < pad_top + dilated_height,
+ tvm.tir.indexmod(y - pad_top, stride_height).equal(0),
+ ),
+ data[
+ n,
+ c,
+ tvm.tir.indexdiv(y - pad_top, stride_height),
+ tvm.tir.indexdiv(x - pad_left, stride_width),
+ ],
+ tvm.tir.const(0.0, data.dtype),
+ ),
+ name="data_pad",
+ )
+
+ # transform kernel layout from IOHW to OIHW, and rotate kernel by 180
degrees
+ kernel_transform = te.compute(
+ (out_channels, inp_channels, kernel_height, kernel_width),
+ lambda i, o, h, w: kernel[o][i][kernel_height - 1 - h][kernel_width -
1 - w],
+ name="kernel_transform",
+ )
+
+ dc = te.reduce_axis((0, inp_channels // groups), name="dc")
+ dh = te.reduce_axis((0, kernel_height), name="dh")
+ dw = te.reduce_axis((0, kernel_width), name="dw")
+ data_out = te.compute(
+ (batch, out_channels * groups, out_height, out_width),
+ lambda b, c, h, w: te.sum(
+ data[
+ b, c // out_channels * (inp_channels // groups) + dc, h + dh,
w + dw
+ ].astype(out_dtype)
+ * kernel_transform[
+ c % out_channels,
+ c // out_channels * (inp_channels // groups) + dc,
+ dh,
+ dw
+ ].astype(out_dtype),
+ axis=[dc, dh, dw],
+ ),
+ tag="group_conv2d_transpose_nchw",
+ )
+
+ return data_out
+
+
[email protected]_topi_schedule("group_conv2d_transpose_nchw.cuda")
+def schedule_group_conv2d_transpose_nchw(cfg, outs):
+ outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
+ s = te.create_schedule([x.op for x in outs])
+
def _callback(op):
- if op.tag == "conv2d_transpose_nchw":
- pad_data = op.input_tensors[0]
- kernel = op.input_tensors[1]
- conv = op.output(0)
-
- ##### space definition begin #####
- n, f, y, x = s[conv].op.axis
- rc = s[conv].op.reduce_axis[0]
- # TODO(@kevinthesun): Support tuning/optimization for dynamic
shape.
- bs = pad_data.shape[0]
- n_tuning_axis = n if isinstance(bs, tvm.tir.IntImm) else 1
- cfg.define_split("tile_n", cfg.axis(n_tuning_axis), num_outputs=4)
- cfg.define_split("tile_f", cfg.axis(f), num_outputs=4)
- cfg.define_split("tile_y", cfg.axis(y), num_outputs=4)
- cfg.define_split("tile_x", cfg.axis(x), num_outputs=4)
- cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=3)
- cfg.define_knob("auto_unroll_max_step", [64, 512, 1500])
-
- target = tvm.target.Target.current()
- if target.kind.name in ["nvptx", "rocm"]:
- cfg.define_knob("unroll_explicit", [1])
- else:
- cfg.define_knob("unroll_explicit", [0, 1])
-
- if cfg.is_fallback:
- N, F, Y, X = get_const_tuple(conv.shape)
- if not isinstance(N, int):
- N = 1
- _fallback_schedule(N, F, Y, X)
-
- ##### space definition end #####
-
- if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in
kernel.op.tag:
- s[kernel].compute_inline()
-
- if conv.op in s.outputs:
- output = conv
- OL = s.cache_write(conv, "local")
- else:
- output = s.outputs[0].output(0)
- s[conv].set_scope("local")
- OL = conv
-
- # create cache stage
- s[pad_data].set_scope("shared")
- AA = pad_data
- WW = s.cache_read(kernel, "shared", [OL])
-
- # tile and bind spatial axes
- n, f, y, x = s[output].op.axis
- kernel_scope, n = s[output].split(n, nparts=1)
- bn, vn, tn, ni = cfg["tile_n"].apply(s, output, n)
- bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f)
- by, vy, ty, yi = cfg["tile_y"].apply(s, output, y)
- bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)
-
- s[output].reorder(bn, bf, by, bx, vn, vf, vy, vx, tn, tf, ty, tx,
ni, fi, yi, xi)
- s[output].bind(bn, te.thread_axis("blockIdx.z"))
- s[output].bind(bf, te.thread_axis("blockIdx.y"))
- s[output].bind(s[output].fuse(by, bx),
te.thread_axis("blockIdx.x"))
- s[output].bind(vn, te.thread_axis("vthread"))
- s[output].bind(vf, te.thread_axis("vthread"))
- s[output].bind(vy, te.thread_axis("vthread"))
- s[output].bind(vx, te.thread_axis("vthread"))
-
- cfg.define_knob("fuse_yx", [0, 1]) # fuse ty,tx or tn,tf
-
- if cfg["fuse_yx"].val:
- s[output].bind(tn, te.thread_axis("threadIdx.z"))
- s[output].bind(tf, te.thread_axis("threadIdx.y"))
- tyx = s[output].fuse(ty, tx)
- s[output].bind(s[output].fuse(ty, tx),
te.thread_axis("threadIdx.x"))
- s[OL].compute_at(s[output], tyx)
-
- # number of threads
- n_tz = cfg["tile_n"].size[2]
- n_ty = cfg["tile_f"].size[2]
- n_tx = cfg["tile_y"].size[2] * cfg["tile_x"].size[2]
- else:
- s[output].bind(s[output].fuse(tn, tf),
te.thread_axis("threadIdx.z"))
- s[output].bind(ty, te.thread_axis("threadIdx.y"))
- s[output].bind(tx, te.thread_axis("threadIdx.x"))
- s[OL].compute_at(s[output], tx)
-
- # number of threads
- n_tz = cfg["tile_n"].size[2] * cfg["tile_f"].size[2]
- n_ty = cfg["tile_y"].size[2]
- n_tx = cfg["tile_x"].size[2]
-
- # tile reduction axes
- n, f, y, x = s[OL].op.axis
- rc, ry, rx = s[OL].op.reduce_axis
- rco, rcm, rci = cfg["tile_rc"].apply(s, OL, rc)
- s[OL].reorder(rco, rcm, ry, rx, rci, n, f, y, x)
-
- s[AA].compute_at(s[OL], rx)
- s[WW].compute_at(s[OL], rx)
-
- # cooperative fetching
- for load in [AA, WW]:
- n, f, y, x = s[load].op.axis
- fused = s[load].fuse(f, y, x)
- tz, fused = s[load].split(fused, nparts=n_tz)
- ty, fused = s[load].split(fused, nparts=n_ty)
- tx, fused = s[load].split(fused, nparts=n_tx)
- s[load].bind(tz, te.thread_axis("threadIdx.z"))
- s[load].bind(ty, te.thread_axis("threadIdx.y"))
- s[load].bind(tx, te.thread_axis("threadIdx.x"))
-
- s[output].pragma(kernel_scope, "auto_unroll_max_step",
cfg["auto_unroll_max_step"].val)
- s[output].pragma(kernel_scope, "unroll_explicit",
cfg["unroll_explicit"].val)
+ if op.tag == "group_conv2d_transpose_nchw":
+ _schedule_group_conv2d_transpose_nchw(cfg, s, op)
+
+ elif op.tag == "conv2d_transpose_nchw": # groups == 1 cases delegate
to regular conv2d_transpose
+ _schedule_conv2d_transpose_nchw(cfg, s, op)
traverse_inline(s, outs[0].op, _callback)
return s
+def _schedule_group_conv2d_transpose_nchw(cfg, s, op):
Review comment:
> although it won't tile over groups like group_conv2d
You can add such thing to the existing schedule. You can change the tag name
at
https://github.com/apache/tvm/blob/66f23e2afb31e88c7616b280423e0aeab5bce0e4/python/tvm/topi/cuda/conv2d_transpose.py#L126
when group > 1, and inside the existing schedule, you check the tag name
(`op.tag`) to decide if you enable tiling over groups.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]