mbaret commented on a change in pull request #8584:
URL: https://github.com/apache/tvm/pull/8584#discussion_r680801775
##########
File path: python/tvm/topi/mali/depthwise_conv2d.py
##########
@@ -51,86 +51,151 @@ def schedule_depthwise_conv2d_nchw(cfg, outs):
outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
s = te.create_schedule([x.op for x in outs])
- def _schedule(pad_data, kernel, conv):
- """schedule depthwise_conv2d"""
- max_unroll = 16
- vec_size = [1, 2, 4, 8, 16]
-
- ##### space definition begin #####
- n, c, y, x = s[conv].op.axis
- bc, tc, ci = cfg.define_split("tile_c", c, num_outputs=3)
- by, ty, yi = cfg.define_split("tile_y", y, num_outputs=3)
- bx, tx, xi = cfg.define_split("tile_x", x, num_outputs=3)
- cfg.define_annotate("ann_spatial", [ci, yi, xi],
policy="try_unroll_vec")
-
- # fallback support
- if cfg.is_fallback:
- ref_log = autotvm.tophub.load_reference_log(
- "mali", "rk3399", "depthwise_conv2d_nchw.mali"
- )
- cfg.fallback_with_reference_log(ref_log)
- ###### space definition end ######
-
- # schedule padding
- n, c, y, x = s[pad_data].op.axis
- tile_and_bind3d(s, pad_data, c, y, x, cfg["tile_c"].size[1], 1, 1)
-
- # schedule dilation
- if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in
kernel.op.tag:
- s[kernel].compute_inline()
-
- # schedule conv
- if conv.op not in s.outputs:
- s[conv].set_scope("local")
- OL = conv
- output = s.outputs[0].output(0)
- else:
- OL = s.cache_write(conv, "local")
- output = conv
-
- n, c, y, x = s[output].op.axis
- bc, tc, ci = cfg["tile_c"].apply(s, output, c)
- by, ty, yi = cfg["tile_y"].apply(s, output, y)
- bx, tx, xi = cfg["tile_x"].apply(s, output, x)
-
- bc = s[output].fuse(n, bc)
- s[output].bind(bc, te.thread_axis("blockIdx.z"))
- s[output].bind(tc, te.thread_axis("threadIdx.z"))
- s[output].bind(by, te.thread_axis("blockIdx.y"))
- s[output].bind(ty, te.thread_axis("threadIdx.y"))
- s[output].bind(bx, te.thread_axis("blockIdx.x"))
- s[output].bind(tx, te.thread_axis("threadIdx.x"))
-
- di, dj = s[OL].op.reduce_axis
- s[OL].unroll(di)
- s[OL].unroll(dj)
-
- s[OL].compute_at(s[output], tx)
- n, ci, yi, xi = s[OL].op.axis
-
- cfg["ann_spatial"].apply(
- s,
- OL,
- [ci, yi, xi],
- axis_lens=[cfg["tile_c"].size[2], cfg["tile_y"].size[2],
cfg["tile_x"].size[2]],
- max_unroll=max_unroll,
- vec_size=vec_size,
- cfg=cfg,
- )
-
def _callback(op):
"""traverse to find op to schedule"""
# schedule depthwise_conv2d
if op.tag == "depthwise_conv2d_nchw":
pad_data = op.input_tensors[0]
kernel = op.input_tensors[1]
conv = op.output(0)
- _schedule(pad_data, kernel, conv)
+ _schedule(cfg, s, pad_data, kernel, conv, "NCHW")
traverse_inline(s, outs[0].op, _callback)
return s
+# register original implementation of depthwise_conv2d_nhwc since we don't
need to change this part
[email protected]_topi_compute("depthwise_conv2d_nhwc.mali")
+def depthwise_conv2d_nhwc(cfg, data, kernel, strides, padding, dilation,
out_dtype):
+ return nn.depthwise_conv2d_nhwc(data, kernel, strides, padding, dilation,
out_dtype)
+
+
+# register customized schedule for arm cpu.
[email protected]_topi_schedule("depthwise_conv2d_nhwc.mali")
+def schedule_depthwise_conv2d_nhwc(cfg, outs):
+ """Schedule depthwise conv2d
+
+ Parameters
+ ----------
+ cfg: ConfigEntity
+ The configuration of this template
+ outs: Array of Tensor
+ The computation graph description of depthwise convolution2d
+ in the format of an array of tensors.
+
+ Returns
+ -------
+ s: Schedule
+ The computation schedule for depthwise_conv2d nchw.
+ """
+ outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
+ s = te.create_schedule([x.op for x in outs])
+
+ def _callback(op):
+ """traverse to find op to schedule"""
+ # schedule depthwise_conv2d
+ if op.tag == "depthwise_conv2d_nhwc":
+ pad_data = op.input_tensors[0]
+ kernel = op.input_tensors[1]
+ conv = op.output(0)
+ _schedule(cfg, s, pad_data, kernel, conv, "NHWC")
+
+ traverse_inline(s, outs[0].op, _callback)
+ return s
+
+
+def _schedule(cfg, s, pad_data, kernel, conv, layout):
+ """schedule depthwise_conv2d"""
+ assert layout in ("NCHW", "NHWC")
+
+ max_unroll = 16
+ vec_size = [1, 2, 4, 8, 16]
+
+ ##### space definition begin #####
+ if layout == "NCHW":
+ n, c, h, w = s[conv].op.axis
+ else:
+ n, h, w, c = s[conv].op.axis
+
+ bc, tc, ci = cfg.define_split("tile_c", c, num_outputs=3)
+ bh, th, hi = cfg.define_split("tile_y", h, num_outputs=3)
+ bw, tw, wi = cfg.define_split("tile_x", w, num_outputs=3)
+ cfg.define_annotate("ann_spatial", [ci, hi, wi], policy="try_unroll_vec")
+
+ # fallback support
+ if cfg.is_fallback:
+ ref_log = autotvm.tophub.load_reference_log("mali", "rk3399",
"depthwise_conv2d_nchw.mali")
Review comment:
I think I'd rather we set some sane but statically defined defaults
here. What do you think @jcf94?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]