ekalda commented on code in PR #13669:
URL: https://github.com/apache/tvm/pull/13669#discussion_r1058903893
##########
tests/python/topi/python/test_topi_conv2d_int8.py:
##########
@@ -35,261 +35,138 @@
import platform
-def compile_conv2d_NHWC_gemm_int8_arm(
- batch,
- in_channel,
- in_size,
- num_filter,
- kernel,
- stride,
- padding,
- dilation=1,
- add_bias=False,
- add_relu=False,
-):
- pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel,
kernel))
- padding_sum = pad_top + pad_left + pad_bottom + pad_right
- print(
- "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
- % (batch, in_channel, in_size, num_filter, kernel, stride,
padding_sum, dilation)
- )
-
- in_height = in_width = in_size
- A = te.placeholder((batch, in_height, in_width, in_channel), name="A",
dtype="int8")
- W = te.placeholder((kernel, kernel, in_channel, num_filter), name="W",
dtype="int8")
- bias = te.placeholder((num_filter,), name="bias", dtype="int8")
- dtype = "int32"
- devices = [
- (
- "llvm --device arm_cpu --mtriple aarch64-linux-gnu",
- topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved,
- topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved,
- ),
- (
- "llvm --device arm_cpu --mtriple aarch64-linux-gnu
-mattr=+v8.2a,+dotprod",
- topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved,
- topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved,
- ),
- (
- "llvm --device arm_cpu --mtriple aarch64-linux-gnu
-mattr=+v8.2a,+dotprod",
- topi.arm_cpu.compute_conv2d_NHWC_quantized_native,
- topi.arm_cpu.schedule_conv2d_NHWC_quantized_native,
- ),
- # TODO(giuseros) Need LLVM-11 in order to compile with +i8mm extension
- # (
- # "llvm --device arm_cpu --mtriple aarch64-linux-gnu
-mattr=+v8.2a,+i8mm",
- # topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved,
- # topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved,
- # ),
- ]
-
- for device_tuple in devices:
- target = device_tuple[0]
- compute = device_tuple[1]
- schedule = device_tuple[2]
-
- dev = tvm.device(target, 0)
- if not tvm.testing.device_enabled(target):
- print("Skip because %s is not enabled" % target)
- return
- print("Compiling on arm AArch64 target: %s" % target)
- with tvm.target.Target(target) as tvm_target:
- assert tvm_target.features.is_aarch64, "AArch64 target not
recognized"
+devices = [
+ (
+ "llvm",
+ topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved,
+ topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved,
+ ),
+ (
+ "llvm --device arm_cpu --mtriple aarch64-linux-gnu",
+ topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved,
+ topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved,
+ ),
+ (
+ "llvm --device arm_cpu --mtriple aarch64-linux-gnu
-mattr=+v8.2a,+dotprod",
+ topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved,
+ topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved,
+ ),
+ (
+ "llvm --device arm_cpu --mtriple aarch64-linux-gnu
-mattr=+v8.2a,+dotprod",
+ topi.arm_cpu.compute_conv2d_NHWC_quantized_native,
+ topi.arm_cpu.schedule_conv2d_NHWC_quantized_native,
+ ),
+ # TODO(giuseros) We need LLVM-11 in order to compile with +i8mm extension
+ # (
+ # "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.2a,+i8mm",
+ # topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved,
+ # topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved,
+ # ),
+]
+
+
[email protected]_llvm
[email protected]("device", devices)
[email protected](
+ "params",
+ [
+ # Subset of inception v3 expanded (dilation > 1, batch > 1, 'VALID'
padding)
+ (1, 3, 299, 32, 3, 2, "SAME", 1, False, False),
+ (1, 32, 149, 32, 3, 1, "SAME", 2, False, False),
+ (4, 32, 147, 64, 3, 1, "SAME", 1, False, False),
+ (1, 64, 73, 80, 1, 1, "SAME", 1, False, False),
+ (1, 80, 73, 192, 3, 1, "SAME", 1, False, False),
+ (1, 192, 35, 48, 1, 1, "SAME", 1, False, False),
+ (1, 192, 35, 64, 1, 1, "VALID", 1, False, False),
+ (1, 192, 35, 32, 1, 1, "SAME", 1, False, False),
+ (1, 48, 35, 64, 5, 1, "SAME", 1, False, False),
+ (1, 96, 35, 96, 3, 1, "SAME", 1, False, False),
+ (1, 256, 35, 48, 1, 1, "SAME", 1, False, False),
+ (1, 256, 35, 64, 1, 1, "SAME", 1, False, False),
+ (1, 288, 35, 64, 1, 1, "SAME", 1, False, False),
+ (1, 288, 35, 48, 1, 1, "SAME", 1, False, False),
+ (1, 96, 35, 96, 3, 2, "SAME", 1, False, False),
+ (1, 128, 17, 192, 7, 1, "SAME", 2, False, False),
+ (1, 160, 17, 160, 7, 1, "SAME", 1, False, False),
+ (1, 160, 17, 192, 1, 1, "VALID", 1, False, False),
+ (1, 192, 17, 192, 1, 1, "SAME", 1, False, False),
+ (1, 768, 5, 128, 1, 1, "SAME", 1, False, False),
+ (1, 192, 17, 320, 3, 2, "SAME", 1, False, False),
+ (1, 192, 17, 192, 3, 2, "SAME", 1, False, False),
+ (1, 1280, 8, 192, 1, 1, "SAME", 1, False, False),
+ (1, 1280, 8, 384, 1, 1, "SAME", 1, False, False),
+ (1, 1280, 8, 320, 1, 1, "SAME", 1, False, False),
+ (1, 1280, 8, 448, 1, 1, "SAME", 1, False, False),
+ (1, 384, 8, 384, 1, 1, "SAME", 1, False, False),
+ (1, 384, 8, 384, 3, 1, "SAME", 1, False, False),
+ (1, 448, 8, 384, 3, 1, "VALID", 1, False, False),
+ (1, 2048, 8, 320, 1, 1, "SAME", 1, False, False),
+ (1, 2048, 8, 448, 1, 1, "SAME", 1, True, True),
+ (1, 2048, 8, 192, 1, 1, "SAME", 1, True, False),
+ # A trouble case for native schedule
+ (1, 8, 1, 24, 1, 1, "SAME", 1, False, False),
+ ],
+)
+def test_conv2d_NHWC_gemm_int8(params, device):
- C = compute(A, W, (stride, stride), padding, (dilation, dilation),
dtype)
- if add_bias:
- C = topi.add(C, bias)
- if add_relu:
- C = topi.nn.relu(C)
- s = schedule([C])
+ with Int8Fallback():
+ target, compute, schedule = device
- if add_bias:
- tvm.build(
- s,
- [A, W, bias, C],
- target,
- name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
- % (batch, in_channel, in_size, num_filter, kernel, stride,
padding_sum, dilation),
- )
- func = tvm.build(
- s,
- [A, W, bias, C],
- target,
- name="relu_%dnnn_%d_%d_%d_%d_%d_%d_%d"
- % (batch, in_channel, in_size, num_filter, kernel, stride,
padding_sum, dilation),
- )
- else:
- func = tvm.build(
- s,
- [A, W, C],
- target,
- name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
- % (batch, in_channel, in_size, num_filter, kernel, stride,
padding_sum, dilation),
- )
+ (
+ batch,
+ in_channel,
+ in_size,
+ num_filter,
+ kernel,
+ stride,
+ padding,
+ dilation,
+ add_bias,
+ add_relu,
+ ) = params
+
+ # TODO(ekalda): These combinations hang during compilation
+ failing_cases = [
+ (devices[1], (1, 128, 17, 192, 7, 1, "SAME", 2, False, False)),
+ (devices[1], (1, 160, 17, 160, 7, 1, "SAME", 1, False, False)),
+ (
+ devices[1],
+ (1, 448, 8, 384, 3, 1, "VALID", 1, False, False),
+ ), # this one passes but is just incredibly slow
+ ]
+ if (device, params) in failing_cases:
+ return
Review Comment:
I added `pytest.skip` for these cases so that they will show up as skipped
tests in pytest log
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]