lhutton1 commented on code in PR #13997:
URL: https://github.com/apache/tvm/pull/13997#discussion_r1108263890


##########
src/relay/op/contrib/ethosu/pooling.cc:
##########
@@ -46,14 +46,27 @@ bool EthosuPoolingRel(const Array<Type>& types, int 
num_inputs, const Attrs& att
 
   const String operator_name = "ethosu_pooling";
 
-  if (param->pooling_type != "AVG" && param->pooling_type != "MAX") {
+  if (param->pooling_type != "AVG" && param->pooling_type != "MAX" &&
+      param->pooling_type != "SUM") {
     reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
                                      << "Invalid operator: expected " << 
operator_name
-                                     << " type 'AVG' or 'MAX' but was " << 
param->pooling_type);
+                                     << " type 'AVG', 'MAX', or 'SUM' but was "
+                                     << param->pooling_type);
     return false;
   }
 
-  CheckDataType(reporter, ifm->dtype, {DataType::UInt(8), DataType::Int(8)}, 
operator_name, "ifm",
+  auto max_avg_pooling_ifm_dtypes = {DataType::UInt(8), DataType::Int(8)};
+  auto sum_pooling_ifm_dtypes = {DataType::UInt(8), DataType::Int(8), 
DataType::Int(16),

Review Comment:
   Nit: prefer types over `auto`



##########
python/tvm/relay/op/contrib/ethosu.py:
##########
@@ -1375,6 +1375,82 @@ def mean_pattern() -> 
tvm.relay.dataflow_pattern.DFPattern:
     return pattern
 
 
+class SumParams:
+    """
+    This class will parse a call to ethosu.sum composite function
+    and extract the parameter information.
+    """
+
+    composite_name = "ethos-u.sum"
+
+    def __init__(self, func_body: Call):
+        from tvm.relay.backend.contrib.ethosu.util import RequantArgs
+
+        requantize = func_body
+        sum_op = requantize.args[0]
+        attrs = sum_op.attrs
+        cast = sum_op.args[0]
+
+        layout = "NHWC"
+        self.ifm = TensorParams(
+            cast.args[0],
+            layout,
+            requantize.args[RequantArgs.IFM_SCALE.value],
+            requantize.args[RequantArgs.IFM_ZERO_POINT.value],
+        )
+        self.ofm = TensorParams(
+            requantize,
+            layout,
+            requantize.args[RequantArgs.OFM_SCALE.value],
+            requantize.args[RequantArgs.OFM_ZERO_POINT.value],
+        )
+
+        ifm_shape = self.ifm.shape
+        self.height = ifm_shape[0] if len(ifm_shape) in (2, 3) else 
ifm_shape[1]
+        self.width = ifm_shape[1] if len(ifm_shape) in (2, 3) else ifm_shape[2]
+        self.keepdims = attrs.keepdims
+
+        self.axis = list(sorted(attrs.axis))
+        if attrs.exclude:
+            self.axis = [i for i in range(len(self.ifm.shape)) if i not in 
self.axis]
+
+    def is_valid(self) -> bool:
+        """
+        Checks whether Sum has compatible attributes with HW.
+        """
+
+        ifm_shape_len = len(self.ifm.shape)
+
+        if not check_valid_dtypes([self.ifm], [np.uint8, np.int8, np.int16, 
np.int32]):
+            return False
+        if not check_valid_dtypes([self.ofm], [np.int8]):

Review Comment:
   Curious if it also possible to support other output data types? Would it 
just involve a change to the output data-type in the legalized shift operation? 
(guessing it's more involved than that)



##########
python/tvm/relay/backend/contrib/ethosu/te/pooling.py:
##########
@@ -110,18 +110,21 @@ def pooling_compute(
     padding = [int(v) for v in padding]
     stride_h, stride_w = [int(v) for v in strides]
     pool_shape_h, pool_shape_w = [int(v) for v in pool_shape]
+    ifm_channels = ofm_channels if pooling_type != "SUM" else ifm.shape[-1]
     upscale_factor = 2 if upscale != "NONE" else 1
 
     # Compute operation for the IFM DMA pipeline
     dmaed_ifm = dma_ifm_compute(
-        ifm, ifm_layout, ifm_zero_point, ifm_scale, ofm_channels, padding, 
upscale_factor
+        ifm, ifm_layout, ifm_zero_point, ifm_scale, ifm_channels, padding, 
upscale_factor
     )
 
     # Pooling compute operation
     ofm_height = (dmaed_ifm.shape[1] - pool_shape_h) // stride_h + 1
     ofm_width = (dmaed_ifm.shape[2] - pool_shape_w) // stride_w + 1
     rh = te.reduce_axis((0, pool_shape_h), name="ry")
     rw = te.reduce_axis((0, pool_shape_w), name="rx")
+    rc = te.reduce_axis((0, 1 if pooling_type != "SUM" else ifm_channels), 
name="rc")
+    ofm_dtype = ifm.dtype if pooling_type != "SUM" else "int32"

Review Comment:
   According to the TRM, the AVG pool output datatype can be any of {uint8, 
int8, int16} without regard to the input datatype - it's okay to leave like 
this for now, just noting as something that might come up later 



##########
src/relay/op/contrib/ethosu/pooling.cc:
##########
@@ -46,14 +46,27 @@ bool EthosuPoolingRel(const Array<Type>& types, int 
num_inputs, const Attrs& att
 
   const String operator_name = "ethosu_pooling";
 
-  if (param->pooling_type != "AVG" && param->pooling_type != "MAX") {
+  if (param->pooling_type != "AVG" && param->pooling_type != "MAX" &&
+      param->pooling_type != "SUM") {
     reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
                                      << "Invalid operator: expected " << 
operator_name
-                                     << " type 'AVG' or 'MAX' but was " << 
param->pooling_type);
+                                     << " type 'AVG', 'MAX', or 'SUM' but was "
+                                     << param->pooling_type);
     return false;
   }
 
-  CheckDataType(reporter, ifm->dtype, {DataType::UInt(8), DataType::Int(8)}, 
operator_name, "ifm",

Review Comment:
   Yes I believe they should be, its likely just a case of not having needed 
`int16` previously, but I think it would be good to add :)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to