ekalda commented on a change in pull request #8795:
URL: https://github.com/apache/tvm/pull/8795#discussion_r694228019



##########
File path: python/tvm/relay/backend/contrib/ethosu/te/convolution.py
##########
@@ -0,0 +1,199 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,unused-argument
+"""Tensor Expressions for convolutions for the NPU"""
+from tvm import te
+from .dma import dma_ofm_compute, dma_ifm_compute
+
+
+def process_stride(stride):
+    """Process the striding into a common format.
+
+    Parameters
+    ----------
+    stride : Union[int, tuple, list]
+        The 2D striding.
+        int -> striding is the same in the height and width axis.
+        2D -> striding specified as (stride height, stride width).
+
+    Returns
+    -------
+    int
+        The stride in the height axis.
+    int
+        The stride in the width axis.
+
+    """
+    assert isinstance(stride, int) or len(stride) == 2
+    if isinstance(stride, int):
+        return stride, stride
+
+    return stride
+
+
+def process_dilation(dilation):
+    """Process the dilation into a common format.
+
+    Parameters
+    ----------
+    dilation : Union[int, tuple, list]
+        The 2D dilation.
+        int -> dilation is the same in the height and width axis.
+        2D -> dilation specified as (dilation height, dilation width).
+
+    Returns
+    -------
+    int
+        The dilation in the height axis.
+    int
+        The dilation in the width axis.
+
+    """
+    assert isinstance(dilation, int) or len(dilation) == 2
+    if isinstance(dilation, int):
+        return dilation, dilation
+
+    return dilation

Review comment:
       Looks like these don't get used anywhere? Or if they do, maybe they 
could be squashed into one function?

##########
File path: python/tvm/relay/backend/contrib/ethosu/te/convolution.py
##########
@@ -0,0 +1,199 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,unused-argument
+"""Tensor Expressions for convolutions for the NPU"""
+from tvm import te
+from .dma import dma_ofm_compute, dma_ifm_compute
+
+
+def process_stride(stride):
+    """Process the striding into a common format.
+
+    Parameters
+    ----------
+    stride : Union[int, tuple, list]
+        The 2D striding.
+        int -> striding is the same in the height and width axis.
+        2D -> striding specified as (stride height, stride width).
+
+    Returns
+    -------
+    int
+        The stride in the height axis.
+    int
+        The stride in the width axis.
+
+    """
+    assert isinstance(stride, int) or len(stride) == 2
+    if isinstance(stride, int):
+        return stride, stride
+
+    return stride
+
+
+def process_dilation(dilation):
+    """Process the dilation into a common format.
+
+    Parameters
+    ----------
+    dilation : Union[int, tuple, list]
+        The 2D dilation.
+        int -> dilation is the same in the height and width axis.
+        2D -> dilation specified as (dilation height, dilation width).
+
+    Returns
+    -------
+    int
+        The dilation in the height axis.
+    int
+        The dilation in the width axis.
+
+    """
+    assert isinstance(dilation, int) or len(dilation) == 2
+    if isinstance(dilation, int):
+        return dilation, dilation
+
+    return dilation
+
+
+def conv2d_compute(
+    ifm,
+    weight,
+    scale_bias,
+    lut,
+    ifm_scale,
+    ifm_zero_point,
+    weight_zero_point,
+    ofm_scale,
+    ofm_zero_point,
+    strides,
+    padding,
+    dilation,
+    activation,
+    clip_min,
+    clip_max,
+    upscale,
+    ifm_layout,
+    ofm_layout,
+):
+    """A compute operator representing the capabilities of a 2D convolution 
for the NPU.
+
+    Parameters
+    ----------
+    ifm : te.Tensor
+        The Input Feature Map tensor (IFM).
+    weight : te.Tensor
+        The weight tensor.
+    scale_bias : te.Tensor
+        The packed per-channel weight scale and bias tensor.
+    lut : te.Tensor
+        The look-up table values to use if activation = "LUT".
+    ifm_scale : float
+        The quantization scale for the Input Feature Map tensor.
+    ifm_zero_point : int
+        The quantization zero point for the Input Feature Map tensor.
+    weight_zero_point : int
+        The quantization zero point for the weight tensor.
+    ofm_scale : float
+        The quantization scale for the Output Feature Map tensor.
+    ofm_zero_point : int
+        The quantization zero point for the Output Feature Map tensor.
+    strides : tuple
+        The 2 dimensional strides as (stride_height, stride_width).
+    padding : tuple
+        The 4 dimensional padding as (pad_top, pad_left, pad_bottom, 
pad_right).
+    dilation : Union[int, tuple, list]
+        The 2 dimensional dilation as (dilation_height, dilation_width).
+    activation : str
+        The activation function to use.
+            "NONE" - no activation function.
+            "CLIP" - clip the output between clip_min and clip_max.
+            "TANH" - tanh activation function.
+            "SIGMOID" - sigmoid activation function.
+            "LUT" - use a look-up table to perform the activation function.
+    clip_min : int
+        The minimum clipping value if activation = "CLIP".
+    clip_max : int
+        The maximum clipping value if activation = "CLIP".
+    upscale : str
+        The 2x2 upscaling mode to apply to the Input Feature Map tensor.
+            "NONE" - no upscaling.
+            "NEAREST" - upscale using nearest neighbour.
+            "ZEROS" - upscale using zeros.
+    ifm_layout : str
+        The layout of the Input Feature Map tensor. Can be "NHWC" or "NHCWB16".
+    ofm_layout : str
+        The layout of the Output Feature Map tensor. Can be "NHWC" or 
"NHCWB16".
+
+    Returns
+    -------
+    te.Tensor
+        The OFM tensor.
+
+    """
+    assert ifm.shape[0] == 1
+    assert ifm_layout in {"NHWC", "NHCWB16"}
+    assert ofm_layout in {"NHWC", "NHCWB16"}
+
+    stride_h, stride_w = strides
+    dilation_h, dilation_w = dilation
+    ofm_channels, kernel_h, kernel_w, ifm_channels = weight.shape
+
+    # Compute operation for the IFM DMA pipeline
+    dmaed_ifm = dma_ifm_compute(
+        ifm, ifm_layout, ifm_zero_point, ifm_scale, weight.shape[3], padding
+    )
+
+    # 2D Convolution compute operation
+    dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
+    dilated_kernel_w = (kernel_w - 1) * dilation_w + 1

