ekalda commented on code in PR #13669:
URL: https://github.com/apache/tvm/pull/13669#discussion_r1058904002


##########
tests/python/topi/python/test_topi_conv2d_int8.py:
##########
@@ -298,378 +175,462 @@ def get_ref_data():
 
         a_np, w_np, b_np, c_np = get_ref_data()
 
-        with tvm.target.Target(target):
-            C = compute(
-                A,
-                W,
-                (stride, stride),
-                padding,
-                (dilation, dilation),
-                "NCHW",
-                "NCHW",
-                out_dtype,
-            )
+        dev = tvm.device(target, 0)
+        with tvm.target.Target(target) as tvm_target:
+            C = compute(A, W, (stride, stride), padding, (dilation, dilation), 
dtype)
             if add_bias:
                 C = topi.add(C, bias)
             if add_relu:
                 C = topi.nn.relu(C)
             s = schedule([C])
 
-        a = tvm.nd.array(a_np.astype(dtype), dev)
-        w = tvm.nd.array(w_np.astype(dtype), dev)
-        b = tvm.nd.array(b_np.astype(out_dtype), dev)
-        c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), 
dev)
-
-        if add_bias:
-            compile_args = [A, W, bias, C]
-            run_args = [a, w, b, c]
-        else:
-            compile_args = [A, W, C]
-            run_args = [a, w, c]
-
-        func = tvm.build(
-            s,
-            compile_args,
-            target,
-            name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
-            % (batch, in_channel, in_size, num_filter, kernel, stride, 
padding_sum, dilation),
-        )
+            a = tvm.nd.array(a_np, dev)
+            w = tvm.nd.array(w_np, dev)
+            b = tvm.nd.array(b_np, dev)
+            c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), 
dtype=C.dtype), dev)
 
-        if build_only:
-            return
+            build_inputs = [A, W, bias, C] if add_bias else [A, W, C]
+            inference_inputs = (a, w, b, c) if add_bias else (a, w, c)
+
+            func = tvm.build(
+                s,
+                build_inputs,
+                target,
+                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+                % (
+                    batch,
+                    in_channel,
+                    in_size,
+                    num_filter,
+                    kernel,
+                    stride,
+                    padding_sum,
+                    dilation,
+                ),
+            )
 
-        print("Running on target: %s" % target)
+            build_only = tvm_target.features.is_aarch64 and 
(platform.machine() != "aarch64")
 
-        func(*run_args)
+            if not build_only:
+                print("Running on target: %s" % target)
+                func(*inference_inputs)
+                tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
 
-        tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
 
