lhutton1 commented on code in PR #13488:
URL: https://github.com/apache/tvm/pull/13488#discussion_r1033339170


##########
tests/python/contrib/test_arm_compute_lib/test_pooling.py:
##########
@@ -169,91 +224,79 @@ def test_pooling():
     device = Device()
     np.random.seed(0)
 
-    fp32_dtype = ("float32", -127, 128, 0.001, 0.001)
-    uint8_dtype = ("uint8", 0, 255, 1, 0)
-    # fmt: off
-    trials = [
-        ["nn.max_pool2d", fp32_dtype,  (3, 3), (2, 2), (1, 1), (0, 0), False, 
False, (27, 27, 512), (0, 1),],
-        ["nn.max_pool2d", fp32_dtype,  (2, 2), (2, 2), (1, 1), (0, 0), False, 
True,  (16, 16, 16),  (0, 1),],
-        ["nn.max_pool2d", fp32_dtype,  (3, 3), (2, 2), (1, 1), (1, 1), True,  
True,  (15, 15, 16),  (0, 1),],
-        ["nn.max_pool2d", fp32_dtype,  (2, 2), (2, 2), (1, 1), (0, 1), False, 
False, (16, 16, 16),  (0, 1),],
-        ["nn.max_pool2d", uint8_dtype, (3, 3), (2, 2), (1, 1), (0, 1), False, 
False, (16, 16, 16),  (0, 1),],
-        ["nn.max_pool2d", uint8_dtype, (2, 2), (2, 2), (1, 1), (1, 1), True,  
True,  (15, 15, 16),  (0, 1),],
-        ["nn.max_pool2d", uint8_dtype, (2, 2), (2, 2), (3, 2), (1, 1), True,  
True,  (15, 15, 16),  (1, 0),],
-        ["nn.avg_pool2d", fp32_dtype,  (2, 2), (2, 2), (1, 1), (1, 1), False, 
False, (16, 16, 16),  (0, 1),],
-        ["nn.avg_pool2d", fp32_dtype,  (2, 2), (2, 2), (1, 1), (0, 0), False, 
True,  (16, 16, 16),  (0, 1),],
-        ["nn.avg_pool2d", fp32_dtype,  (3, 3), (2, 2), (3, 2), (0, 1), True,  
False, (15, 15, 16),  (1, 0),],
-        # 20.05: "exclude_padding equal false is not supported for AVG Pooling 
with padding on quantized types"
-        # ["nn.avg_pool2d", uint8_dtype, (2, 2), (2, 2), (1, 1), False, True, 
(16, 16, 16)],
-        ["nn.avg_pool2d", uint8_dtype, (3, 3), (2, 2), (1, 1), (0, 1), False, 
False, (16, 16, 16),  (0, 1),],
-        ["nn.l2_pool2d",  fp32_dtype,  (2, 2), (2, 2), (1, 1), (0, 1), True,  
False, (16, 16, 16),  (0, 1),],
-        ["nn.l2_pool2d",  fp32_dtype,  (3, 3), (2, 2), (1, 1), (0, 0), False, 
False, (16, 16, 16),  (0, 1),],
-        ["nn.l2_pool2d",  fp32_dtype,  (2, 2), (2, 2), (1, 1), (1, 1), False, 
True,  (15, 15, 16),  (0, 1),],
-    ]
-    # fmt: on
-    for (
+    low, high, atol, rtol = _get_low_high_atol_rtol(dtype)
+    tvm_ops, acl_partitions = expected_ops
+
+    shape = (1, *input_shape)
+    outputs = []
+    inputs = {
+        "a": tvm.nd.array(np.random.uniform(low, high, shape).astype(dtype)),
+    }
+
+    func = _get_pooling_model(
+        shape,
+        dtype,
         typef,
-        (dtype, low, high, atol, rtol),
         size,
         stride,
         dilation,
         pad,
         ceil_mode,
         count_include_pad,
-        input_shape,
-        (tvm_ops, acl_partitions),
-    ) in trials:
-        shape = (1, *input_shape)
-        outputs = []
-        inputs = {
-            "a": tvm.nd.array(np.random.uniform(low, high, 
shape).astype(dtype)),
-        }
-
-        func = _get_pooling_model(
-            shape,
-            dtype,
-            typef,
-            size,
-            stride,
-            dilation,
-            pad,
-            ceil_mode,
-            count_include_pad,
-            iter(inputs),
+        iter(inputs),
+    )
+
+    config = {
+        "size": size,
+        "stride": stride,
+        "shape": shape,
+        "pooling type": typef,
+        "dtype": dtype,
+        "padding": pad,
+        "dilation": dilation,
+        "ceil_mode": ceil_mode,
+        "count_include_pad": count_include_pad,
+        "inputs": inputs,
+    }
+    verify_saturation = True if dtype == "uint8" else False
+    for acl in [False, True]:
+        outputs.append(
+            build_and_run(
+                func,
+                inputs,
+                1,
+                None,
+                device,
+                enable_acl=acl,
+                tvm_ops=tvm_ops,
+                acl_partitions=acl_partitions,
+                config=config,
+            )[0]
         )
 
-        config = {
-            "size": size,
-            "stride": stride,
-            "shape": shape,
-            "pooling type": typef,
-            "dtype": dtype,
-            "padding": pad,
-            "dilation": dilation,
-            "ceil_mode": ceil_mode,
-            "count_include_pad": count_include_pad,
-            "inputs": inputs,
-        }
-        verify_saturation = True if dtype == "uint8" else False
-        for acl in [False, True]:
-            outputs.append(
-                build_and_run(
-                    func,
-                    inputs,
-                    1,
-                    None,
-                    device,
-                    enable_acl=acl,
-                    tvm_ops=tvm_ops,
-                    acl_partitions=acl_partitions,
-                    config=config,
-                )[0]
-            )
-
-        verify(outputs, atol=atol, rtol=rtol, config=config, 
verify_saturation=verify_saturation)
-
-
-def test_global_pooling():
+    verify(outputs, atol=atol, rtol=rtol, config=config, 
verify_saturation=verify_saturation)
+
+
[email protected](
+    "typef,dtype,input_shape",
+    [
+        ["nn.global_max_pool2d", "float32", (8, 8, 16)],
+        ["nn.global_max_pool2d", "float32", (9, 9, 16)],
+        ["nn.global_max_pool2d", "float32", (8, 8, 16)],
+        ["nn.global_max_pool2d", "uint8", (8, 8, 16)],
+        ["nn.global_max_pool2d", "uint8", (9, 9, 16)],
+        ["nn.global_max_pool2d", "int8", (8, 8, 16)],
+        ["nn.global_max_pool2d", "int8", (9, 9, 16)],
+        ["nn.global_avg_pool2d", "float32", (8, 8, 16)],
+        ["nn.global_avg_pool2d", "float32", (8, 8, 16)],
+        ["nn.global_avg_pool2d", "float32", (9, 9, 16)],
+        ["nn.global_avg_pool2d", "uint8", (8, 8, 16)],
+        ["nn.global_avg_pool2d", "uint8", (9, 9, 16)],
+        ["nn.global_avg_pool2d", "int8", (8, 8, 16)],
+        ["nn.global_avg_pool2d", "int8", (9, 9, 16)],

Review Comment:
   Missed a couple for float ;)
   ```suggestion
           ["nn.global_max_pool2d", "float32", (8, 8, 16)],
           ["nn.global_max_pool2d", "float32", (9, 9, 16)],
           ["nn.global_max_pool2d", "uint8", (8, 8, 16)],
           ["nn.global_max_pool2d", "uint8", (9, 9, 16)],
           ["nn.global_max_pool2d", "int8", (8, 8, 16)],
           ["nn.global_max_pool2d", "int8", (9, 9, 16)],
           ["nn.global_avg_pool2d", "float32", (8, 8, 16)],
           ["nn.global_avg_pool2d", "float32", (9, 9, 16)],
           ["nn.global_avg_pool2d", "uint8", (8, 8, 16)],
           ["nn.global_avg_pool2d", "uint8", (9, 9, 16)],
           ["nn.global_avg_pool2d", "int8", (8, 8, 16)],
           ["nn.global_avg_pool2d", "int8", (9, 9, 16)],
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to