This is an automated email from the ASF dual-hosted git repository.

ekalda pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4b8297480d [SME][TOPI] Add conv2d NHWC SME fp16->fp32 schedule (#17048)
4b8297480d is described below

commit 4b8297480d52d637d762544a8d68b01b7d01ff7a
Author: Andrei Hutu <[email protected]>
AuthorDate: Wed Jun 5 17:02:59 2024 +0100

    [SME][TOPI] Add conv2d NHWC SME fp16->fp32 schedule (#17048)
    
    This commit extends the SME conv2d NHWC schedule to support convolutions 
with float16 inputs (data and kernel) and a float32 output using the tensor 
intrinsics added in #16981.
---
 python/tvm/relay/op/strategy/arm_cpu.py            | 39 ++++++++---
 python/tvm/topi/arm_cpu/arm_utils.py               | 22 +++---
 python/tvm/topi/arm_cpu/conv2d.py                  | 81 +++++++++++++++++++---
 python/tvm/topi/arm_cpu/conv2d_alter_op.py         | 28 ++++++++
 python/tvm/topi/arm_cpu/conv2d_gemm.py             | 11 +++
 python/tvm/topi/nn/conv2d.py                       |  7 +-
 tests/python/relay/strategy/arm_cpu/test_conv2d.py | 39 +++++++----
 .../relay/strategy/test_select_implementation.py   | 60 +++++++++++++---
 tests/python/topi/test_topi_conv2d_nhwc.py         | 15 ++--
 9 files changed, 244 insertions(+), 58 deletions(-)

diff --git a/python/tvm/relay/op/strategy/arm_cpu.py 
b/python/tvm/relay/op/strategy/arm_cpu.py
index 12f19462f7..35fd2b7a78 100644
--- a/python/tvm/relay/op/strategy/arm_cpu.py
+++ b/python/tvm/relay/op/strategy/arm_cpu.py
@@ -23,6 +23,7 @@ import re
 
 from tvm import relay, topi, tir
 from tvm.tir.schedule.analysis import has_block
+from tvm.dlight.gpu.matmul import auto_inline_consumers
 
 from ....auto_scheduler import is_auto_scheduler_enabled
 from ....meta_schedule import is_meta_schedule_enabled
@@ -255,9 +256,8 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, 
target):
                 if is_aarch64 and data.dtype in ["float32", "float16"]:
                     if (
                         target.features.has_sme
-                        and data.dtype in ["float32"]
-                        and kernel.dtype in ["float32"]
-                        and out_type.dtype in ["float32"]
+                        and kernel.dtype == data.dtype
+                        and out_type.dtype == "float32"
                     ):
                         strategy.add_implementation(
                             
wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_hybrid_SME),
@@ -536,6 +536,7 @@ def 
conv2d_gemm_without_weight_transform_strategy_arm_cpu(attrs, inputs, out_typ
     """conv2d_winograd_without_weight_transform arm cpu strategy"""
     layout = attrs.data_layout
     data = inputs[0]
+    kernel = inputs[1]
     strategy = _op.OpStrategy()
     is_aarch64 = target.features.is_aarch64
     has_dot_prod = target.features.has_dotprod
@@ -581,13 +582,31 @@ def 
conv2d_gemm_without_weight_transform_strategy_arm_cpu(attrs, inputs, out_typ
                     wrap_topi_schedule(interleaved_schedule),
                     
name="conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu",
                 )
+        # Non-quantized cases
         elif data.dtype in ["float32", "float16"]:
-            # Non-quantized cases
-            strategy.add_implementation(
-                
wrap_compute_conv2d_gemm(topi.arm_cpu.compute_conv2d_NHWC_hybrid_without_transform),
-                
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_hybrid_without_transform),
-                name="conv2d_NHWC_hybrid_without_transform.arm_cpu",
-            )
+            # The SME schedule for float16->float32 prearranges the two 
matrices to be multiplied
+            # using the ARM_SME_BLOCK2_2SVLx1SVL_FP16_TRANSPOSE_INTERLEAVE 
intrinsic which expects
+            # the reduction axis K as the second dimension of the matrix (i.e. 
shape = (_, K)).
+            # This means that the flattened weights matrix B needs to be 
transposed to (N, K).
+            if (
+                target.features.has_sme
+                and kernel.dtype == "float16"
+                and data.dtype == "float16"
+                and out_type.dtype == "float32"
+            ):
+                strategy.add_implementation(
+                    
wrap_compute_conv2d_gemm(topi.arm_cpu.compute_conv2d_NHWC_SME_transposed_B),
+                    lambda: None,
+                    name="conv2d_NHWC_hybrid_SME_transposed_B.arm_cpu",
+                )
+            else:
+                strategy.add_implementation(
+                    wrap_compute_conv2d_gemm(
+                        
topi.arm_cpu.compute_conv2d_NHWC_hybrid_without_transform
+                    ),
+                    
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_hybrid_without_transform),
+                    name="conv2d_NHWC_hybrid_without_transform.arm_cpu",
+                )
     else:
         raise RuntimeError(
             f"Unsupported conv2d_NHWC_without_transform layout {layout}"
@@ -819,6 +838,8 @@ def arm_cpu_tir_strategy(sch: tir.Schedule) -> bool:
         topi.arm_cpu.matmul.tir_schedule_matmul_sme(sch)
         return True
     elif has_block(sch, "conv2d_gemm_output"):
+        conv2d_block = sch.get_block("conv2d_gemm_output")
+        auto_inline_consumers(sch, conv2d_block)
         topi.arm_cpu.schedule_conv2d_NHWC_hybrid_TIR(sch)
         return True
 
diff --git a/python/tvm/topi/arm_cpu/arm_utils.py 
b/python/tvm/topi/arm_cpu/arm_utils.py
index 5c4b3c0456..f690b22731 100644
--- a/python/tvm/topi/arm_cpu/arm_utils.py
+++ b/python/tvm/topi/arm_cpu/arm_utils.py
@@ -68,8 +68,11 @@ def get_tiling_A(interleave_A, in_dtype, use_sme=False):
             tile_M = 4
             tile_K = 16
     elif use_sme:
-        tile_M = 2 * 4 * tvm.tir.vscale()
-        tile_K = 2 * 4 * tvm.tir.vscale()
+        tile_M = 2 * tvm.tir.get_vscale_expr(in_dtype)
+        if in_dtype == "float16":
+            tile_K = tvm.tir.get_vscale_expr(in_dtype)
+        else:
+            tile_K = 2 * tvm.tir.get_vscale_expr(in_dtype)
     else:
         # In non-SME, non-quantized cases, A is not interleaved.
         # We are loading 4 rows from A.
@@ -139,17 +142,16 @@ def get_tiling_B_transformed(interleave_A, in_dtype, 
use_scalable_vectors=False,
             tile_N = 4
             tile_K = 16
     elif use_sme:
-        tile_N = 2 * 4 * tvm.tir.vscale()
-        tile_K = 2 * 4 * tvm.tir.vscale()
-    # In non-SME, non-quantized cases, A is not interleaved.
-    elif use_scalable_vectors:
+        tile_N = 2 * tvm.tir.get_vscale_expr(in_dtype)
         if in_dtype == "float16":
-            # Each load from B' contains 32 * vscale elements (i.e. 32 * 
vscale columns from B)
-            tile_N = 32 * tvm.tir.vscale()
+            tile_K = tvm.tir.get_vscale_expr(in_dtype)
         else:
-            # Each load from B' contains 16 * vscale elements (i.e. 16 * 
vscale columns from B)
-            tile_N = 16 * tvm.tir.vscale()
+            tile_K = 2 * tvm.tir.get_vscale_expr(in_dtype)
+    # In non-SME, non-quantized cases, A is not interleaved.
+    elif use_scalable_vectors:
+        # Each load from B' contains 4 * scalable vectors (i.e. 4 * SVL 
columns from B)
         # We are loading 4 rows from B', in the dimension of reduction (i.e. 4 