-    targets = [
-        (
-            "cuda",
-            lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, 
s, p, d, l, o),
-            topi.cuda.schedule_conv2d_NCHWc_int8,
-            4,
-            False,
-        ),
-        # Disable on CI since it does not support spirv int8 dot product
-        # (
-        #     "vulkan -from_device=0",
-        #     lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, 
w, s, p, d, l, o),
-        #     topi.cuda.schedule_conv2d_NCHWc_int8,
-        #     4,
-        #     False,
-        # ),
-    ]
-
-    build_only_aarch64 = platform.machine() != "aarch64"
-
-    targets.append(
[email protected]("in_dtype", ["int8", "uint8"])
[email protected](
+    "params",
+    [
+        # ResNet18 workloads where channels in / out are multiple of 
oc_block_factor
+        (1, 64, 56, 64, 3, 1, 1, 1, False, False),
+        (1, 64, 56, 64, 1, 1, 0, 1, False, False),
+        (1, 64, 56, 128, 3, 2, 1, 1, False, False),
+        (1, 64, 56, 128, 1, 2, 0, 1, False, False),
+        (1, 128, 28, 128, 3, 1, 1, 1, False, False),
+        (1, 128, 28, 256, 3, 2, 1, 1, False, False),
+        (1, 128, 28, 256, 1, 2, 0, 1, False, False),
+        (1, 256, 14, 256, 3, 1, 1, 1, False, False),
+        (1, 256, 14, 512, 3, 2, 1, 1, False, False),
+        (1, 256, 14, 512, 1, 2, 0, 1, False, False),
+        (1, 512, 7, 512, 3, 1, 1, 1, False, False),
+        # bias, relu
+        (1, 64, 56, 64, 3, 1, 1, 1, False, True),
+        (1, 64, 56, 64, 3, 1, 1, 1, True, False),
+        (1, 64, 56, 64, 3, 1, 1, 1, True, True),
+        # dilation = 2
+        (1, 64, 56, 64, 3, 1, 1, 2, False, False),
+        # batch size
+        (4, 64, 56, 64, 3, 1, 1, 1, False, False),
+        (9, 64, 56, 64, 3, 1, 1, 1, False, False),
+        # weird workloads
+        (4, 4, 4, 8, 4, 4, 4, 1, False, False),
+        # inception v3 workloads where channels in / out are multiple of 
oc_block_factor
+        (1, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (1, 32, 147, 64, 3, 1, 1, 1, False, False),
+        (1, 64, 73, 80, 1, 1, 0, 1, False, False),
+        (1, 80, 73, 192, 3, 1, 0, 1, False, False),
+        (1, 192, 35, 64, 1, 1, 0, 1, False, False),
+        (1, 192, 35, 48, 1, 1, 0, 1, False, False),
+        (1, 48, 35, 64, 5, 1, 2, 1, False, False),
+        (1, 64, 35, 96, 3, 1, 1, 1, False, False),
+        (1, 96, 35, 96, 3, 1, 1, 1, False, False),
+        (1, 192, 35, 32, 1, 1, 0, 1, False, False),
+        (1, 256, 35, 64, 1, 1, 0, 1, False, False),
+        (1, 256, 35, 48, 1, 1, 0, 1, False, False),
+        (1, 288, 35, 64, 1, 1, 0, 1, False, False),
+        (1, 288, 35, 48, 1, 1, 0, 1, False, False),
+        (1, 288, 35, 384, 3, 2, 0, 1, False, False),
+        (1, 96, 35, 96, 3, 2, 0, 1, False, False),
+        (1, 768, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 768, 17, 128, 1, 1, 0, 1, False, False),
+        (1, 128, 17, 128, 1, 1, 0, 1, False, False),
+        (1, 128, 17, 192, 7, 1, 3, 1, False, False),
+        (1, 128, 17, 128, 7, 1, 3, 1, False, False),
+        (1, 128, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 768, 17, 160, 1, 1, 0, 1, False, False),
+        (1, 160, 17, 160, 1, 1, 0, 1, False, False),
+        (1, 160, 17, 192, 7, 1, 3, 1, False, False),
+        (1, 160, 17, 160, 7, 1, 3, 1, False, False),
+        (1, 160, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 192, 17, 192, 1, 1, 0, 1, False, False),
+        (1, 192, 17, 192, 7, 1, 3, 1, False, False),
+        (1, 192, 17, 320, 3, 2, 0, 1, False, False),
+        (1, 192, 17, 192, 3, 2, 0, 1, False, False),
+        (1, 1280, 8, 320, 1, 1, 0, 1, False, False),
+        (1, 1280, 8, 384, 1, 1, 0, 1, False, False),
+        (1, 384, 8, 384, 1, 1, 0, 1, False, False),
+        (1, 384, 8, 384, 3, 1, 1, 1, False, False),
+        (1, 1280, 8, 448, 1, 1, 0, 1, False, False),
+        (1, 448, 8, 384, 3, 1, 1, 1, False, False),
+        (1, 1280, 8, 192, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 320, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 384, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 448, 1, 1, 0, 1, False, False),
+        (1, 2048, 8, 192, 1, 1, 0, 1, False, False),
+        (1, 1024, 19, 88, 3, 1, 1, 1, False, False),
+        # batch > 1
+        (7, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (8, 32, 149, 32, 3, 1, 0, 1, False, False),
+        (32, 32, 149, 32, 3, 1, 0, 1, False, False),
+        # Asymmetric padding
+        (1, 32, 35, 64, 7, 2, (0, 0, 1, 1), 1, False, False),
+        (1, 64, 8, 128, 3, 1, (3, 3, 2, 2), 1, False, False),
+        (1, 64, 8, 64, 1, 1, (1, 2, 2, 1), 1, False, False),
+        (1, 64, 17, 192, 1, 1, (1, 2), 1, False, False),
+        (1, 64, 8, 64, 3, 1, (3, 1), 1, False, False),
+        (1, 128, 8, 384, 3, 1, (0, 2), 1, False, False),
+        (1, 64, 8, 64, 1, 1, "VALID", 1, False, False),
+        (1, 392, 8, 64, 3, 1, "VALID", 1, False, False),
+        (1, 512, 19, 64, 1, 1, "SAME", 1, False, False),
+        (1, 64, 16, 32, 2, 1, "SAME", 1, False, False),
+        (1, 64, 8, 64, 3, 1, (1, 2, 2, 1), 1, False, True),
+        (1, 64, 8, 64, 5, 2, (1, 3), 1, True, False),
+        (1, 64, 56, 64, 3, 1, "VALID", 1, True, True),
+        (1, 64, 56, 64, 24, 1, "SAME", 1, True, True),
+    ],
+)
+def test_conv2d_NCHWc_int8(in_dtype, params):
+    with Int8Fallback():
         (
-            "llvm -device arm_cpu -mtriple aarch64-linux-gnu 
-mattr=+neon,+v8.2a,+dotprod",
-            topi.arm_cpu.conv2d_NCHWc_int8,
-            topi.arm_cpu.schedule_conv2d_NCHWc_int8,
-            8,
-            build_only_aarch64,
+            batch,
+            in_channel,
+            in_size,
+            num_filter,
+            kernel,
+            stride,
+            padding,
+            dilation,
+            add_bias,
+            add_relu,
+        ) = params
+        pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, 
(kernel, kernel))
+        padding_sum = pad_top + pad_left + pad_bottom + pad_right
+        print(
+            "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
+            % (batch, in_channel, in_size, num_filter, kernel, stride, 
padding_sum, dilation)
         )
-    )
-
-    if in_dtype == "int8":
-        targets += [
-            (
-                "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon",
-                topi.arm_cpu.conv2d_NCHWc_int8,
-                topi.arm_cpu.schedule_conv2d_NCHWc_int8,
-                8,
-                build_only_aarch64,
-            ),
-            (
-                "rocm -mattr=+dotprod",
-                lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, 
w, s, p, d, l, o),
-                topi.cuda.schedule_conv2d_NCHWc_int8,
-                4,
-                False,
-            ),
-        ]
-
-    for target, compute, schedule, oc_block_factor, build_only in targets:
-        check_target(target, compute, schedule, oc_block_factor, build_only)
-
-
-def verify_conv2d_nchw_int8(
-    in_dtype,
-    batch,
-    in_channel,
-    in_size,
-    num_filter,
-    kernel,
-    stride,
-    padding,
-    dilation=1,
-    add_bias=False,
-    add_relu=False,
-):
-    pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, 
kernel))
-    padding_sum = pad_top + pad_left + pad_bottom + pad_right
-    print(
-        "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
-        % (batch, in_channel, in_size, num_filter, kernel, stride, 
padding_sum, dilation)
-    )
-
-    in_height = in_width = in_size
-
-    A = te.placeholder((batch, in_channel, in_height, in_width), name="A", 
dtype=in_dtype)
-    W = te.placeholder((num_filter, in_channel, kernel, kernel), name="W", 
dtype=in_dtype)
-    bias = te.placeholder((num_filter, 1, 1), name="bias", dtype=in_dtype)
-
-    a_shape = get_const_tuple(A.shape)
-    w_shape = get_const_tuple(W.shape)
-    bias_shape = get_const_tuple(bias.shape)
-    dtype = A.dtype
-
-    @memoize("topi.tests.test_topi_conv2d_int8.verify_conv2d_nchw")
-    def get_ref_data():
-        a_np = np.random.randint(low=-128, high=127, 
size=a_shape).astype(dtype)
-        w_np = np.random.randint(low=-128, high=128, 
size=w_shape).astype(dtype)
-        b_np = np.random.uniform(size=bias_shape).astype(dtype)
-        dw_np = tvm.topi.testing.dilate_python(w_np, (1, 1, dilation, 
dilation))
-        c_np = tvm.topi.testing.conv2d_nchw_python(a_np, dw_np, stride, 
padding).astype(dtype)
-
-        if add_bias:
-            b_np = np.random.uniform(size=bias_shape).astype(dtype)
-            c_np += b_np
-        if add_relu:
-            c_np = np.maximum(c_np, 0)
-
-        return a_np, w_np, b_np, c_np
-
-    a_np, w_np, b_np, c_np = get_ref_data()
-
-    def verify_workload_padding():
-        _, _, out_height, out_width = get_const_tuple(c_np.shape)
-        wkl = _get_workload(A, W, (stride, stride), padding, dilation, dtype)
-
-        # for testing functionality,
-        # we choose arbitrary int32_lanes and num_int8_elements can divide the 
channel,
-        # regardless of the performance.
-        int32_lanes, num_int8_elements = num_filter, in_channel
 
-        # check if tile_ow candidates are the factors of the right output 
weight.
-        cfg = autotvm.get_config()
-        fallback_schedule_cpu_common_int8(cfg, wkl, int32_lanes, 
num_int8_elements)
-        ow_tile = np.prod(cfg["tile_ow"].size)
-
-        tvm.testing.assert_allclose(ow_tile, out_width)
+        in_height = in_width = in_size
+
+        A = te.placeholder((batch, in_channel, in_height, in_width), name="A", 
dtype=in_dtype)
+        W = te.placeholder((num_filter, in_channel, kernel, kernel), name="W", 
dtype=in_dtype)
+
+        a_shape = get_const_tuple(A.shape)
+        w_shape = get_const_tuple(W.shape)
+        dtype = A.dtype
+        out_dtype = "int32" if in_dtype == "int8" else "uint32"
+        lo = -128 if in_dtype == "int8" else 0
+        hi = 127 if in_dtype == "int8" else 255
+
+        def check_target(target, compute, schedule, oc_block_factor, 
build_only):
+            dev = tvm.device(target, 0)
+            if not tvm.testing.device_enabled(target):
+                print("Skip because %s is not enabled" % target)
+                return

Review Comment:
   Done



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to