This is an automated email from the ASF dual-hosted git repository.
lukhut pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new f719151486 [Bugfix][Strategy] Fix `arm_cpu` int8 conv2d strategy for
dotprod and i8mm targets (#15711)
f719151486 is described below
commit f719151486c452d0c5c44339c3578bf64711ef29
Author: Andrei Hutu <[email protected]>
AuthorDate: Mon Sep 11 09:28:02 2023 +0100
[Bugfix][Strategy] Fix `arm_cpu` int8 conv2d strategy for dotprod and i8mm
targets (#15711)
Whenever both dotprod and i8mm were available together on a target (e.g.
`"llvm --device=arm_cpu --mtriple=aarch64-linux-gnu
-mattr=+v8.2a,+dotprod,+i8mm"`), the native int8 conv2d implementation
corresponding to the `+dotprod` attribute would be selected, but the compute
definition of the conv2d operation would be constructed for the `+i8mm`
attribute and its related interleaved schedule instead. The reason for this was
a different order of conditional statements being used in 2 sepa [...]
- `arm_cpu.py`: When selecting the conv2d implementation, the program
first checked for `dotprod` support. If present, it chose the native schedule
- `conv2d_gemm.py`: when constructing the compute definition, `i8mm`
support is checked first, then `dotprod`
To fix this, I modified the int8 conv2d strategy to prioritize `i8mm` over
`dotprod` when both are available too.
---
python/tvm/relay/op/strategy/arm_cpu.py | 49 ++++++++++++++++------
.../relay/strategy/test_select_implementation.py | 8 ++++
2 files changed, 45 insertions(+), 12 deletions(-)
diff --git a/python/tvm/relay/op/strategy/arm_cpu.py
b/python/tvm/relay/op/strategy/arm_cpu.py
index b64c541863..a23ccf8f69 100644
--- a/python/tvm/relay/op/strategy/arm_cpu.py
+++ b/python/tvm/relay/op/strategy/arm_cpu.py
@@ -213,19 +213,35 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type,
target):
is_aarch64 = target.features.is_aarch64
has_asimd = target.features.has_asimd
has_dot_prod = target.features.has_dotprod
+ has_matmul_i8 = target.features.has_matmul_i8
- if has_dot_prod and data.dtype in ["int8", "uint8"]:
- strategy.add_implementation(
-
wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_quantized_native),
-
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized_native),
- name="conv2d_NHWC_quantized_native.arm_cpu",
- )
- if is_aarch64 and has_asimd and data.dtype in ["int8",
"uint8"]:
- strategy.add_implementation(
-
wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved),
-
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved),
- name="conv2d_NHWC_quantized_interleaved.arm_cpu",
- )
+ if data.dtype in ["int8", "uint8"]:
+ if has_matmul_i8:
+ strategy.add_implementation(
+ wrap_compute_conv2d(
+
topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved
+ ),
+ wrap_topi_schedule(
+
topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved
+ ),
+ name="conv2d_NHWC_quantized_interleaved.arm_cpu",
+ )
+ if has_dot_prod:
+ strategy.add_implementation(
+
wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_quantized_native),
+
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized_native),
+ name="conv2d_NHWC_quantized_native.arm_cpu",
+ )
+ if is_aarch64 and has_asimd:
+ strategy.add_implementation(
+ wrap_compute_conv2d(
+
topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved
+ ),
+ wrap_topi_schedule(
+
topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved
+ ),
+ name="conv2d_NHWC_quantized_interleaved.arm_cpu",
+ )
if (not is_aarch64) or (data.dtype not in ["int8", "uint8"]):
# TODO(@giuseros)
# This strategy errors out for quantized data types when
tuning.
@@ -471,10 +487,19 @@ def
conv2d_gemm_without_weight_transform_strategy_arm_cpu(attrs, inputs, out_typ
is_aarch64 = target.features.is_aarch64
has_asimd = target.features.has_asimd
has_dot_prod = target.features.has_dotprod
+ has_matmul_i8 = target.features.has_matmul_i8
interleaved_compute =
topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved_without_transform
native_compute =
topi.arm_cpu.compute_conv2d_NHWC_quantized_native_without_transform
if layout == "NHWC" and data.dtype in ["int8", "uint8"]:
+ if has_matmul_i8:
+ strategy.add_implementation(
+ wrap_compute_conv2d_gemm(interleaved_compute),
+ wrap_topi_schedule(
+
topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved_without_transform
+ ),
+
name="conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu",
+ )
if has_dot_prod:
strategy.add_implementation(
wrap_compute_conv2d_gemm(native_compute),
diff --git a/tests/python/relay/strategy/test_select_implementation.py
b/tests/python/relay/strategy/test_select_implementation.py
index 906ef2d161..d7dd0abbc4 100644
--- a/tests/python/relay/strategy/test_select_implementation.py
+++ b/tests/python/relay/strategy/test_select_implementation.py
@@ -81,6 +81,14 @@ def test_concatenate(target, expected_implementation):
"llvm --device=arm_cpu --mtriple=aarch64-linux-gnu
-mattr=+v8.2a,+i8mm",
"conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu",
),
+ (
+ "llvm --device=arm_cpu --mtriple=aarch64-linux-gnu
-mattr=+v8.2a,+dotprod,+i8mm",
+ "conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu",
+ ),
+ (
+ "llvm --device=arm_cpu --mtriple=aarch64-linux-gnu -mattr=+v9a",
+ "conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu",
+ ),
],
)
def test_int8_conv2d(target, expected_impl):