manupa-arm commented on a change in pull request #8795: URL: https://github.com/apache/tvm/pull/8795#discussion_r695571625
########## File path: python/tvm/relay/op/contrib/ethosu.py ########## @@ -0,0 +1,251 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Arm(R) Ethos(TM)-U NPU supported operators.""" +import numpy as np + +from tvm.relay.expr import Constant +from tvm.relay.op.contrib.register import register_pattern_table +from tvm.relay.dataflow_pattern import wildcard, is_op, is_constant +from tvm.relay.backend.contrib.ethosu.util import QConv2DArgs +from tvm.relay.backend.contrib.ethosu.util import BiasAddArgs +from tvm.relay.backend.contrib.ethosu.util import RequantArgs +from tvm.relay.backend.contrib.ethosu.util import get_dim_value +from ethosu.vela import api as vapi + + +def check_strides(strides): + """Checks whether strides are within the limits supported by the hardware""" + stride_range = (1, 3) + smin, smax = stride_range + if not smax >= strides[0] >= smin: + return False + if not smax >= strides[1] >= smin: + return False + return True + + +def check_valid_dtypes(tensor_params): + """Check whether dtypes are supported by the hardware""" + supported_dtypes = (np.uint8, np.int8) + for tep in tensor_params: + # Check for dtypes + if np.dtype(tep.dtype) not in supported_dtypes: + return False + # Check for shape sizes + if any(dimlen > 65536 for dimlen in tep.shape): + return False + return True + + +def check_weights(weights, dilation): + """Checks whether weight tensor is compatible with HW""" + dilated_height_range = (1, 64) + dilated_hxw_range = (1, 64 * 64) + weights_limit = 127 * 65536 + dilated_width = (weights.shape[get_dim_value(weights.layout, "W")] - 1) * dilation[0] + 1 + dilated_height = (weights.shape[get_dim_value(weights.layout, "H")] - 1) * dilation[1] + 1 + dh_min, dh_max = dilated_height_range + if not dh_min <= dilated_height <= dh_max: + return False + dilated_hxw = dilated_height * dilated_width + dhxw_min, dhxw_max = dilated_hxw_range + if not dhxw_min <= dilated_hxw <= dhxw_max: + return False + # A saturation upper bound check for accumulators + weights.values = weights.values - weights.q_params.zero_point + axis = ( + get_dim_value(weights.layout, "H"), + get_dim_value(weights.layout, "W"), + get_dim_value(weights.layout, "I"), + ) + sum_weights = np.amax(np.sum(np.absolute(weights.values), axis=axis)) + if not sum_weights <= weights_limit: + return False + return True + + +def check_bias(bias): + """Check whether the bias values fit in 40 bits""" + if bias and bias.dtype == np.dtype("int64"): + valid = all(len(bin(bias_value)[2:]) <= 40 for bias_value in bias.values) + return valid + return True + + +def check_batch_size(ifm): + """Checks for the number of batches vela currently supports""" + if ifm.shape[0] != 1: + return False + return True + + +def check_dilation(dilation): + """Checks whether dilation is within the limits supported by the hardware""" + dilation_range = (1, 2) + dmin, dmax = dilation_range + if not dmin <= dilation[0] <= dmax: + return False + if not dmin <= dilation[1] <= dmax: + return False + return True + + +def check_padding(padding, bounds): + """Checks whether padding is within the limits supported by the hardware""" + if len(padding) != 4 or len(bounds) != 4: + return False + top, left, bottom, right = padding + topb, leftb, bottomb, rightb = bounds + if top > topb or left > leftb or bottom > bottomb or right > rightb: + return False + return True + + +class TensorParams: + """ + This class will parse a tvm Expr along with quantization scale + and zero point to populate parameters that are required + for the creation of tensors in Vela. + """ + + def __init__(self, tensor, layout=None, scale=None, zero_point=None): + self.tensor = tensor + if isinstance(tensor, Constant): + self.values = tensor.data.asnumpy() + else: + self.values = None + self.dtype = tensor.checked_type.dtype + self.shape = [int(i) for i in tensor.checked_type.shape] + self.layout = layout + + if scale is not None and zero_point is not None: + self.q_params = vapi.NpuQuantization( + scale.data.asnumpy().astype("float32"), zero_point.data.asnumpy().astype(self.dtype) + ) + else: + # put default values + self.q_params = vapi.NpuQuantization(1.0, 0) + + +class QnnConv2DParams: + """ + This class will parse a Call to a ethosu.qnn_conv2d_clip composite function Review comment: Yes, and corrected -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