rows from B)
+        tile_N = 4 * tvm.tir.get_vscale_expr(in_dtype)
         tile_K = 4
     elif in_dtype == "float16" and target.features.has_fp16_simd:
         # Each load from B' contains 32 elements (i.e. 32 columns from B)
diff --git a/python/tvm/topi/arm_cpu/conv2d.py 
b/python/tvm/topi/arm_cpu/conv2d.py
index d0fe251e7e..a6c951c078 100644
--- a/python/tvm/topi/arm_cpu/conv2d.py
+++ b/python/tvm/topi/arm_cpu/conv2d.py
@@ -24,6 +24,7 @@ from tvm import autotvm
 from tvm.script import tir as T
 import tvm.contrib.nnpack
 from tvm.tir.schedule.analysis import has_block
+from tvm.topi.arm_cpu.matmul import _get_transpose_interleave_intrin_name
 
 from ..utils import traverse_inline, get_const_tuple
 from .. import nn
@@ -680,6 +681,43 @@ def compute_conv2d_NHWC_hybrid_SME(cfg, data, kernel, 
strides, padding, dilation
     )
 
 
[email protected]_topi_compute("conv2d_NHWC_hybrid_SME_transposed_B.arm_cpu")
+def compute_conv2d_NHWC_SME_transposed_B(
+    cfg,
+    data,
+    kernel,
+    strides,
+    padding,
+    dilation,
+    out_dtype,
+    kernel_size,
+    output_channels,
+):
+    """Compute conv2d NHWC hybrid SME transposed B"""
+    N, K = get_const_tuple(kernel.shape)
+    tile_N, tile_K = get_tiling_B_transformed(False, data.dtype, True, True)
+    pad_N, pad_K = tvm.topi.arm_cpu.arm_utils.get_conv2d_weights_padding(N, K, 
tile_N, tile_K)
+
+    kernel = tvm.topi.nn.pad(
+        kernel, pad_before=(0, 0), pad_after=(pad_N, pad_K), 
name="weight_padding"
+    )
+
+    return compute_conv2d_gemm_without_weight_transform(
+        cfg,
+        data,
+        kernel,
+        strides,
+        padding,
+        dilation,
+        out_dtype,
+        kernel_size,
+        output_channels,
+        interleave_A=False,
+        use_scalable_vectors=True,
+        use_sme=True,
+    )
+
+
 def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
     """
     Perform TIR scheduling for conv2d NHWC.
@@ -688,7 +726,8 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
     primfunc = sch.mod["main"]
     buffer_names = primfunc.params
     buffer_list = [primfunc.buffer_map[buf] for buf in buffer_names]
-    dtype = buffer_list[0].dtype
+    in_dtype = buffer_list[0].dtype
+    out_dtype = "float32"
 
     # Determine PrimFunc blocks
     block_list = [
@@ -698,6 +737,8 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
         "A_padded_K",
         "A_padded_M",
         "weight_flatten",
+        "weight_padding",
+        "weight_transpose",
         "C",
         "conv2d_gemm_output",
     ]
@@ -716,8 +757,8 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
     M_padded = sch.get(m).extent
     N_padded = sch.get(n).extent
     K_padded = sch.get(k).extent
-    tile_M, tile_K = get_tiling_A(False, dtype, use_sme)
-    tile_N, _ = get_tiling_B_transformed(False, dtype, use_scalable_vectors, 
use_sme)
+    tile_M, tile_K = get_tiling_A(False, in_dtype, use_sme)
+    tile_N, _ = get_tiling_B_transformed(False, in_dtype, 
use_scalable_vectors, use_sme)
     tile_M = T.cast(tile_M, M_padded.dtype)
     tile_N = T.cast(tile_N, N_padded.dtype)
     tile_K = T.cast(tile_K, K_padded.dtype)
@@ -729,12 +770,15 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: 
tvm.tir.Schedule):
         # pylint: disable=import-outside-toplevel
         from tvm.topi.arm_cpu.pstate_attributes import SMEAttributes
         from tvm.tir.tensor_intrin.arm_cpu import (
-            ARM_SME_2SVLx2SVL_FP32_TRANSPOSE_INTERLEAVE,
             ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA,
             ARM_SME_INIT,
             get_sme_gemm_interleaved_mopa_2svlx2svl_intrin,
         )
 
+        transpose_interleave_intrin_name = 
_get_transpose_interleave_intrin_name(
+            in_dtype, out_dtype
+        )
+
         # Interleave the padded im2col matrix utilizing the matrix tile
         interleave_t_A_block = sch.cache_read(gemm_block, 0, "global")
         sch.transform_layout(interleave_t_A_block, ("write", 0), lambda b, m, 
k: (b, k, m))
@@ -743,24 +787,40 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: 
tvm.tir.Schedule):
         ko, ki = sch.split(k, factors=(None, tile_K), disable_predication=True)
         sch.parallel(b)
         sch.reorder(b, ko, mo, ki, mi)
-        sch.tensorize(ki, ARM_SME_2SVLx2SVL_FP32_TRANSPOSE_INTERLEAVE)
+        sch.tensorize(ki, transpose_interleave_intrin_name)
+
+        # Interleave the padded weights matrix utilizing the matrix tile
+        if in_dtype == "float16":
+            interleave_b_block = sch.cache_read(gemm_block, 1, "global")
+            sch.transform_layout(interleave_b_block, ("write", 0), lambda n, 
k: (k, n))
+            n, k = sch.get_loops(interleave_b_block)
+            ko, ki = sch.split(k, factors=(None, tile_K), 
disable_predication=True)
+            no, ni = sch.split(n, factors=(None, tile_N), 
disable_predication=True)
+            sch.reorder(ko, no, ki, ni)
+            sch.tensorize(ki, transpose_interleave_intrin_name)
 
         # Split and reorder the loops of the GeMM for tensorization
         b, m, n, k = sch.get_loops(gemm_block)
+        tile_M, _ = get_tiling_A(False, out_dtype, True)
+        tile_N, _ = get_tiling_B_transformed(False, out_dtype, True, True)
+        tile_M = T.cast(tile_M, M_padded.dtype)
+        tile_N = T.cast(tile_N, N_padded.dtype)
         mo, mi = sch.split(m, factors=(None, tile_M), disable_predication=True)
         no, ni = sch.split(n, factors=(None, tile_N), disable_predication=True)
         sch.parallel(b)
         sch.reorder(b, mo, no, mi, ni, k)
 
-        # Tensorize the GeMM output matrix initialization to zero
+        # Tensorize the GeMM initialization
         init_block = sch.decompose_reduction(gemm_block, mi)
         sch.tensorize(sch.get_loops(init_block)[-2], ARM_SME_INIT)
 
         # Tensorize the GeMM update
-        sme_gemm_interleaved_intrin_name = 
ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA + f"_{K_padded}"
+        sme_gemm_interleaved_intrin_name = (
+            ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA + f"_{K_padded}_{in_dtype}"
+        )
         tvm.tir.TensorIntrin.register(
             sme_gemm_interleaved_intrin_name,
-            *get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(K_padded, dtype),
+            *get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(K_padded, 
in_dtype),
             override=True,
         )
         sch.tensorize(mi, sme_gemm_interleaved_intrin_name)
@@ -878,6 +938,11 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
         weight_flatten_block = func_blocks["weight_flatten"]
         sch.compute_inline(weight_flatten_block)
 
+    # Weight transpose
+    if func_blocks["weight_transpose"] and func_blocks["weight_padding"]:
+        weight_padding_block = func_blocks["weight_padding"]
+        sch.compute_inline(weight_padding_block)
+
     # Conv2d output block
     output_block = func_blocks["conv2d_gemm_output"]
     n, h, w, c = sch.get_loops(output_block)
diff --git a/python/tvm/topi/arm_cpu/conv2d_alter_op.py 
b/python/tvm/topi/arm_cpu/conv2d_alter_op.py
index fe4569ceb1..2476cb92b9 100644
--- a/python/tvm/topi/arm_cpu/conv2d_alter_op.py
+++ b/python/tvm/topi/arm_cpu/conv2d_alter_op.py
@@ -162,6 +162,34 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
             inputs[0], new_kernel_expr, **new_attrs
         )
 
+    if (
+        topi_tmpl == "conv2d_NHWC_hybrid_SME.arm_cpu"
+        and data_dtype == "float16"
+        and kernel_dtype == "float16"
+        and out_dtype == "float32"
+    ):
+        assert data_layout == "NHWC" and kernel_layout == "HWIO"
+        KH, KW, IC, OC = get_const_tuple(kernel.shape)
+        K = KH * KW * IC
+        N = OC
+        # The SME schedule for float16->float32 prearranges the two matrices 
to be multiplied
+        # using the ARM_SME_BLOCK2_2SVLx1SVL_FP16_TRANSPOSE_INTERLEAVE 
intrinsic which expects
+        # the reduction axis K as the second dimension of the matrix (i.e. 
shape = (_, K)).
+        # This means that the flattened weights matrix B needs to be 
transposed to (N, K).
+        transposed_kernel_expr = relay.transpose(inputs[1], axes=[3, 0, 1, 2])
+        transposed_flattened_kernel_expr = 
relay.reshape(transposed_kernel_expr, newshape=(N, K))
+        new_kernel_expr = transposed_flattened_kernel_expr
+        new_kernel = te.placeholder((N, K), kernel.dtype)
+        new_workload_name = "conv2d_NHWC_hybrid_SME_transposed_B.arm_cpu"
+        new_workload = autotvm.task.args_to_workload(
+            [data, new_kernel, strides, padding, dilation, out_dtype, (KH, 
KW), OC],
+            new_workload_name,
+        )
+        dispatch_ctx.update(target, new_workload, cfg)
+        return relay.nn.contrib_conv2d_gemm_without_weight_transform(
+            inputs[0], new_kernel_expr, **new_attrs
+        )
+
     # Only microTVM does layout alteration for NHWC layout with real data types
     if data_layout == "NHWC" and data_dtype not in ["uint8", "int8"]:
         return None
diff --git a/python/tvm/topi/arm_cpu/conv2d_gemm.py 
b/python/tvm/topi/arm_cpu/conv2d_gemm.py
index 0c3908bb70..e637aa91e5 100644
--- a/python/tvm/topi/arm_cpu/conv2d_gemm.py
+++ b/python/tvm/topi/arm_cpu/conv2d_gemm.py
@@ -289,6 +289,17 @@ def compute_conv2d_gemm_without_weight_transform(
                 tvm.tir.const(1, C.dtype) * C[0, M_padded - 1, N_padded - 1]
                 - tvm.tir.const(1, C.dtype) * C[0, M_padded - 1, N_padded - 1]
             )
+    elif use_sme and in_dtype == "float16" and out_dtype == "float32":
+        assert len(B_interleaved_t.shape) == 2
+        C = te.compute(
+            (batches, M_padded, N_padded),
+            lambda b, x, y: te.sum(
+                A[b, x, k].astype(out_dtype) * B_interleaved_t[y, 
k].astype(out_dtype),
+                axis=k,
+            ),
+            name="C",
+        )
+        zero = tvm.tir.const(0)
     elif use_scalable_vectors or use_sme:
         assert len(B_interleaved_t.shape) == 2
         C = te.compute(
diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py
index 8d61c62250..205730ff22 100644
--- a/python/tvm/topi/nn/conv2d.py
+++ b/python/tvm/topi/nn/conv2d.py
@@ -654,7 +654,12 @@ def conv2d_gemm_weight_transform(kernel, tile_N, tile_K, 
use_scalable_vectors=Fa
             kernel_flat, pad_before=(0, 0), pad_after=(pad_K, pad_N), 
name="weight_padding"
         )
 
-    if use_sme or use_scalable_vectors:
+    if use_sme and kernel.dtype == "float16":
+        return te.compute(
+            (N_padded, K_padded), lambda x, y: kernel_flat[y, x], 
name="weight_transpose"
+        )
+
+    if use_scalable_vectors or use_sme:
         return kernel_flat
 
     if kernel.dtype in ["int8", "uint8"]:
diff --git a/tests/python/relay/strategy/arm_cpu/test_conv2d.py 
b/tests/python/relay/strategy/arm_cpu/test_conv2d.py
index 2708094afb..f4fa250ecf 100644
--- a/tests/python/relay/strategy/arm_cpu/test_conv2d.py
+++ b/tests/python/relay/strategy/arm_cpu/test_conv2d.py
@@ -120,7 +120,8 @@ class TestConv2d_NCHW_Spatial_Pack(Conv2dTests):
     schedule_name = parameter("conv2d_nchw_spatial_pack.arm_cpu")
 
 
-dtype = tvm.testing.parameter("float32")
+in_dtype = tvm.testing.parameter("float16", "float32")
+out_dtype = tvm.testing.parameter("float32")
 
 batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation = 
tvm.testing.parameters(
     # Pad M, N, K
@@ -154,30 +155,35 @@ batch, in_channel, in_size, num_filter, kernel, stride, 
padding, dilation = tvm.
 
 
 @tvm.testing.fixture()
-def ref_data(dtype, batch, in_channel, in_size, num_filter, kernel, stride, 
padding, dilation):
+def ref_data(
+    in_dtype, out_dtype, batch, in_channel, in_size, num_filter, kernel, 
stride, padding, dilation
+):
     np.random.seed(0)
     in_height = in_width = in_size
     a_shape = (batch, in_height, in_width, in_channel)
     w_shape = (kernel, kernel, in_channel, num_filter)
 
-    a_np = np.random.uniform(size=a_shape).astype(dtype)
-    w_np = np.random.uniform(size=w_shape).astype(dtype)
-    return a_np, w_np
+    a_np = np.random.uniform(size=a_shape).astype(in_dtype)
+    w_np = np.random.uniform(size=w_shape).astype(in_dtype)
+    dw_np = tvm.topi.testing.dilate_python(w_np, (dilation, dilation, 1, 1))
+    b_np = tvm.topi.testing.conv2d_nhwc_python(
+        a_np.astype(out_dtype), dw_np.astype(out_dtype), stride, padding
+    ).astype(out_dtype)
+    return a_np, w_np, dw_np, b_np
 
 
 @pytest.mark.skipif(
     llvm_version_major() < 16, reason="SME is not supported in earlier 
versions of LLVM"
 )
 @tvm.testing.requires_aprofile_aem_fvp
-def test_conv2d_fp32(target, ref_data, dtype, stride, padding, dilation):
-    a_np, w_np = ref_data
-    dw_np = tvm.topi.testing.dilate_python(w_np, (dilation, dilation, 1, 1))
+def test_conv2d_sme(target, ref_data, in_dtype, out_dtype, stride, padding, 
dilation):
+    a_np, w_np, dw_np, b_np = ref_data
 
     kernel_size = get_const_tuple(w_np.shape[:2])
     out_channels = w_np.shape[3]
 
-    x = relay.var("data", shape=a_np.shape, dtype=dtype)
-    weight = relay.const(w_np, dtype=dtype)
+    x = relay.var("data", shape=a_np.shape, dtype=in_dtype)
+    weight = relay.const(w_np, dtype=in_dtype)
     conv2d = relay.nn.conv2d(
         x,
         weight,
@@ -188,7 +194,7 @@ def test_conv2d_fp32(target, ref_data, dtype, stride, 
padding, dilation):
         padding=get_pad_tuple(padding, dw_np.shape[:2]),
         data_layout="NHWC",
         kernel_layout="HWIO",
-        out_dtype=dtype,
+        out_dtype=out_dtype,
     )
 
     func = relay.Function(relay.analysis.free_vars(conv2d), conv2d)
@@ -198,7 +204,7 @@ def test_conv2d_fp32(target, ref_data, dtype, stride, 
padding, dilation):
 
     inputs = {"data": a_np}
     params = {}
-    ref_outputs = generate_ref_data(ir_mod, inputs, params)
+    ref_outputs = {"output": b_np}
 
     target = tvm.target.Target("llvm -mtriple=aarch64-none-elf 
-mattr=+v9.2a,+sme")
     runtime = tvm.relay.backend.Runtime("crt", {"system-lib": True})
@@ -220,9 +226,12 @@ def test_conv2d_fp32(target, ref_data, dtype, stride, 
padding, dilation):
             runtime=runtime,
             params=params,
         )
-    generated_func = executor_factory.lowered_ir_mods.items()[0][1][
-        "tvmgen_default_fused_nn_conv2d"
-    ]
+
+    if in_dtype == "float16":
+        func_name = 
"tvmgen_default_fused_nn_contrib_conv2d_gemm_without_weight_transform"
+    else:
+        func_name = "tvmgen_default_fused_nn_conv2d"
+    generated_func = executor_factory.lowered_ir_mods.items()[0][1][func_name]
     extra_memory_in_bytes = 
calculate_extra_workspace_size_from_scalable_extents(generated_func, 4)
 
     test_model = AOTTestModel(
diff --git a/tests/python/relay/strategy/test_select_implementation.py 
b/tests/python/relay/strategy/test_select_implementation.py
index 01a914e793..b95bd4072a 100644
--- a/tests/python/relay/strategy/test_select_implementation.py
+++ b/tests/python/relay/strategy/test_select_implementation.py
@@ -58,7 +58,7 @@ def test_concatenate(target, expected_implementation):
     assert impl.name == expected_implementation
 
 
-def _get_conv2d_impl(dtype, target):
+def _get_conv2d_impl(in_dtype, out_dtype, target):
     """Returns selected conv2d implementation for a given datatype and 
