This is an automated email from the ASF dual-hosted git repository.

lukhut pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new f719151486 [Bugfix][Strategy] Fix `arm_cpu` int8 conv2d strategy for 
dotprod and i8mm targets (#15711)
f719151486 is described below

commit f719151486c452d0c5c44339c3578bf64711ef29
Author: Andrei Hutu <[email protected]>
AuthorDate: Mon Sep 11 09:28:02 2023 +0100

    [Bugfix][Strategy] Fix `arm_cpu` int8 conv2d strategy for dotprod and i8mm 
targets (#15711)
    
    Whenever both dotprod and i8mm were available together on a target (e.g. 
`"llvm --device=arm_cpu --mtriple=aarch64-linux-gnu 
-mattr=+v8.2a,+dotprod,+i8mm"`), the native int8 conv2d implementation 
corresponding to the `+dotprod` attribute would be selected, but the compute 
definition of the conv2d operation would be constructed for the `+i8mm` 
attribute and its related interleaved schedule instead. The reason for this was 
a different order of conditional statements being used in 2 sepa [...]
     - `arm_cpu.py`: When selecting the conv2d implementation, the program 
first checked for `dotprod` support. If present, it chose the native schedule
     - `conv2d_gemm.py`: when constructing the compute definition, `i8mm` 
support is checked first, then `dotprod`
    To fix this, I modified the int8 conv2d strategy to prioritize `i8mm` over 
`dotprod` when both are available too.
---
 python/tvm/relay/op/strategy/arm_cpu.py            | 49 ++++++++++++++++------
 .../relay/strategy/test_select_implementation.py   |  8 ++++
 2 files changed, 45 insertions(+), 12 deletions(-)

diff --git a/python/tvm/relay/op/strategy/arm_cpu.py 
b/python/tvm/relay/op/strategy/arm_cpu.py
index b64c541863..a23ccf8f69 100644
--- a/python/tvm/relay/op/strategy/arm_cpu.py
+++ b/python/tvm/relay/op/strategy/arm_cpu.py
@@ -213,19 +213,35 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, 
target):
                 is_aarch64 = target.features.is_aarch64
                 has_asimd = target.features.has_asimd
                 has_dot_prod = target.features.has_dotprod
+                has_matmul_i8 = target.features.has_matmul_i8
 
-                if has_dot_prod and data.dtype in ["int8", "uint8"]:
-                    strategy.add_implementation(
-                        
wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_quantized_native),
-                        
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized_native),
-                        name="conv2d_NHWC_quantized_native.arm_cpu",
-                    )
-                if is_aarch64 and has_asimd and data.dtype in ["int8", 
"uint8"]:
-                    strategy.add_implementation(
-                        
wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved),
-                        
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved),
-                        name="conv2d_NHWC_quantized_interleaved.arm_cpu",
-                    )
+                if data.dtype in ["int8", "uint8"]:
+                    if has_matmul_i8:
+                        strategy.add_implementation(
+                            wrap_compute_conv2d(
+                                
topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved
+                            ),
+                            wrap_topi_schedule(
+                                
topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved
+                            ),
+                            name="conv2d_NHWC_quantized_interleaved.arm_cpu",
+                        )
+                    if has_dot_prod:
+                        strategy.add_implementation(
+                            
wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_quantized_native),
+                            
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized_native),
+                            name="conv2d_NHWC_quantized_native.arm_cpu",
+                        )
+                    if is_aarch64 and has_asimd:
+                        strategy.add_implementation(
+                            wrap_compute_conv2d(
+                                
topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved
+                            ),
+                            wrap_topi_schedule(
+                                
topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved
+                            ),
+                            name="conv2d_NHWC_quantized_interleaved.arm_cpu",
+                        )
                 if (not is_aarch64) or (data.dtype not in ["int8", "uint8"]):
                     # TODO(@giuseros)
                     # This strategy errors out for quantized data types when 
tuning.
@@ -471,10 +487,19 @@ def 
conv2d_gemm_without_weight_transform_strategy_arm_cpu(attrs, inputs, out_typ
     is_aarch64 = target.features.is_aarch64
     has_asimd = target.features.has_asimd
     has_dot_prod = target.features.has_dotprod
+    has_matmul_i8 = target.features.has_matmul_i8
 
     interleaved_compute = 
topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved_without_transform
     native_compute = 
topi.arm_cpu.compute_conv2d_NHWC_quantized_native_without_transform
     if layout == "NHWC" and data.dtype in ["int8", "uint8"]:
+        if has_matmul_i8:
+            strategy.add_implementation(
+                wrap_compute_conv2d_gemm(interleaved_compute),
+                wrap_topi_schedule(
+                    
topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved_without_transform
+                ),
+                
name="conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu",
+            )
         if has_dot_prod:
             strategy.add_implementation(
                 wrap_compute_conv2d_gemm(native_compute),
diff --git a/tests/python/relay/strategy/test_select_implementation.py 
b/tests/python/relay/strategy/test_select_implementation.py
index 906ef2d161..d7dd0abbc4 100644
--- a/tests/python/relay/strategy/test_select_implementation.py
+++ b/tests/python/relay/strategy/test_select_implementation.py
@@ -81,6 +81,14 @@ def test_concatenate(target, expected_implementation):
             "llvm --device=arm_cpu --mtriple=aarch64-linux-gnu 
-mattr=+v8.2a,+i8mm",
             "conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu",
         ),
+        (
+            "llvm --device=arm_cpu --mtriple=aarch64-linux-gnu 
-mattr=+v8.2a,+dotprod,+i8mm",
+            "conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu",
+        ),
+        (
+            "llvm --device=arm_cpu --mtriple=aarch64-linux-gnu -mattr=+v9a",
+            "conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu",
+        ),
     ],
 )
 def test_int8_conv2d(target, expected_impl):

Reply via email to