Review comment:
       It seems to me that finding an output shape based on the kernel size, 
strides and dilation is a very common operation, so maybe worth creating a 
separate function?

##########
File path: python/tvm/relay/backend/contrib/ethosu/util.py
##########
@@ -0,0 +1,198 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Helper utility Enums and Functions used through out codegen
+
+The enums are there to indicate which argument of each relay operator
+corresponds with which input.
+e.g., input zero point of qnn.conv2d is 4th argument(3rd index)
+
+The rest of the utility functions are misc.
+Refer to the description inside such functions
+"""
+
+from enum import Enum
+import numpy as np
+
+from tvm import relay
+from tvm.relay.build_module import bind_params_by_name
+from tvm.relay.backend.contrib.ethosu import preprocess
+
+
+class QConv2DArgs(Enum):
+    """
+    This is a helper enums to access the correct index
+    qnn conv2d arguments
+    """
+
+    ifm = 0
+    weights = 1
+    ifm_zero_point = 2
+    weights_zero_point = 3
+    ifm_scale = 4
+    weights_scale = 5
+
+
+class RequantArgs(Enum):
+    """
+    This is a helper enums to access the correct index
+    qnn requantize arguments
+    """
+
+    ifm_scale = 1
+    ifm_zero_point = 2
+    ofm_scale = 3
+    ofm_zero_point = 4
+
+
+class BiasAddArgs(Enum):
+    """
+    This is a helper enums to access the correct index
+    qnn bias_add arguments
+    """
+
+    biases = 1
+
+
+class ClipArgs(Enum):
+    """
+    This is a helper enums to access the correct index
+    qnn bias_add arguments
+    """
+
+    a_min = 1
+    a_max = 2
+
+
+class MaxPoolArgs(Enum):
+    """
+    This is a helper enums to access the correct index
+    max pool arguments
+    """
+
+    ifm = 0
+
+
+class AddArgs(Enum):
+    """This is a helper enums to access the correct index
+    max pool arguments
+    """
+
+    ifm0 = 0
+    ifm1 = 1
+    ifm0_scale = 2
+    ifm0_zero_point = 3
+    ifm1_scale = 4
+    ifm1_zero_point = 5
+    ofm_scale = 6
+    ofm_zero_point = 7

Review comment:
       Looks like this class doesn't get used in this commit

##########
File path: python/tvm/relay/backend/contrib/ethosu/te/dma.py
##########
@@ -0,0 +1,299 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,unnecessary-lambda
+"""Tensor Expressions for operations supported by the DMA engine"""
+import tvm
+from tvm import te
+from tvm.topi.utils import equal_const_int
+
+
+def _pad_tensor(tensor, pad_before, pad_after=None):
+    """Generate a padded tensor.
+
+    Parameters
+    ----------
+    tensor : te.Tensor
+        The tensor to pad.
+    pad_before : tuple of int
+        The 'before' padding on each axis.
+    pad_after : tuple of int
+        The 'after' padding on each axis.
+    Returns
+    -------
+    _pad : callable
+        The padded tensor.
+
+    """
+    pad_after = pad_after or pad_before
+    dims = len(tensor.shape)
+    assert len(pad_before) == dims
+    assert len(pad_after) == dims
+
+    def _pad(*indices):
+        not_zero = []
+        index_tuple = []

Review comment:
       I think it is worth documenting what these variables mean

##########
File path: python/tvm/relay/backend/contrib/ethosu/util.py
##########
@@ -0,0 +1,198 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Helper utility Enums and Functions used through out codegen
+
+The enums are there to indicate which argument of each relay operator
+corresponds with which input.
+e.g., input zero point of qnn.conv2d is 4th argument(3rd index)
+
+The rest of the utility functions are misc.
+Refer to the description inside such functions
+"""
+
+from enum import Enum
+import numpy as np
+
+from tvm import relay
+from tvm.relay.build_module import bind_params_by_name
+from tvm.relay.backend.contrib.ethosu import preprocess
+
+
+class QConv2DArgs(Enum):
+    """
+    This is a helper enums to access the correct index
+    qnn conv2d arguments
+    """
+
+    ifm = 0
+    weights = 1
+    ifm_zero_point = 2
+    weights_zero_point = 3
+    ifm_scale = 4
+    weights_scale = 5
+
+
+class RequantArgs(Enum):
+    """
+    This is a helper enums to access the correct index
+    qnn requantize arguments
+    """
+
+    ifm_scale = 1
+    ifm_zero_point = 2
+    ofm_scale = 3
+    ofm_zero_point = 4
+
+
+class BiasAddArgs(Enum):
+    """
+    This is a helper enums to access the correct index
+    qnn bias_add arguments
+    """
+
+    biases = 1
+
+
+class ClipArgs(Enum):
+    """
+    This is a helper enums to access the correct index
+    qnn bias_add arguments

Review comment:
       clip

##########
File path: python/tvm/relay/op/contrib/ethosu.py
##########
@@ -0,0 +1,251 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Arm(R) Ethos(TM)-U NPU supported operators."""
+import numpy as np
+
+from tvm.relay.expr import Constant
+from tvm.relay.op.contrib.register import register_pattern_table
+from tvm.relay.dataflow_pattern import wildcard, is_op, is_constant
+from tvm.relay.backend.contrib.ethosu.util import QConv2DArgs
+from tvm.relay.backend.contrib.ethosu.util import BiasAddArgs
+from tvm.relay.backend.contrib.ethosu.util import RequantArgs
+from tvm.relay.backend.contrib.ethosu.util import get_dim_value
+from ethosu.vela import api as vapi
+
+
+def check_strides(strides):
+    """Checks whether strides are within the limits supported by the 
hardware"""
+    stride_range = (1, 3)
+    smin, smax = stride_range
+    if not smax >= strides[0] >= smin:
+        return False
+    if not smax >= strides[1] >= smin:
+        return False
+    return True
+
+
+def check_valid_dtypes(tensor_params):
+    """Check whether dtypes are supported by the hardware"""
+    supported_dtypes = (np.uint8, np.int8)
+    for tep in tensor_params:
+        # Check for dtypes
+        if np.dtype(tep.dtype) not in supported_dtypes:
+            return False
+        # Check for shape sizes
+        if any(dimlen > 65536 for dimlen in tep.shape):
+            return False
+    return True
+
+
+def check_weights(weights, dilation):
+    """Checks whether weight tensor is compatible with HW"""
+    dilated_height_range = (1, 64)
+    dilated_hxw_range = (1, 64 * 64)
+    weights_limit = 127 * 65536
+    dilated_width = (weights.shape[get_dim_value(weights.layout, "W")] - 1) * 
dilation[0] + 1
+    dilated_height = (weights.shape[get_dim_value(weights.layout, "H")] - 1) * 
dilation[1] + 1
+    dh_min, dh_max = dilated_height_range
+    if not dh_min <= dilated_height <= dh_max:
+        return False
+    dilated_hxw = dilated_height * dilated_width
+    dhxw_min, dhxw_max = dilated_hxw_range
+    if not dhxw_min <= dilated_hxw <= dhxw_max:
+        return False
+    # A saturation upper bound check for accumulators
+    weights.values = weights.values - weights.q_params.zero_point
+    axis = (
+        get_dim_value(weights.layout, "H"),
+        get_dim_value(weights.layout, "W"),
+        get_dim_value(weights.layout, "I"),
+    )
+    sum_weights = np.amax(np.sum(np.absolute(weights.values), axis=axis))
+    if not sum_weights <= weights_limit:
+        return False
+    return True
+
+
+def check_bias(bias):
+    """Check whether the bias values fit in 40 bits"""
+    if bias and bias.dtype == np.dtype("int64"):
+        valid = all(len(bin(bias_value)[2:]) <= 40 for bias_value in 
bias.values)
+        return valid
+    return True
+
+
+def check_batch_size(ifm):
+    """Checks for the number of batches vela currently supports"""
+    if ifm.shape[0] != 1:
+        return False
+    return True
+
+
+def check_dilation(dilation):
+    """Checks whether dilation is within the limits supported by the 
hardware"""
+    dilation_range = (1, 2)
+    dmin, dmax = dilation_range
+    if not dmin <= dilation[0] <= dmax:
+        return False
+    if not dmin <= dilation[1] <= dmax:
+        return False
+    return True
+
+
+def check_padding(padding, bounds):
+    """Checks whether padding is within the limits supported by the hardware"""
+    if len(padding) != 4 or len(bounds) != 4:
+        return False
+    top, left, bottom, right = padding
+    topb, leftb, bottomb, rightb = bounds
+    if top > topb or left > leftb or bottom > bottomb or right > rightb:
+        return False
+    return True
+
+
+class TensorParams:
+    """
+    This class will parse a tvm Expr along with quantization scale
+    and zero point to populate parameters that are required
+    for the creation of tensors in Vela.
+    """
+
+    def __init__(self, tensor, layout=None, scale=None, zero_point=None):
+        self.tensor = tensor
+        if isinstance(tensor, Constant):
+            self.values = tensor.data.asnumpy()
+        else:
+            self.values = None
+        self.dtype = tensor.checked_type.dtype
+        self.shape = [int(i) for i in tensor.checked_type.shape]
+        self.layout = layout
+
+        if scale is not None and zero_point is not None:
+            self.q_params = vapi.NpuQuantization(
+                scale.data.asnumpy().astype("float32"), 
zero_point.data.asnumpy().astype(self.dtype)
+            )
+        else:
+            # put default values
+            self.q_params = vapi.NpuQuantization(1.0, 0)
+
+
+class QnnConv2DParams:
+    """
+    This class will parse a Call to a ethosu.qnn_conv2d_clip composite function

Review comment:
       Shouldn't this be just ethosu.qnn_conv2d?

##########
File path: python/tvm/relay/backend/contrib/ethosu/te/dma.py
##########
@@ -0,0 +1,299 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,unnecessary-lambda
+"""Tensor Expressions for operations supported by the DMA engine"""
+import tvm
+from tvm import te
+from tvm.topi.utils import equal_const_int
+
+
+def _pad_tensor(tensor, pad_before, pad_after=None):
+    """Generate a padded tensor.
+
+    Parameters
+    ----------
+    tensor : te.Tensor
+        The tensor to pad.
+    pad_before : tuple of int
+        The 'before' padding on each axis.
+    pad_after : tuple of int
+        The 'after' padding on each axis.
+    Returns
+    -------
+    _pad : callable
+        The padded tensor.
+
+    """
+    pad_after = pad_after or pad_before
+    dims = len(tensor.shape)
+    assert len(pad_before) == dims
+    assert len(pad_after) == dims
+
+    def _pad(*indices):
+        not_zero = []
+        index_tuple = []
+        for i in range(dims):
+            if equal_const_int(pad_before[i], 0) and 
equal_const_int(pad_after[i], 0):
+                index_tuple.append(indices[i])
+            else:
+                index_tuple.append(indices[i] - pad_before[i])
+                not_zero.append(indices[i] >= pad_before[i])
+                not_zero.append(indices[i] < tensor.shape[i] + pad_before[i])
+        if not_zero:
+            not_zero = tvm.tir.all(*not_zero)
+            return tvm.tir.if_then_else(not_zero, tensor(*index_tuple), 
tvm.tir.const(0, "uint8"))
+        return tensor(*index_tuple)
+
+    return _pad
+
+
+def read_compute(tensor, layout, zero_point, scale):
+    """A TE compute operator to represent a read.
+
+    Parameters
+    ----------
+    tensor : te.Tensor
+        The tensor to read.
+    layout : str
+        The layout of the tensor, either NHWC or NHCWB16.
+    zero_point : int
+        The zero point of the tensor.
+    scale : float
+        The scale of the tensor.
+
+    Returns
+    -------
+    te.Tensor
+        The tensor having been read.
+
+    """
+    assert layout in {"NHWC", "NHCWB16"}
+    read_attrs = {
+        "op": "ethosu_read",
+        "layout": layout,
+        "zero_point": zero_point,
+        "scale": scale,
+    }
+    return te.compute(tensor.shape, lambda *i: tensor(*i), name="ethosu_read", 
attrs=read_attrs)
+
+
+def write_compute(tensor, layout, zero_point, scale):
+    """A TE compute operator to represent a write.
+
+    Parameters
+    ----------
+    tensor : te.Tensor
+        The tensor to write.
+    layout : str
+        The layout of the tensor, either NHWC or NHCWB16.
+    zero_point : int
+        The zero point of the tensor.
+    scale : float
+        The scale of the tensor.
+
+    Returns
+    -------
+    te.Tensor
+        The tensor having been written.
+
+    """
+    assert layout in {"NHWC", "NHCWB16"}
+    write_attrs = {
+        "op": "ethosu_write",
+        "layout": layout,
+        "zero_point": zero_point,
+        "scale": scale,
+    }
+    return te.compute(
+        tensor.shape,
+        lambda *i: tensor(*i),
+        name="ethosu_write",
+        attrs=write_attrs,
+    )
+
+
+def convert_to_nhwc_compute(tensor, layout, channels):
+    """Converts a tensor into NHWC layout if it's in NHWCB16 layout.
+
+    Parameters
+    ----------
+    tensor : te.Tensor
+        The tensor to convert.
+    layout : str
+        The layout of the tensor, either NHWC or NHCWB16.
+    channels : int
+        The number of valid channels for the tensor.
+
+    Returns
+    -------
+    te.Tensor
+        The converted tensor in NHWC layout.
+
+    """
+    assert layout in {"NHWC", "NHCWB16"}
+    convert_to_nhwc_attrs = {
+        "op": "ethosu_convert_to_nhwc",
+        "layout": layout,
+    }
+    if layout == "NHCWB16":
+        return te.compute(
+            (tensor.shape[0], tensor.shape[1], tensor.shape[3], channels),
+            lambda nn, hh, ww, cc: tensor(nn, hh, te.indexdiv(cc, 16), ww, 
te.indexmod(cc, 16)),
+            name="ethosu_convert_to_nhwc",
+            attrs=convert_to_nhwc_attrs,
+        )
+
+    return te.compute(
+        tensor.shape,
+        lambda *i: tensor(*i),
+        name="ethosu_convert_to_nhwc",
+        attrs=convert_to_nhwc_attrs,
+    )
+
+
+def convert_to_nhcwb16_compute(tensor, layout, channels):
+    """Converts a tensor into NHCWB16 layout if it's in NHWC layout.
+
+    Parameters
+    ----------
+    tensor : te.Tensor
+        The tensor to convert.
+    layout : str
+        The layout of the tensor, either NHWC or NHCWB16.
+    channels : int
+        The number of valid channels for the tensor.
+
+    Returns
+    -------
+    te.Tensor
+        The converted tensor in NHCWB16 layout.
+
+    """
+    assert layout in {"NHWC", "NHCWB16"}
+    convert_to_nhcwb16_attrs = {
+        "op": "ethosu_convert_to_nhcwb16",
+        "layout": layout,
+    }
+    if layout == "NHCWB16":
+        out_channel_bricks = te.indexdiv(channels - 1, 16) + 1
+        output_shape = (1, tensor.shape[1], out_channel_bricks, 
tensor.shape[2], 16)
+        return te.compute(
+            output_shape,
+            lambda nn, hh, cc, ww, cb: tvm.tir.if_then_else(
+                cc * 16 + cb < channels,
+                tensor(nn, hh, ww, cc * 16 + cb),
+                tvm.tir.IntImm(tensor.dtype, 0),
+            ),
+            name="ethosu_convert_to_nhcwb16",
+            attrs=convert_to_nhcwb16_attrs,
+        )
+
+    return te.compute(
+        tensor.shape,
+        lambda *i: tensor(*i),
+        name="ethosu_convert_to_nhcwb16",
+        attrs=convert_to_nhcwb16_attrs,
+    )
+
+
+def pad_compute(tensor, padding):
+    """Pad an NHWC tensor in the height and width axes.
+
+    Parameters
+    ----------
+    tensor : te.Tensor
+        The tensor to pad.
+    padding : tuple
+        The 4 dimensional padding as (pad_top, pad_left, pad_bottom, 
pad_right).
+
+    Returns
+    -------
+    te.Tensor
+        The padded tensor.
+
+    """
+    pad_top, pad_left, pad_down, pad_right = padding
+    pad_before = [0, pad_top, pad_left, 0]
+    pad_after = [0, pad_down, pad_right, 0]
+    pad_attrs = {
+        "op": "ethosu_pad",
+    }
+    shape = tensor.shape
+    return te.compute(
+        (shape[0], shape[1] + pad_top + pad_down, shape[2] + pad_left + 
pad_right, shape[3]),
+        lambda nn, hh, ww, cc: _pad_tensor(tensor, pad_before, pad_after)(nn, 
hh, ww, cc),
+        name="ethosu_pad",
+        attrs=pad_attrs,
+    )
+
+
+def dma_ifm_compute(ifm, layout, zero_point, scale, channels, padding):
+    """A sequence of compute operators representing the DMA capabilities for 
an IFM.
+
+    Parameters
+    ----------
+    ifm : te.Tensor
+        The Input Feature Map (IFM) tensor.
+    layout : str
+        The layout of the data, either NHWC or NHCWB16.
+    zero_point : int
+        The zero point of the data.
+    scale : float
+        The scale of the data.
+    channels : int
+        The number of valid channels for the data.
+    padding : Union[int, tuple, list]
+        The desired padding.
+        int -> padding applied to both height and width axes.
+        2D -> padding applied equally on both sides of the (height, width) 
axes.
+        4D -> padding applied as (top, left, bottom, right)

Review comment:
       Looks like pad_compute only handles 4D padding though




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to