masahi commented on code in PR #15057:
URL: https://github.com/apache/tvm/pull/15057#discussion_r1223459120


##########
src/relay/qnn/op/avg_pool2d.cc:
##########
@@ -0,0 +1,225 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/qnn/op/Avg_pool2d.cc
+ * \brief Property def of qnn Avg_pool2d operator.

Review Comment:
   Odd capital `A` in two lines.



##########
src/relay/qnn/op/avg_pool2d.cc:
##########
@@ -0,0 +1,225 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/qnn/op/Avg_pool2d.cc
+ * \brief Property def of qnn Avg_pool2d operator.
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/base.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/relay/qnn/attrs.h>
+#include <tvm/tir/expr.h>
+
+#include "../../op/nn/nn.h"
+#include "../../op/nn/pooling.h"
+#include "../../op/nn/pooling_common.h"
+#include "../../op/tensor/transform.h"
+#include "../../transforms/infer_layout_utils.h"
+#include "../../transforms/pattern_utils.h"
+#include "../utils.h"
+#include "op_common.h"
+
+namespace tvm {
+namespace relay {
+namespace qnn {
+
+// relay.op.qnn.avg_pool2d
+bool QnnAvgPool2DRel(const Array<Type>& types, int num_inputs, const Attrs& 
attrs,
+                     const TypeReporter& reporter) {
+  // Expected Types: data, input_zero_point, input_scale, output_zero_point, 
output_scale
+  // out_type
+
+  ICHECK_EQ(types.size(), 6);
+
+  const auto* data = types[0].as<TensorTypeNode>();
+  if (data == nullptr) return false;
+  ICHECK(data->dtype == DataType::Int(8) || data->dtype == DataType::UInt(8))
+      << "Expected quantized avg_pool2d type(int8, uint8) for input but was " 
<< data->dtype;
+
+  const auto* param = attrs.as<AvgPool2DAttrs>();
+  ICHECK(param != nullptr) << "AvgPool2DAttrs cannot be nullptr.";
+
+  // Check the types of scale and zero points.
+  for (size_t i = 1; i < 5; ++i) {
+    if (types[i].as<IncompleteTypeNode>()) {
+      return false;
+    }
+  }
+
+  ICHECK(IsScalarType(types[1], DataType::Float(32)));  // input_scale
+  ICHECK(IsScalarType(types[2], DataType::Int(32)));    // input_zero_point
+  ICHECK(IsScalarType(types[3], DataType::Float(32)));  // output_scale
+  ICHECK(IsScalarType(types[4], DataType::Int(32)));    // output_zero_point
+
+  // Find the output shape and data type
+  const auto dshape = data->shape;
+  ICHECK_GE(dshape.size(), 2U)
+      << "Pool2D only support input >= 2-D: input must have height and width";
+
+  // Check input and output layout
+  Layout layout(param->layout);
+  // The Layout is always NHWC
+  ICHECK(layout.Contains(LayoutAxis::Get('H')) && 
layout.Contains(LayoutAxis::Get('W')) &&
+         !layout.Contains(LayoutAxis::Get('h')) && 
!layout.Contains(LayoutAxis::Get('w')))
+      << "Invalid input layout " << layout
+      << ". qnn_avg_pool2d inut layout must have H and W, which cannot be 
split";
+
+  // Find the output shape and data type
+  const auto hidx = layout.IndexOf(LayoutAxis::Get('H'));
+  const auto widx = layout.IndexOf(LayoutAxis::Get('W'));
+
+  IndexExpr pad_h, pad_w;
+  if (param->padding.size() == 1) {
+    pad_h = param->padding[0] * 2;
+    pad_w = param->padding[0] * 2;
+  } else if (param->padding.size() == 2) {
+    // (top, left)
+    pad_h = param->padding[0] * 2;
+    pad_w = param->padding[1] * 2;
+  } else if (param->padding.size() == 4) {
+    // (top, left, bottom, right)
+    pad_h = param->padding[0] + param->padding[2];
+    pad_w = param->padding[1] + param->padding[3];
+  } else {
+    return false;
+  }
+
+  std::vector<IndexExpr> oshape(dshape.begin(), dshape.end());
+  if (dshape[hidx].as<tir::AnyNode>()) {
+    oshape[hidx] = dshape[hidx];
+  } else {
+    oshape[hidx] =
+        calculate_pool_dimension(dshape[hidx], pad_h, param->pool_size[0], 
param->dilation[0],
+                                 param->strides[0], param->ceil_mode);
+  }
+  if (dshape[widx].as<tir::AnyNode>()) {
+    oshape[widx] = dshape[widx];
+  } else {
+    oshape[widx] =
+        calculate_pool_dimension(dshape[widx], pad_w, param->pool_size[1], 
param->dilation[1],
+                                 param->strides[1], param->ceil_mode);
+  }
+
+  // assign output type
+  reporter->Assign(types[5], TensorType(oshape, data->dtype));
+  return true;
+}
+
+InferCorrectLayoutOutput QnnAvgPoolInferCorrectLayout(const Attrs& attrs,
+                                                      const Array<Layout>& 
new_in_layouts,
+                                                      const Array<Layout>& 
old_in_layouts,
+                                                      const 
Array<tvm::relay::Type>& old_in_types) {
+  // Use Relay AvgPool2D Infer correct layout.
+  auto avgpool_new_layouts =
+      PoolInferCorrectLayout<AvgPool2DAttrs>(attrs, new_in_layouts, 
old_in_layouts, old_in_types);
+
+  // Fill the layouts of remaining input tensors - scales and zero points. The 
layouts of these
+  // tensors can be treated as channel layout.

Review Comment:
   I think zero point should be a scalar - what happens if we set its layout to 
"undef"?



##########
src/relay/op/nn/pooling_common.h:
##########
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/op/nn/pooling_common.h
+ * \brief Properties def of pooling operator for sharing.

Review Comment:
   Common functions for pooling operator definition. 



##########
src/relay/qnn/op/avg_pool2d.cc:
##########
@@ -0,0 +1,225 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/qnn/op/Avg_pool2d.cc
+ * \brief Property def of qnn Avg_pool2d operator.
+ */

Review Comment:
   I don't like "Property def". I see it is used in other files but let's not 
repeat that.



##########
src/relay/qnn/op/avg_pool2d.cc:
##########
@@ -0,0 +1,225 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/qnn/op/Avg_pool2d.cc
+ * \brief Property def of qnn Avg_pool2d operator.
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/base.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/relay/qnn/attrs.h>
+#include <tvm/tir/expr.h>
+
+#include "../../op/nn/nn.h"
+#include "../../op/nn/pooling.h"
+#include "../../op/nn/pooling_common.h"
+#include "../../op/tensor/transform.h"
+#include "../../transforms/infer_layout_utils.h"
+#include "../../transforms/pattern_utils.h"
+#include "../utils.h"
+#include "op_common.h"
+
+namespace tvm {
+namespace relay {
+namespace qnn {
+
+// relay.op.qnn.avg_pool2d
+bool QnnAvgPool2DRel(const Array<Type>& types, int num_inputs, const Attrs& 
attrs,
+                     const TypeReporter& reporter) {
+  // Expected Types: data, input_zero_point, input_scale, output_zero_point, 
output_scale
+  // out_type
+
+  ICHECK_EQ(types.size(), 6);
+
+  const auto* data = types[0].as<TensorTypeNode>();
+  if (data == nullptr) return false;
+  ICHECK(data->dtype == DataType::Int(8) || data->dtype == DataType::UInt(8))
+      << "Expected quantized avg_pool2d type(int8, uint8) for input but was " 
<< data->dtype;
+
+  const auto* param = attrs.as<AvgPool2DAttrs>();
+  ICHECK(param != nullptr) << "AvgPool2DAttrs cannot be nullptr.";
+
+  // Check the types of scale and zero points.
+  for (size_t i = 1; i < 5; ++i) {
+    if (types[i].as<IncompleteTypeNode>()) {
+      return false;
+    }
+  }
+
+  ICHECK(IsScalarType(types[1], DataType::Float(32)));  // input_scale
+  ICHECK(IsScalarType(types[2], DataType::Int(32)));    // input_zero_point
+  ICHECK(IsScalarType(types[3], DataType::Float(32)));  // output_scale
+  ICHECK(IsScalarType(types[4], DataType::Int(32)));    // output_zero_point
+
+  // Find the output shape and data type
+  const auto dshape = data->shape;
+  ICHECK_GE(dshape.size(), 2U)
+      << "Pool2D only support input >= 2-D: input must have height and width";
+
+  // Check input and output layout
+  Layout layout(param->layout);
+  // The Layout is always NHWC
+  ICHECK(layout.Contains(LayoutAxis::Get('H')) && 
layout.Contains(LayoutAxis::Get('W')) &&
+         !layout.Contains(LayoutAxis::Get('h')) && 
!layout.Contains(LayoutAxis::Get('w')))
+      << "Invalid input layout " << layout
+      << ". qnn_avg_pool2d inut layout must have H and W, which cannot be 
split";
+
+  // Find the output shape and data type
+  const auto hidx = layout.IndexOf(LayoutAxis::Get('H'));
+  const auto widx = layout.IndexOf(LayoutAxis::Get('W'));
+
+  IndexExpr pad_h, pad_w;
+  if (param->padding.size() == 1) {
+    pad_h = param->padding[0] * 2;
+    pad_w = param->padding[0] * 2;
+  } else if (param->padding.size() == 2) {
+    // (top, left)
+    pad_h = param->padding[0] * 2;
+    pad_w = param->padding[1] * 2;
+  } else if (param->padding.size() == 4) {
+    // (top, left, bottom, right)
+    pad_h = param->padding[0] + param->padding[2];
+    pad_w = param->padding[1] + param->padding[3];
+  } else {
+    return false;
+  }
+
+  std::vector<IndexExpr> oshape(dshape.begin(), dshape.end());
+  if (dshape[hidx].as<tir::AnyNode>()) {
+    oshape[hidx] = dshape[hidx];
+  } else {
+    oshape[hidx] =
+        calculate_pool_dimension(dshape[hidx], pad_h, param->pool_size[0], 
param->dilation[0],
+                                 param->strides[0], param->ceil_mode);
+  }
+  if (dshape[widx].as<tir::AnyNode>()) {
+    oshape[widx] = dshape[widx];
+  } else {
+    oshape[widx] =
+        calculate_pool_dimension(dshape[widx], pad_w, param->pool_size[1], 
param->dilation[1],
+                                 param->strides[1], param->ceil_mode);
+  }
+
+  // assign output type
+  reporter->Assign(types[5], TensorType(oshape, data->dtype));
+  return true;
+}
+
+InferCorrectLayoutOutput QnnAvgPoolInferCorrectLayout(const Attrs& attrs,
+                                                      const Array<Layout>& 
new_in_layouts,
+                                                      const Array<Layout>& 
old_in_layouts,
+                                                      const 
Array<tvm::relay::Type>& old_in_types) {
+  // Use Relay AvgPool2D Infer correct layout.
+  auto avgpool_new_layouts =
+      PoolInferCorrectLayout<AvgPool2DAttrs>(attrs, new_in_layouts, 
old_in_layouts, old_in_types);
+
+  // Fill the layouts of remaining input tensors - scales and zero points. The 
layouts of these
+  // tensors can be treated as channel layout.
+  Layout channel_layout = Layout("C");
+  Array<Layout> input_layouts = {avgpool_new_layouts->input_layouts[0], 
channel_layout,
+                                 channel_layout, channel_layout, 
channel_layout};
+  Array<Layout> output_layouts = avgpool_new_layouts->output_layouts;
+  return InferCorrectLayoutOutput(input_layouts, output_layouts, attrs);
+}
+
+/*
+ * \brief Forward rewrite the qnn Avg_pool2d op.
+ * \param attrs The QNN Avg_pool2d attrs.

Review Comment:
   Odd capital `A` in two lines.



##########
src/relay/qnn/op/avg_pool2d.cc:
##########
@@ -0,0 +1,225 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/qnn/op/Avg_pool2d.cc
+ * \brief Property def of qnn Avg_pool2d operator.
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/base.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/relay/qnn/attrs.h>
+#include <tvm/tir/expr.h>
+
+#include "../../op/nn/nn.h"
+#include "../../op/nn/pooling.h"
+#include "../../op/nn/pooling_common.h"
+#include "../../op/tensor/transform.h"
+#include "../../transforms/infer_layout_utils.h"
+#include "../../transforms/pattern_utils.h"
+#include "../utils.h"
+#include "op_common.h"
+
+namespace tvm {
+namespace relay {
+namespace qnn {
+
+// relay.op.qnn.avg_pool2d
+bool QnnAvgPool2DRel(const Array<Type>& types, int num_inputs, const Attrs& 
attrs,
+                     const TypeReporter& reporter) {
+  // Expected Types: data, input_zero_point, input_scale, output_zero_point, 
output_scale
+  // out_type
+
+  ICHECK_EQ(types.size(), 6);
+
+  const auto* data = types[0].as<TensorTypeNode>();
+  if (data == nullptr) return false;
+  ICHECK(data->dtype == DataType::Int(8) || data->dtype == DataType::UInt(8))
+      << "Expected quantized avg_pool2d type(int8, uint8) for input but was " 
<< data->dtype;
+
+  const auto* param = attrs.as<AvgPool2DAttrs>();
+  ICHECK(param != nullptr) << "AvgPool2DAttrs cannot be nullptr.";
+
+  // Check the types of scale and zero points.
+  for (size_t i = 1; i < 5; ++i) {
+    if (types[i].as<IncompleteTypeNode>()) {
+      return false;
+    }
+  }
+
+  ICHECK(IsScalarType(types[1], DataType::Float(32)));  // input_scale
+  ICHECK(IsScalarType(types[2], DataType::Int(32)));    // input_zero_point
+  ICHECK(IsScalarType(types[3], DataType::Float(32)));  // output_scale
+  ICHECK(IsScalarType(types[4], DataType::Int(32)));    // output_zero_point
+
+  // Find the output shape and data type
+  const auto dshape = data->shape;
+  ICHECK_GE(dshape.size(), 2U)
+      << "Pool2D only support input >= 2-D: input must have height and width";
+
+  // Check input and output layout
+  Layout layout(param->layout);
+  // The Layout is always NHWC
+  ICHECK(layout.Contains(LayoutAxis::Get('H')) && 
layout.Contains(LayoutAxis::Get('W')) &&
+         !layout.Contains(LayoutAxis::Get('h')) && 
!layout.Contains(LayoutAxis::Get('w')))
+      << "Invalid input layout " << layout
+      << ". qnn_avg_pool2d inut layout must have H and W, which cannot be 
split";
+
+  // Find the output shape and data type
+  const auto hidx = layout.IndexOf(LayoutAxis::Get('H'));
+  const auto widx = layout.IndexOf(LayoutAxis::Get('W'));
+
+  IndexExpr pad_h, pad_w;
+  if (param->padding.size() == 1) {
+    pad_h = param->padding[0] * 2;
+    pad_w = param->padding[0] * 2;
+  } else if (param->padding.size() == 2) {
+    // (top, left)
+    pad_h = param->padding[0] * 2;
+    pad_w = param->padding[1] * 2;
+  } else if (param->padding.size() == 4) {
+    // (top, left, bottom, right)
+    pad_h = param->padding[0] + param->padding[2];
+    pad_w = param->padding[1] + param->padding[3];
+  } else {
+    return false;
+  }
+
+  std::vector<IndexExpr> oshape(dshape.begin(), dshape.end());
+  if (dshape[hidx].as<tir::AnyNode>()) {
+    oshape[hidx] = dshape[hidx];
+  } else {
+    oshape[hidx] =
+        calculate_pool_dimension(dshape[hidx], pad_h, param->pool_size[0], 
param->dilation[0],
+                                 param->strides[0], param->ceil_mode);
+  }
+  if (dshape[widx].as<tir::AnyNode>()) {
+    oshape[widx] = dshape[widx];
+  } else {
+    oshape[widx] =
+        calculate_pool_dimension(dshape[widx], pad_w, param->pool_size[1], 
param->dilation[1],
+                                 param->strides[1], param->ceil_mode);
+  }
+
+  // assign output type
+  reporter->Assign(types[5], TensorType(oshape, data->dtype));
+  return true;
+}
+
+InferCorrectLayoutOutput QnnAvgPoolInferCorrectLayout(const Attrs& attrs,
+                                                      const Array<Layout>& 
new_in_layouts,
+                                                      const Array<Layout>& 
old_in_layouts,
+                                                      const 
Array<tvm::relay::Type>& old_in_types) {
+  // Use Relay AvgPool2D Infer correct layout.
+  auto avgpool_new_layouts =
+      PoolInferCorrectLayout<AvgPool2DAttrs>(attrs, new_in_layouts, 
old_in_layouts, old_in_types);
+
+  // Fill the layouts of remaining input tensors - scales and zero points. The 
layouts of these
+  // tensors can be treated as channel layout.
+  Layout channel_layout = Layout("C");
+  Array<Layout> input_layouts = {avgpool_new_layouts->input_layouts[0], 
channel_layout,
+                                 channel_layout, channel_layout, 
channel_layout};
+  Array<Layout> output_layouts = avgpool_new_layouts->output_layouts;
+  return InferCorrectLayoutOutput(input_layouts, output_layouts, attrs);
+}
+
+/*
+ * \brief Forward rewrite the qnn Avg_pool2d op.
+ * \param attrs The QNN Avg_pool2d attrs.
+ * \param new_args The new mutated args to the call node.
+ * \param arg_types The types of input and output.
+ * \return The sequence of Relay ops for qnn Avg_pool2d op.
+ * \note Lowering of the qnn.Avg_pool2d operator

Review Comment:
   Odd capital `A` in two lines



##########
python/tvm/topi/hexagon/slice_ops/avg_pool2d.py:
##########
@@ -16,118 +16,206 @@
 # under the License.
 # pylint: disable=invalid-name, unused-variable, unused-argument, 
too-many-locals, pointless-exception-statement
 
-""" Compute and schedule for avg_pool2d slice op
-
-Please note the following assumptions made by the implementation:
-
-1) The input must be padded in advance to account for 'padding'. In addition,
-   both input and output must be padded as per the physical buffer layout.
-2) The current implementation assumes 'count_include_pad' to be 'True'. It can 
be
-   modified to support 'False' case but the element count for the pooling 
window
-   must be pre-computed and provided as an input to reduce the run-time 
overhead.
-3) 'padding' is ignored. It must be handled outside of the sliced op.
-4) Please note that this implementation will not work if the output includes 
any
-   physical layout related padding as it can result into out-of-bound access
-   for the input.
-"""
+""" Compute and schedule for avg_pool2d slice op """
 
 from tvm import te
 from tvm import tir
 from ..utils import get_layout_transform_fn
+from ...utils import get_const_tuple
+from ...nn.utils import get_pad_tuple
+from ...nn.pad import pad
+from ..compute_poolarea import compute_PoolArea
 
 
-def validate_out_shape(out_shape, in_shape, kernel, stride, dilation):
-    """Validate output shape"""
-    _, oh, ow, _ = out_shape
-    _, ih, iw, _ = in_shape
+def avg_pool2d_NCHW(
+    data, kernel, stride, padding, dilation, count_include_pad, oshape, 
odtype="float16"
+):
+    """avg_pool2d compute"""
+    if odtype != "float16":
+        raise RuntimeError(f"Unsupported output dtype '{odtype}'")
     kh, kw = kernel
+    rh = te.reduce_axis((0, kh), name="rh")
+    rw = te.reduce_axis((0, kw), name="rw")
     sh, sw = stride
     dh, dw = dilation
-    if ih < (oh - 1) * sh + dh * (kh - 1) + 1:
-        raise RuntimeError("Output height is too large")
-    if iw < (ow - 1) * sw + dw * (kw - 1) + 1:
-        raise RuntimeError("Output width is too large")
 
+    dilated_kh = (kh - 1) * dh + 1
+    dilated_kw = (kw - 1) * dw + 1
+
+    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
+        get_const_tuple(padding), (dilated_kh, dilated_kw)
+    )
+
+    # DOPAD
 
-def avg_pool2d_compute(A, kernel, stride, dilation, oshape, odtype="float16"):
+    if pad_top != 0 or pad_down != 0 or pad_left != 0 or pad_right != 0:
+        pad_before = (0, 0, pad_top, pad_left)
+        pad_after = (0, 0, pad_down, pad_right)
+        data_pad = pad(data, pad_before, pad_after, name="data_pad")
+    else:
+        # By definition when True, zero-padding will be included in the 
averaging calculation
+        # This is equivalent to PoolArea = (kh * kw)
+        count_include_pad = True
+        data_pad = data
+
+    Sum = te.compute(
+        oshape,
+        lambda b, c, h, w: te.sum(
+            data_pad[b, c, h * sh + dh * rh, w * sw + dw * 
rw].astype("float32"), axis=[rh, rw]
+        ),
+        name="pool_sum",
+    )
+
+    if not count_include_pad:
+        # Compute PoolArea using unpadded input tensor
+        _, _, oh, ow = oshape
+        _, _, ih, iw = data.shape
+
+        PoolArea = te.compute(
+            (oh, ow),
+            lambda i, j: compute_PoolArea(i, j, ih, iw, kh, kw, sh, sw, dh, 
dw, pad_top, pad_left),
+            name="pool_area",
+        )
+
+        InvArea = te.compute(
+            (oh, ow),
+            lambda i, j: tir.if_then_else(
+                tir.all(PoolArea[i, j] > 0), (float(1) / PoolArea[i, j]), 0
+            ),
+            name="inverse_area",
+        )
+
+        Avg = te.compute(
+            oshape,
+            lambda b, c, h, w: (Sum[b, c, h, w] * InvArea[h, 
w]).astype(odtype),
+            name="pool_avg",
+        )
+    else:
+        InvArea = float(1) / (kh * kw)
+        Avg = te.compute(
+            oshape, lambda b, c, h, w: (Sum[b, c, h, w] * 
InvArea).astype(odtype), name="pool_avg"
+        )
+
+    return Avg
+
+
+def avg_pool2d_NHWC(
+    data, kernel, stride, padding, dilation, count_include_pad, oshape, 
odtype="float16"
+):
     """avg_pool2d compute"""
     if odtype != "float16":
-        RuntimeError(f"Unsupported output dtype '{odtype}'")
+        raise RuntimeError(f"Unsupported output dtype '{odtype}'")
     kh, kw = kernel
     rh = te.reduce_axis((0, kh), name="rh")
     rw = te.reduce_axis((0, kw), name="rw")
-    ob, oh, ow, oc = oshape
-    if isinstance(ob, int):
-        validate_out_shape(oshape, A.shape, kernel, stride, dilation)
 
     sh, sw = stride
     dh, dw = dilation
     InvArea = float(1) / (kh * kw)
 
+    dilated_kh = (kh - 1) * dh + 1
+    dilated_kw = (kw - 1) * dw + 1
+
+    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
+        get_const_tuple(padding), (dilated_kh, dilated_kw)
+    )
+
+    # DOPAD
+    if pad_top != 0 or pad_down != 0 or pad_left != 0 or pad_right != 0:
+        pad_before = (0, pad_top, pad_left, 0)
+        pad_after = (0, pad_down, pad_right, 0)
+        data_pad = pad(data, pad_before, pad_after, name="data_pad")
+    else:
+        # By definition when True, zero-padding will be included in the 
averaging calculation
+        # This is equivalent to PoolArea = (kh * kw)
+        count_include_pad = True
+        data_pad = data
+
     Sum = te.compute(
         oshape,
         lambda b, h, w, c: te.sum(
-            A[b, h * sh + dh * rh, w * sw + dw * rw, c].astype("float32"), 
axis=[rh, rw]
+            data_pad[b, h * sh + dh * rh, w * sw + dw * rw, 
c].astype("float32"), axis=[rh, rw]
         ),
-        name="sum",
-    )
-    Avg = te.compute(
-        oshape, lambda b, h, w, c: (Sum[b, h, w, c] * 
InvArea).astype(A.dtype), name="avg"
+        name="pool_sum",
     )
+
+    if not count_include_pad:
+        # Compute PoolArea using unpadded input tensor
+        _, oh, ow, _ = oshape
+        _, ih, iw, _ = data.shape
+
+        PoolArea = te.compute(
+            (oh, ow),
+            lambda i, j: compute_PoolArea(i, j, ih, iw, kh, kw, sh, sw, dh, 
dw, pad_top, pad_left),
+            name="pool_area",
+        )
+
+        InvArea = te.compute(
+            (oh, ow),
+            lambda i, j: tir.if_then_else(
+                tir.all(PoolArea[i, j] > 0), (float(1) / PoolArea[i, j]), 0
+            ),
+            name="inverse_area",
+        )
+
+        Avg = te.compute(
+            oshape,
+            lambda b, h, w, c: (Sum[b, h, w, c] * InvArea[h, 
w]).astype(odtype),
+            name="pool_avg",
+        )
+    else:
+        InvArea = float(1) / (kh * kw)
+        Avg = te.compute(
+            oshape, lambda b, h, w, c: (Sum[b, h, w, c] * 
InvArea).astype(odtype), name="pool_avg"
+        )
+
     return Avg
 
 
-def schedule_nhwc_8h2w32c2w(outs, ins, output_layout: str, input_layout: str):
-    """Schedule for input and output layout nhwc-8h2w32c2w"""
+def schedule_8h2w32c2w(outs, ins, output_layout: str, input_layout: str):
+    """Schedule for input and output layout 8h2w32c2w"""
     func = te.create_prim_func([ins, outs])
+    print(func)
     s = tir.Schedule(func)
-    Sum = s.get_block("sum")
-    Avg = s.get_block("avg")
+    Sum = s.get_block("pool_sum")
+    Avg = s.get_block("pool_avg")
 
+    mem_scope = "global.vtcm"
+    sum_read = s.cache_read(Sum, 0, mem_scope)
+    avg_write = s.cache_write(Avg, 0, mem_scope)
     input_transform_fn = get_layout_transform_fn(input_layout)
     output_transform_fn = get_layout_transform_fn(output_layout)
-    s.transform_layout(Sum, ("read", 0), input_transform_fn)
-    s.transform_layout(Avg, ("write", 0), output_transform_fn)
-
-    # Schedule 'Avg'
-    n, h, w, c = s.get_loops(Avg)
-    ho, hi = s.split(h, [None, 8])
-    wo, wi = s.split(w, [None, 4])
-    wio, wii = s.split(wi, [None, 2])
-    co, ci = s.split(c, [None, 32])
-    s.reorder(n, ho, wo, co, hi, wio, ci, wii)
-    ci_wii = s.fuse(ci, wii)
-    s.vectorize(ci_wii)
-
-    # Schedule 'Sum'
-    s.compute_at(Sum, wio)
-    Sum_axis = s.get_loops(Sum)
-    s.reorder(Sum_axis[-2], Sum_axis[-1], Sum_axis[-4], Sum_axis[-3])
-    ci_wii = s.fuse(Sum_axis[-4], Sum_axis[-3])
-    # s.vectorize(ci_wii) # Doesn't work
+    s.transform_layout(Sum, ("read", 0), input_transform_fn, pad_value=0.0)
+    s.transform_layout(Avg, ("write", 0), output_transform_fn, pad_value=0.0)
     return s
 
 
-def schedule_n11c_1024c(outs, ins, output_layout: str, input_layout: str):
-    """Schedule for output layout: n11c-1024c, input layout: nhwc-8h2w32c2w"""
+def schedule_1024c(outs, ins, output_layout: str, input_layout: str):
+    """Schedule for output layout: 1024c, input layout: 8h2w32c2w"""
     func = te.create_prim_func([ins, outs])
     s = tir.Schedule(func)
-    Sum = s.get_block("sum")
-    Avg = s.get_block("avg")
+    Sum = s.get_block("pool_sum")
+    Avg = s.get_block("pool_avg")
 
+    mem_scope = "global.vtcm"
+    sum_read = s.cache_read(Sum, 0, mem_scope)
+    avg_write = s.cache_write(Avg, 0, mem_scope)
     input_transform_fn = get_layout_transform_fn(input_layout)
     output_transform_fn = get_layout_transform_fn(output_layout)
-    s.transform_layout(Sum, ("read", 0), input_transform_fn)
-    s.transform_layout(Avg, ("write", 0), output_transform_fn)
+    s.transform_layout(Sum, ("read", 0), input_transform_fn, pad_value=0.0)
+    s.transform_layout(Avg, ("write", 0), output_transform_fn, pad_value=0.0)
 
     # Schedule 'Avg'
-    n, h, w, c = s.get_loops(Avg)
-    co, ci = s.split(c, [None, 1024])
+    if output_layout == "n11c-1024c-2d":
+        n, h, w, c = s.get_loops(Avg)
+    else:
+        n, c, h, w = s.get_loops(Avg)
+    _, ci = s.split(c, [None, 1024])
     cio, cii = s.split(ci, [None, 64])
     s.vectorize(cii)
 
     # Schedule 'Sum'
-    s.compute_at(Sum, cio)
+    # s.compute_at(Sum, cio)

Review Comment:
   remove



##########
tests/python/contrib/test_hexagon/test_qnn_op_integration.py:
##########
@@ -0,0 +1,455 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,missing-function-docstring,redefined-outer-name
+
+""" Test Relay integrated qnn ops
+There are two types of tests for qnn ops in this file. One to verify the
+correctness of the relay integration and the other one to verify
+the fake quantization to integer implemented for picking up the qnn op.
+The former is only executed when qnn canonicalization is disabled.
+The latter is executed both with and without canonicalization.
+"""
+# TODO: We might want to distribute these test cases into other test cases 
such as
+# test_wo_qnn_canonicalization and test_pass_fake_quantization_to_integer in 
the future.

Review Comment:
   yeah please do this now (otherwise it will probably not happen). I think 
FQ2I for avgpool should go into `test_pass_fake_quantization_to_integer`, and 
the tests involving Hexagon compile can stay in this file and tests in 
`test_wo_qnn_canonicalization` can migrate to this file instead.



##########
src/relay/qnn/op/avg_pool2d.cc:
##########
@@ -0,0 +1,225 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/qnn/op/Avg_pool2d.cc
+ * \brief Property def of qnn Avg_pool2d operator.
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/base.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/relay/qnn/attrs.h>
+#include <tvm/tir/expr.h>
+
+#include "../../op/nn/nn.h"
+#include "../../op/nn/pooling.h"
+#include "../../op/nn/pooling_common.h"
+#include "../../op/tensor/transform.h"
+#include "../../transforms/infer_layout_utils.h"
+#include "../../transforms/pattern_utils.h"
+#include "../utils.h"
+#include "op_common.h"
+
+namespace tvm {
+namespace relay {
+namespace qnn {
+
+// relay.op.qnn.avg_pool2d
+bool QnnAvgPool2DRel(const Array<Type>& types, int num_inputs, const Attrs& 
attrs,
+                     const TypeReporter& reporter) {
+  // Expected Types: data, input_zero_point, input_scale, output_zero_point, 
output_scale
+  // out_type
+
+  ICHECK_EQ(types.size(), 6);
+
+  const auto* data = types[0].as<TensorTypeNode>();
+  if (data == nullptr) return false;
+  ICHECK(data->dtype == DataType::Int(8) || data->dtype == DataType::UInt(8))
+      << "Expected quantized avg_pool2d type(int8, uint8) for input but was " 
<< data->dtype;
+
+  const auto* param = attrs.as<AvgPool2DAttrs>();
+  ICHECK(param != nullptr) << "AvgPool2DAttrs cannot be nullptr.";
+
+  // Check the types of scale and zero points.
+  for (size_t i = 1; i < 5; ++i) {
+    if (types[i].as<IncompleteTypeNode>()) {
+      return false;
+    }
+  }
+
+  ICHECK(IsScalarType(types[1], DataType::Float(32)));  // input_scale
+  ICHECK(IsScalarType(types[2], DataType::Int(32)));    // input_zero_point
+  ICHECK(IsScalarType(types[3], DataType::Float(32)));  // output_scale
+  ICHECK(IsScalarType(types[4], DataType::Int(32)));    // output_zero_point
+
+  // Find the output shape and data type
+  const auto dshape = data->shape;
+  ICHECK_GE(dshape.size(), 2U)
+      << "Pool2D only support input >= 2-D: input must have height and width";
+
+  // Check input and output layout
+  Layout layout(param->layout);
+  // The Layout is always NHWC
+  ICHECK(layout.Contains(LayoutAxis::Get('H')) && 
layout.Contains(LayoutAxis::Get('W')) &&
+         !layout.Contains(LayoutAxis::Get('h')) && 
!layout.Contains(LayoutAxis::Get('w')))
+      << "Invalid input layout " << layout
+      << ". qnn_avg_pool2d inut layout must have H and W, which cannot be 
split";
+
+  // Find the output shape and data type
+  const auto hidx = layout.IndexOf(LayoutAxis::Get('H'));
+  const auto widx = layout.IndexOf(LayoutAxis::Get('W'));
+
+  IndexExpr pad_h, pad_w;
+  if (param->padding.size() == 1) {
+    pad_h = param->padding[0] * 2;
+    pad_w = param->padding[0] * 2;
+  } else if (param->padding.size() == 2) {
+    // (top, left)
+    pad_h = param->padding[0] * 2;
+    pad_w = param->padding[1] * 2;
+  } else if (param->padding.size() == 4) {
+    // (top, left, bottom, right)
+    pad_h = param->padding[0] + param->padding[2];
+    pad_w = param->padding[1] + param->padding[3];
+  } else {
+    return false;
+  }
+
+  std::vector<IndexExpr> oshape(dshape.begin(), dshape.end());
+  if (dshape[hidx].as<tir::AnyNode>()) {
+    oshape[hidx] = dshape[hidx];
+  } else {
+    oshape[hidx] =
+        calculate_pool_dimension(dshape[hidx], pad_h, param->pool_size[0], 
param->dilation[0],
+                                 param->strides[0], param->ceil_mode);
+  }
+  if (dshape[widx].as<tir::AnyNode>()) {
+    oshape[widx] = dshape[widx];
+  } else {
+    oshape[widx] =
+        calculate_pool_dimension(dshape[widx], pad_w, param->pool_size[1], 
param->dilation[1],
+                                 param->strides[1], param->ceil_mode);
+  }
+
+  // assign output type
+  reporter->Assign(types[5], TensorType(oshape, data->dtype));
+  return true;
+}
+
+InferCorrectLayoutOutput QnnAvgPoolInferCorrectLayout(const Attrs& attrs,
+                                                      const Array<Layout>& 
new_in_layouts,
+                                                      const Array<Layout>& 
old_in_layouts,
+                                                      const 
Array<tvm::relay::Type>& old_in_types) {
+  // Use Relay AvgPool2D Infer correct layout.
+  auto avgpool_new_layouts =
+      PoolInferCorrectLayout<AvgPool2DAttrs>(attrs, new_in_layouts, 
old_in_layouts, old_in_types);
+
+  // Fill the layouts of remaining input tensors - scales and zero points. The 
layouts of these
+  // tensors can be treated as channel layout.
+  Layout channel_layout = Layout("C");
+  Array<Layout> input_layouts = {avgpool_new_layouts->input_layouts[0], 
channel_layout,
+                                 channel_layout, channel_layout, 
channel_layout};
+  Array<Layout> output_layouts = avgpool_new_layouts->output_layouts;
+  return InferCorrectLayoutOutput(input_layouts, output_layouts, attrs);
+}
+
+/*
+ * \brief Forward rewrite the qnn Avg_pool2d op.
+ * \param attrs The QNN Avg_pool2d attrs.
+ * \param new_args The new mutated args to the call node.
+ * \param arg_types The types of input and output.
+ * \return The sequence of Relay ops for qnn Avg_pool2d op.
+ * \note Lowering of the qnn.Avg_pool2d operator
+
+ *  Quantized Avg_pool2d will take one quantized input tensor and returns 
another

Review Comment:
   Odd capital `A`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to