elvin-n commented on code in PR #11161: URL: https://github.com/apache/tvm/pull/11161#discussion_r865391568
########## python/tvm/topi/adreno/utils.py: ########## @@ -0,0 +1,545 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-variable,unused-argument,no-else-return +"""util functions to be reused in different compute/schedule on Qualcomm Adreno GPU""" + +import tvm +import numpy +from tvm import te +from tvm.topi.utils import simplify +from tvm.topi import nn +from ..utils import get_const_tuple + + +def getDiv(value, start): + """Returns the maximum divider for `value` starting from `start` value""" + div = 1 + for d in range(start, 0, -1): + if (value % d) == 0: + div = d + break + return div + + +def split_to_chunks(trip_count, block): + """ + Splits the trip count value to chunks and block, returns the remainder as well + the chunks and blocks covers or overlaps the origin trip_count + + If trip_count can be divisible by block: + trip_count = chunks * block + else + trip_count = (chunks - 1) * block + tail + + Parameters + ---------- + trip_count: int + tripcount for original compute + + block: int + size of the block + + Returns + ---------- + out: tuple of the (chunks, block, tail) + """ + tail = trip_count % 4 + chunks = trip_count // 4 + if tail == 0: + tail = 4 + else: + chunks += 1 + return chunks, block, tail + + +def pack_input( + Input, layout, batch, in_channel_chunks, in_channel_block, in_channel_tail, in_height, in_width +): + """ + Adds compute stages for packing of the data in runtime. Extends channel dimensions + to be dividable by factor 4 + + Parameters + ---------- + Input: tvm.te.Tensor + Input tensor to be repacked in runtime + + layout: string + Layout of origin 4d tensor + NCHW or NHWC are acceptable + + batch: int + Batch size + + in_channel_chunks: int + Number of channel chunks been in the final tensor + + in_channel_block: int + Number of channel blocks been in the final tensor + + in_channel_tail: int + Tail in the latest chunk diffing original number of channels vs blocked one + If in_channel_tail != in_channel_block: + original_channels = in_channel_chunks * in_channel_block - in_channel_tail + else + original_channels = in_channel_chunks * in_channel_block + + in_height: int + Height of the feature map + + in_width: int + Width of the feature map + """ + + pad_value = tvm.tir.const(0, Input.dtype) + + def _reorder_data_nchw(*indices): + condition = [] + condition.append(indices[1] == in_channel_chunks - 1) + condition.append(indices[4] >= in_channel_tail) + condition = tvm.tir.all(*condition) + return tvm.tir.if_then_else( + condition, + pad_value, + Input[indices[0], indices[1] * in_channel_block + indices[4], indices[2], indices[3]], + ) + + def _reorder_data_nhwc(*indices): + condition = [] + condition.append(indices[3] == in_channel_chunks - 1) + condition.append(indices[4] >= in_channel_tail) + condition = tvm.tir.all(*condition) + return tvm.tir.if_then_else( + condition, + pad_value, + Input[indices[0], indices[1], indices[2], indices[3] * in_channel_block + indices[4]], + ) + + # compute: + if layout == "NCHW": + reordered_data = te.compute( + [batch, in_channel_chunks, in_height, in_width, in_channel_block], + _reorder_data_nchw, + name="input_pack", + tag="input_pack", + ) + elif layout == "NHWC": + reordered_data = te.compute( + [batch, in_height, in_width, in_channel_chunks, in_channel_block], + _reorder_data_nhwc, + name="input_pack", + tag="input_pack", + ) + else: + assert False, "Adreno util function pack_input does not accept unknown layout" + return reordered_data + + +def pack_filter( + Filter, + layout, + out_channel_chunks, + out_channel_block, + out_channel_tail, + in_filter_channels, + in_data_channel_chunks, + in_data_channel_block, + in_data_channel_tail, + kernel_h, + kernel_w, +): + """ + Adds compute stages for packing of the filter in runtime. Extends channels dimensions + to be dividable by factor 4 + + Parameters + ---------- + Filter: tvm.te.Tensor + Filter tensor to be repacked in runtime + + layout: string + Layout of origin 4d tensor + NCHW or NHWC are acceptable + + out_channel_chunks: int + Number of chunks for filters + + out_channel_block: int + Size of the block + + out_channel_tail: int + Original size of the latest chunk of output filters + + in_filter_channels: int + Number of filter channels. might be different vs input channels in the + data due to groups/depthwise nature + + in_data_channel_chunks: int + Number of chunks by channels for input data + + in_data_channel_block: int + Size of the block for input data channels + + in_data_channel_tail + Original size of the latest chunk for input data channels + + kernel_h: int + Height of the conv2d kernel + + kernel_w: int + Width of the conv2d kernel + """ + pad_value = tvm.tir.const(0, Filter.dtype) + + def _reorder_weights_depthwise_oihw(*indices): + conditionA = [] + conditionA.append(indices[0] == out_channel_chunks - 1) + conditionA.append(indices[4] >= out_channel_tail) + conditionAT = tvm.tir.all(*conditionA) + + return tvm.tir.if_then_else( + conditionAT, + pad_value, + Filter[indices[0] * out_channel_block + indices[4], indices[1], indices[2], indices[3]], + ) + + def _reorder_weights_depthwise_hwoi(*indices): + conditionA = [] + conditionA.append(indices[2] == out_channel_chunks - 1) + conditionA.append(indices[4] >= out_channel_tail) + conditionAT = tvm.tir.all(*conditionA) + + return tvm.tir.if_then_else( + conditionAT, + pad_value, + Filter[indices[0], indices[1], indices[2] * out_channel_block + indices[4], indices[3]], + ) + + def _reorder_weights_oihw(*indices): + conditionA = [] + conditionA.append(indices[0] == out_channel_chunks - 1) + conditionA.append(indices[4] >= out_channel_tail) + conditionAT = tvm.tir.all(*conditionA) + + conditionO = [] + conditionO.append(conditionAT) + conditionO.append( + indices[1] >= in_data_channel_chunks * in_data_channel_block + in_data_channel_tail + ) + conditionOT = tvm.tir.any(*conditionO) + return tvm.tir.if_then_else( + conditionOT, + pad_value, + Filter[indices[0] * out_channel_block + indices[4], indices[1], indices[2], indices[3]], + ) + + def _reorder_weights_hwio(*indices): + conditionA = [] + conditionA.append(indices[3] == out_channel_chunks - 1) + conditionA.append(indices[4] >= out_channel_tail) + conditionAT = tvm.tir.all(*conditionA) + + conditionO = [] + conditionO.append(conditionAT) + conditionO.append( + indices[2] >= in_data_channel_chunks * in_data_channel_block + in_data_channel_tail + ) + conditionOT = tvm.tir.any(*conditionO) + return tvm.tir.if_then_else( + conditionOT, + pad_value, + Filter[indices[0], indices[1], indices[2], indices[3] * out_channel_block + indices[4]], + ) + + if in_filter_channels == 1: + if layout == "OIHW": + reordered_filter = te.compute( + [out_channel_chunks, in_filter_channels, kernel_h, kernel_w, out_channel_block], + _reorder_weights_depthwise_oihw, + name="filter_pack", + tag="filter_pack", + ) + elif layout == "HWOI": + reordered_filter = te.compute( + [kernel_h, kernel_w, out_channel_chunks, in_filter_channels, out_channel_block], + _reorder_weights_depthwise_hwoi, + name="filter_pack", + tag="filter_pack", + ) + else: + assert False, "Adreno util function def pack_filter does not accept unknown layout" + else: + if layout == "OIHW": + reordered_filter = te.compute( + [out_channel_chunks, in_filter_channels, kernel_h, kernel_w, out_channel_block], + _reorder_weights_oihw, + name="filter_pack", + tag="filter_pack", + ) + elif layout == "HWIO": + reordered_filter = te.compute( + [kernel_h, kernel_w, in_filter_channels, out_channel_chunks, out_channel_block], + _reorder_weights_hwio, + name="filter_pack", + tag="filter_pack", + ) + else: + assert False, "Adreno util function def pack_filter does not accept unknown layout" + return reordered_filter + + +def expand_spatial_dimensions( + in_height, in_width, kernel_h, kernel_w, dilation_h, dilation_w, padding, stride_h, stride_w +): + """ + Expands spatial dimensions to be dividable by factor 4. This will allow us to do extrimely + better parallel computation on GPU. The drawback of this solution - it will be number of + useless computations. By fact the speed-up of parallelism significantly overcomes the slowdown + of extra compute and eventuially this is useful approach, at least for GPU + + Parameters + ---------- + in_height: int + Height of the feature map + + in_width: int + Width of the featrue map Review Comment: done -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
