csullivan commented on code in PR #11161: URL: https://github.com/apache/tvm/pull/11161#discussion_r870483136
########## python/tvm/topi/adreno/conv2d_nhwc.py: ########## @@ -0,0 +1,339 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-variable,unused-argument,no-else-return +"""conv2d nhwc schedule on Qualcomm Adreno GPU""" +import tvm +from tvm import te +from tvm import autotvm + +from ..utils import get_const_tuple, traverse_inline +from .utils import ( + split_to_chunks, + pack_input, + pack_filter, + expand_spatial_dimensions, + add_pad, + bind_data_copy, + get_texture_storage, +) + + [email protected]_topi_compute("conv2d_nhwc.image2d") +def conv2d_nhwc(cfg, data, kernel, strides, padding, dilation, out_dtype="float16"): + """Compute conv2d with NCHWc layout""" + args = {"shared": False, "accumulator": "float16"} + return compute_conv2d_NHWC_HWIO(data, kernel, strides, padding, dilation, out_dtype, args=args) + + [email protected]_topi_compute("conv2d_nhwc_acc32.image2d") +def conv2d_nhwc_acc32(cfg, data, kernel, strides, padding, dilation, out_dtype="float16"): + """Compute conv2d with NCHWc layout""" + args = {"shared": False, "accumulator": "float32"} + return compute_conv2d_NHWC_HWIO(data, kernel, strides, padding, dilation, out_dtype, args=args) + + [email protected]_topi_schedule("conv2d_nhwc.image2d") +def schedule_conv2d_nhwc(cfg, outs): + return schedule_conv2d_nhwc_impl(cfg, outs, tag="cast_from_acc16") + + [email protected]_topi_schedule("conv2d_nhwc_acc32.image2d") +def schedule_conv2d_nhwc_acc32(cfg, outs): + return schedule_conv2d_nhwc_impl(cfg, outs, tag="cast_from_acc32") + + +def schedule_conv2d_nhwc_impl(cfg, outs, tag): + """Create the schedule for conv2d_nhwc""" + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + s = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if op.tag == tag: + schedule_conv2d_NHWC(cfg, s, op.output(0)) + + traverse_inline(s, outs[0].op, _callback) + return s + + +def compute_conv2d_NHWC_HWIO(Input, Filter, stride, padding, dilation, out_dtype, args): + """ + Convolution operator in NHWC layout. + Algo: + 1. Convert into blocked format if we have 4d original tensor. + In case of AutoTVM we override the convert by just tensors since such conversion + will be absent for real blocked convolution, no sense to include into tuning + 2. Expand spatial dimensions to have width and height be dividable by factor 4 + This leads to slightly bigger amount of compute but allow utilize GPU much better + 3. Add paddings. This happens even if we do not need pad originaly. This is useful + due to work arounding of the gaps of texture annotation between Primary Functions + and limited support of textures in schedules. Later on this pad will be executed + separately and will produce texture + 4. 5d Convolution compute with accumulating into out_dtype + 5. Cast to the origin output data type + 6. For case of 4d convolution: convert of output from 5d to 4d + """ + + if out_dtype is None: + out_dtype = Input.dtype + assert isinstance(stride, int) or len(stride) == 2 + assert isinstance(dilation, int) or len(dilation) == 2 + if isinstance(stride, int): + stride_h = stride_w = stride + else: + stride_h, stride_w = stride + + if isinstance(dilation, int): + dilation_h = dilation_w = dilation + else: + dilation_h, dilation_w = dilation + + convert_from4d = False + if len(Input.shape) == 4: + batch, in_height, in_width, in_channels = Input.shape + kernel_h, kernel_w, in_filter_channels, out_channles = Filter.shape + + in_channel_chunks, in_channel_block, in_channel_tail = split_to_chunks(in_channels, 4) + out_channel_chunks, out_channel_block, out_channel_tail = split_to_chunks(out_channles, 4) + + if autotvm.GLOBAL_SCOPE.in_tuning: + dshape = (batch, in_height, in_width, in_channel_chunks, in_channel_block) + Input = tvm.te.placeholder(dshape, Input.dtype, name="data_placeholder") + kshape = (kernel_h, kernel_w, in_filter_channels, out_channel_chunks, out_channel_block) + Filter = tvm.te.placeholder(kshape, Filter.dtype, name="kernel_placeholder") + else: + convert_from4d = True + Input = pack_input( + Input, + "NHWC", + batch, + in_channel_chunks, + in_channel_block, + in_channel_tail, + in_height, + in_width, + ) + Filter = pack_filter( + Filter, + "HWIO", + out_channel_chunks, + out_channel_block, + out_channel_tail, + in_filter_channels, + in_channel_chunks, + in_channel_block, + in_channel_tail, + kernel_h, + kernel_w, + ) + + else: + batch, in_height, in_width, in_channel_chunks, in_channel_block = Input.shape + kernel_h, kernel_w, in_filter_channels, out_channel_chunks, out_channel_block = Filter.shape + + out_height_orig, out_height, out_width_orig, out_width = expand_spatial_dimensions( + in_height, in_width, kernel_h, kernel_w, dilation_h, dilation_w, padding, stride_h, stride_w + ) + + temp = add_pad( + Input, + "NHWC", + out_height_orig, + out_width_orig, + kernel_h, + kernel_w, + dilation_h, + dilation_w, + padding, + stride_h, + stride_w, + ) + + rcc = te.reduce_axis((0, in_channel_chunks), name="rcc") + rcb = te.reduce_axis((0, in_channel_block), name="rcb") + ry = te.reduce_axis((0, kernel_h), name="ry") + rx = te.reduce_axis((0, kernel_w), name="rx") + conv = te.compute( + (batch, out_height, out_width, out_channel_chunks, out_channel_block), + lambda nn, yy, xx, fc, fb: te.sum( + ( + temp[nn, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, rcc, rcb] + * Filter[ry, rx, rcc * in_channel_block + rcb, fc, fb] + ).astype(args["accumulator"]), + axis=[ry, rx, rcc, rcb], + ), + tag="conv2d_nhwc", + ) + + if convert_from4d and not autotvm.GLOBAL_SCOPE.in_tuning: + dummy_cast = te.compute( + (batch, out_height_orig, out_width_orig, out_channel_chunks, out_channel_block), + lambda n, y, x, fc, fb: conv[n, y, x, fc, fb].astype(out_dtype), + tag="dummy_cast", + ) + return te.compute( + (batch, out_height_orig, out_width_orig, out_channles), + lambda n, y, x, c: dummy_cast[n, y, x, c // out_channel_block, c % out_channel_block], + tag="cast_from_acc" + args["accumulator"][-2:], + ) + else: + return te.compute( + (batch, out_height_orig, out_width_orig, out_channel_chunks, out_channel_block), + lambda n, y, x, ffc, ffb: conv[n, y, x, ffc, ffb].astype(out_dtype), + tag="cast_from_acc" + args["accumulator"][-2:], + ) + + +def schedule_conv2d_NHWC(cfg, s, output): + """ + schedule optimized for batch size = 1 + + Algo: + 1. Split output axis to three parts: global work size, vthread, local worksize. + The limitations for tuning includes heuristics from some tuned networks to limit + search space and not pay much time for useles configurations. + 2. In case of 4d convolution schedule copying of the input (and filter) into + 5d tensors + 4. pad should be scheduled separately to create independent opencl kernel. If pad is + inlined into convolution, this gives 1.5x performance drop + 5. We are using cache_read to produce texture and guarantee the best performance + on the next stage. + 6. For 5d convolution we schedule the latest op with binding 5d axis and vectorize + for textures + For 4d tensor we are doing the same for the latest blocked stage, i.e. conversion + of data type + 7. In case of 4d conv we need to schedule postops as well + """ + latest = s.outputs[0].output(0) + if len(latest.op.axis) == 4: + latest_blocked = dummy = output.op.input_tensors[0] + conv = dummy.op.input_tensors[0] + else: + conv = output.op.input_tensors[0] + latest_blocked = latest + + ##### space definition begin ##### + n, y, x, fc, fb = s[conv].op.axis + ry, rx, rcc, rcb = s[conv].op.reduce_axis + + if conv.shape[3] % 2 == 0: + min_threads_div = 2 + else: + min_threads_div = 1 + + cfg.define_split( + "tile_fc", + fc, + num_outputs=3, + filter=lambda entity: entity.size[1] <= 8 + and entity.size[2] >= min_threads_div + and entity.size[2] < 256, + ) + cfg.define_split( + "tile_y", + y, + num_outputs=3, + filter=lambda entity: entity.size[1] <= 8 and entity.size[2] <= 16, + ) + cfg.define_split( + "tile_x", + x, + num_outputs=3, + filter=lambda entity: entity.size[1] <= 8 and entity.size[2] <= 16, + ) + + cfg.define_split("tile_rcc", rcc, num_outputs=2) + cfg.define_split("tile_ry", ry, num_outputs=2) + cfg.define_split("tile_rx", rx, num_outputs=2) + cfg.define_knob("auto_unroll_max_step", [0, 512, 1500]) + cfg.define_knob("unroll_explicit", [0, 1]) + + pad_data, kernel = s[conv].op.input_tensors + if ( + isinstance(kernel.op, tvm.te.ComputeOp) and "filter_pack" in kernel.op.tag + ): # len(latest.op.axis) == 4: + # manage scheduling of datacopy + pad_data, kernel = s[conv].op.input_tensors + pack_data = pad_data.op.input_tensors[0] + bind_data_copy(s[pack_data]) + bind_data_copy(s[kernel]) + + pad_data, kernel = s[conv].op.input_tensors + + s[pad_data].compute_inline() Review Comment: Are you meaning to inline padding here? Your comment [above](https://github.com/apache/tvm/pull/11161/files#diff-b3830c92b31f9a929f0eccbb82bde0013f49737d9f98424641d91a5435a86bfcR210) implies that you intend to do otherwise. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