target"""
     data_shape = (1, 1, 1, 4)
     weight_shape = (1, 1, 4, 4)
@@ -68,21 +68,24 @@ def _get_conv2d_impl(dtype, target):
     kernel_size = (1, 1)
 
     out = relay.nn.conv2d(
-        relay.var("data", shape=data_shape, dtype=dtype),
-        relay.var("weight", shape=weight_shape, dtype=dtype),
+        relay.var("data", shape=data_shape, dtype=in_dtype),
+        relay.var("weight", shape=weight_shape, dtype=in_dtype),
         kernel_size=kernel_size,
         channels=channels,
         data_layout=data_layout,
         kernel_layout=kernel_layout,
-        out_dtype=dtype,
+        out_dtype=out_dtype,
     )
 
     with target:
         out = run_opt_pass(out, relay.transform.AlterOpLayout())
+        data_shape = out.type_args[0].shape
+        weight_shape = out.type_args[1].shape
+
         impl, _ = relay.backend.te_compiler.select_implementation(
             out.op,
             out.attrs,
-            [te.placeholder(data_shape, dtype), te.placeholder(weight_shape, 
dtype)],
+            [te.placeholder(data_shape, in_dtype), 
te.placeholder(weight_shape, in_dtype)],
             out.checked_type,
             target,
             use_autotvm=False,
@@ -131,7 +134,7 @@ def test_int8_conv2d(target, expected_impl):
     target = tvm.target.Target(target)
     dtype = "int8"
 
-    selected_impl = _get_conv2d_impl(dtype, target)
+    selected_impl = _get_conv2d_impl(dtype, dtype, target)
     assert selected_impl == expected_impl
 
 
@@ -171,7 +174,7 @@ def test_fp32_conv2d(target, expected_impl):
     target = tvm.target.Target(target)
     dtype = "float32"
 
-    selected_impl = _get_conv2d_impl(dtype, target)
+    selected_impl = _get_conv2d_impl(dtype, dtype, target)
     assert selected_impl == expected_impl
 
 
@@ -211,7 +214,48 @@ def test_fp16_conv2d(target, expected_impl):
     target = tvm.target.Target(target)
     dtype = "float16"
 
-    selected_impl = _get_conv2d_impl(dtype, target)
+    selected_impl = _get_conv2d_impl(dtype, dtype, target)
+    assert selected_impl == expected_impl
+
+
[email protected](
+    llvm_version_major() < 15, reason=f"Requires LLVM 15+, got 
{llvm_version_major()}"
+)
[email protected](
+    "target,expected_impl",
+    [
+        (
+            "llvm -device=arm_cpu -mtriple=armv8l-linux-gnu -mattr=+neon",
+            "conv2d_nhwc_spatial_pack.arm_cpu",
+        ),
+        (
+            "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu",
+            "conv2d_NHWC_hybrid_without_transform.arm_cpu",
+        ),
+        (
+            "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+neon",
+            "conv2d_NHWC_hybrid_without_transform.arm_cpu",
+        ),
+        (
+            "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu 
-mattr=+v8.2a,+neon",
+            "conv2d_NHWC_hybrid_without_transform.arm_cpu",
+        ),
+        (
+            "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+v9a",
+            "conv2d_NHWC_hybrid_without_transform.arm_cpu",
+        ),
+        (
+            "llvm --device=arm_cpu --mtriple=aarch64-linux-gnu 
-mattr=+v9.2a,+sme",
+            "conv2d_NHWC_hybrid_SME_transposed_B.arm_cpu",
+        ),
+    ],
+)
+def test_fp16_to_fp32_conv2d(target, expected_impl):
+    target = tvm.target.Target(target)
+    in_dtype = "float16"
+    out_dtype = "float32"
+
+    selected_impl = _get_conv2d_impl(in_dtype, out_dtype, target)
     assert selected_impl == expected_impl
 
 
diff --git a/tests/python/topi/test_topi_conv2d_nhwc.py 
b/tests/python/topi/test_topi_conv2d_nhwc.py
index 02f16b59c0..d46db1b28b 100644
--- a/tests/python/topi/test_topi_conv2d_nhwc.py
+++ b/tests/python/topi/test_topi_conv2d_nhwc.py
@@ -68,7 +68,7 @@ device = tvm.testing.parameter(
         False,
     ),
     (
-        "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.2a",
+        "llvm --device arm_cpu --mtriple aarch64-linux-gnu 
-mattr=+v8.2a,+fullfp16",
         topi.arm_cpu.compute_conv2d_NHWC_hybrid,
         topi.arm_cpu.schedule_conv2d_NHWC_hybrid_TIR,
         True,
@@ -173,13 +173,14 @@ def test_conv2d_nhwc_gemm(device, ref_data, dtype, 
stride, padding, dilation):
     if target.features.has_sme and llvm_version_major() < 16:
         pytest.skip(f"LLVM {llvm_version_major()} does not support targetting 
SME.")
 
-    if target.features.has_sme and dtype == "float16":
-        pytest.skip(f"Conv2d fp16 targetting SME not implemented.")
+    # SME schedule always outputs float32 results, regardless of input dtype.
+    # Otherwise, output dtype is the same as input dtype.
+    out_dtype = "float32" if target.features.has_sme else dtype
 
     with target:
         a = tvm.nd.array(a_np, dev)
         w = tvm.nd.array(w_np, dev)
-        B = compute(A, W, stride, padding, dilation, dtype)
+        B = compute(A, W, stride, padding, dilation, out_dtype)
         b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), 
dev)
         if use_tir_schedule:
             primfunc = te.create_prim_func([A, W, B])
@@ -190,22 +191,22 @@ def test_conv2d_nhwc_gemm(device, ref_data, dtype, 
stride, padding, dilation):
             func = tvm.build(s, [A, W, B], target)
 
         # Run only on AArch64 devices
-        # Do not run SVE schedules on non-SVE devices
+        # Do not run SVE/SME schedules on non-SVE/SME devices
         build_only = (
             platform.machine() != "aarch64"
-            or (target.features.has_sve and not 
tvm.testing.requires_aarch64_sve.run_time_check())
             or (
                 dtype == "float16"
                 and target.features.has_fp16_simd
                 and not tvm.testing.requires_arm_fp16.run_time_check()
             )
+            or (target.features.has_sve and not 
tvm.testing.requires_aarch64_sve.run_time_check())
             or (target.features.has_sme and not 
tvm.testing.requires_aarch64_sme.run_time_check())
         )
         if build_only:
             return
 
         func(a, w, b)
-    tol = get_tolerance(dtype, w_np, b_np)
+    tol = get_tolerance(out_dtype, w_np, b_np)
     tvm.testing.assert_allclose(b.numpy(), b_np, rtol=tol["rtol"], 
atol=tol["atol"])
 
 

Reply via email to