This is an automated email from the ASF dual-hosted git repository.

ekalda pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ac9a943c4d [TOPI][Testing] Enable conv2d NHWC fp16 topi testing for 
`arm_cpu` (#17007)
ac9a943c4d is described below

commit ac9a943c4dd45cb98c5801631450fd9bb44e7804
Author: Andrei Hutu <[email protected]>
AuthorDate: Wed May 22 11:01:02 2024 +0100

    [TOPI][Testing] Enable conv2d NHWC fp16 topi testing for `arm_cpu` (#17007)
    
    This commit adds fp16 test cases to the conv2d NHWC TOPI schedules for 
`arm_cpu`.
    Following the example of #8529, the numpy reference conv2d output is 
computed in fp32 instead of fp16, while the absolute tolerance varies for each 
test case according to the size of the summed axis and the output's largest 
element.
---
 python/tvm/testing/utils.py                |  7 +++++
 tests/python/topi/test_topi_conv2d_nhwc.py | 49 ++++++++++++++++++++++++------
 2 files changed, 47 insertions(+), 9 deletions(-)

diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py
index 38b39b5fc2..84b631cf38 100644
--- a/python/tvm/testing/utils.py
+++ b/python/tvm/testing/utils.py
@@ -1057,6 +1057,13 @@ requires_arm_dot = Feature(
 )
 
 
+requires_arm_fp16 = Feature(
+    "arm_fp16",
+    "Arm(R) Neon(TM) instructions for FP16",
+    run_time_check=lambda: _has_cpu_feat("fullfp16"),
+)
+
+
 requires_aarch64_sve = Feature(
     "arm_sve",
     "AArch64 SVE",
diff --git a/tests/python/topi/test_topi_conv2d_nhwc.py 
b/tests/python/topi/test_topi_conv2d_nhwc.py
index 6ff844de08..b5c9518d34 100644
--- a/tests/python/topi/test_topi_conv2d_nhwc.py
+++ b/tests/python/topi/test_topi_conv2d_nhwc.py
@@ -53,7 +53,7 @@ device = tvm.testing.parameter(
         topi.arm_cpu.schedule_conv2d_nhwc_spatial_pack,
     ),
     (
-        "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.2a",
+        "llvm --device arm_cpu --mtriple aarch64-linux-gnu 
-mattr=+v8.2a,+fullfp16",
         topi.arm_cpu.compute_conv2d_NHWC_hybrid,
         topi.arm_cpu.schedule_conv2d_NHWC_hybrid,
     ),
@@ -64,7 +64,7 @@ device = tvm.testing.parameter(
     ),
 )
 
-dtype = tvm.testing.parameter("float32")
+dtype = tvm.testing.parameter("float16", "float32")
 
 batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation = 
tvm.testing.parameters(
     # Pad M, N, K
@@ -104,14 +104,36 @@ def ref_data(dtype, batch, in_channel, in_size, 
num_filter, kernel, stride, padd
     a_shape = (batch, in_height, in_width, in_channel)
     w_shape = (kernel, kernel, in_channel, num_filter)
 
+    np.random.seed(0)
     a_np = np.random.uniform(size=a_shape).astype(dtype)
     w_np = np.random.uniform(size=w_shape).astype(dtype)
     dw_np = tvm.topi.testing.dilate_python(w_np, (dilation, dilation, 1, 1))
-    b_np = tvm.topi.testing.conv2d_nhwc_python(a_np, dw_np, stride, padding)
+
+    # scipy.signal.convolve2d does not support float16 data types,
+    # and the python fallback would be too slow for general use.
+    conv_dtype = "float32" if dtype == "float16" else dtype
+    b_np = tvm.topi.testing.conv2d_nhwc_python(
+        a_np.astype(conv_dtype), dw_np.astype(conv_dtype), stride, padding
+    ).astype(dtype)
     return a_np, w_np, b_np
 
 
-def test_conv2d_nhwc_gemm_fp32(device, ref_data, dtype, stride, padding, 
dilation):
+def get_tolerance(dtype, w_np, b_np):
+    if dtype == "float16":
+        # A summation in float16 with a single accumulator very
+        # quickly runs into large rounding errors.
+        # This tolerance is necessary to ensure no false negatives,
+        # but it may introduce false positives, depending on schedule 
behaviour.
+        num_values_summed = w_np.shape[0] * w_np.shape[1] * w_np.shape[2]
+        next_float_gap_size = np.nextafter(b_np.max(), np.inf, 
dtype=b_np.dtype) - b_np.max()
+        tol = {"rtol": 1e-5, "atol": num_values_summed * next_float_gap_size / 
2}
+    else:
+        tol = {"rtol": 1e-5, "atol": 1e-7}
+
+    return tol
+
+
+def test_conv2d_nhwc_gemm(device, ref_data, dtype, stride, padding, dilation):
     a_np, w_np, b_np = ref_data
 
     A = te.placeholder(a_np.shape, name="A", dtype=dtype)
@@ -130,14 +152,21 @@ def test_conv2d_nhwc_gemm_fp32(device, ref_data, dtype, 
stride, padding, dilatio
 
         # Run only on AArch64 devices
         # Do not run SVE schedules on non-SVE devices
-        build_only = platform.machine() != "aarch64" or (
-            target.features.has_sve and not 
tvm.testing.requires_aarch64_sve.run_time_check()
+        build_only = (
+            platform.machine() != "aarch64"
+            or (target.features.has_sve and not 
tvm.testing.requires_aarch64_sve.run_time_check())
+            or (
+                dtype == "float16"
+                and target.features.has_fp16_simd
+                and not tvm.testing.requires_arm_fp16.run_time_check()
+            )
         )
         if build_only:
             return
 
         func(a, w, b)
-    tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5)
+    tol = get_tolerance(dtype, w_np, b_np)
+    tvm.testing.assert_allclose(b.numpy(), b_np, rtol=tol["rtol"], 
atol=tol["atol"])
 
 
 def test_conv2d_nhwc_hwio(target, dev, ref_data, dtype, stride, padding, 
dilation):
@@ -155,7 +184,8 @@ def test_conv2d_nhwc_hwio(target, dev, ref_data, dtype, 
stride, padding, dilatio
     b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), dev)
     func = tvm.build(s, [A, W, B], target)
     func(a, w, b)
-    tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5)
+    tol = get_tolerance(dtype, w_np, b_np)
+    tvm.testing.assert_allclose(b.numpy(), b_np, rtol=tol["rtol"], 
atol=tol["atol"])
 
 
 def test_conv2d_nhwc_ohwi(ref_data, dtype, stride, padding, dilation):
@@ -184,7 +214,8 @@ def test_conv2d_nhwc_ohwi(ref_data, dtype, stride, padding, 
dilation):
     b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), dev)
     func = tvm.build(s, [A, W, B], target)
     func(a, w, b)
-    tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5)
+    tol = get_tolerance(dtype, w_np_hwio, b_np)
+    tvm.testing.assert_allclose(b.numpy(), b_np, rtol=tol["rtol"], 
atol=tol["atol"])
 
 
 if __name__ == "__main__":

Reply via email to