This is an automated email from the ASF dual-hosted git repository.

ekalda pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new d3011ab609 [SME] Utilize predication in fp32 matmul and conv2d 
schedules (#17054)
d3011ab609 is described below

commit d3011ab609f30ef3363b230bd0f3702ba00aa270
Author: Luke Hutton <[email protected]>
AuthorDate: Fri Jun 14 10:47:00 2024 +0100

    [SME] Utilize predication in fp32 matmul and conv2d schedules (#17054)
    
    Prior to this commit, the matmul and conv2d schedules required padding
    of the inputs to some multiple of vscale and a final "unpadding" stage.
    
    Instead, we can leverage predicated operations to avoid the
    the requirement for padding. Both the transpose interleave and outer
    product fp32 intrinsics are updated to use predication. The
    `get_active_lane_mask` intrinsic is utilized to generate a variably
    sized mask of active lanes depending on the global position the tensor
    intrinsic is operating on.
    
    For now this relies on using `offset_of` and `stride` information from
    the tensor we're predicating an access on. Likely we will want to
    build on this in the future with a more intuitive API for determining
    the current tile location.
    
    Support for batched conv2d was removed since this causes numerical
    issues which is suspected to be due to how the current tile is
    determined (paragraph above).
---
 python/tvm/relay/op/strategy/arm_cpu.py            |   7 ++
 python/tvm/tir/tensor_intrin/arm_cpu.py            | 134 +++++++++++++++++----
 python/tvm/topi/arm_cpu/conv2d.py                  |  40 +++---
 python/tvm/topi/arm_cpu/conv2d_gemm.py             |  39 ++++--
 python/tvm/topi/arm_cpu/matmul.py                  |  58 ++++-----
 .../python/codegen/test_target_codegen_aarch64.py  |   4 +
 tests/python/topi/test_topi_conv2d_nhwc.py         |  10 +-
 7 files changed, 197 insertions(+), 95 deletions(-)

diff --git a/python/tvm/relay/op/strategy/arm_cpu.py 
b/python/tvm/relay/op/strategy/arm_cpu.py
index 35fd2b7a78..f4b4708401 100644
--- a/python/tvm/relay/op/strategy/arm_cpu.py
+++ b/python/tvm/relay/op/strategy/arm_cpu.py
@@ -110,6 +110,8 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, 
target):
     """conv2d arm cpu strategy"""
     strategy = _op.OpStrategy()
     data, kernel = inputs
+    data_shape = data.shape
+    kernel_shape = kernel.shape
     dilation_h, dilation_w = attrs.get_int_tuple("dilation")
     stride_h, stride_w = attrs.get_int_tuple("strides")
     padding = attrs.get_int_tuple("padding")
@@ -258,6 +260,11 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, 
target):
                         target.features.has_sme
                         and kernel.dtype == data.dtype
                         and out_type.dtype == "float32"
+                        and data_shape[0] == 1
+                        # The schedule uses tensorization which does not work 
when the
+                        # reduction axis of the gemm has unit iters. See
+                        # https://github.com/apache/tvm/issues/16566
+                        and (data_shape[3] * kernel_shape[0] * 
kernel_shape[1]) > 1
                     ):
                         strategy.add_implementation(
                             
wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_hybrid_SME),
diff --git a/python/tvm/tir/tensor_intrin/arm_cpu.py 
b/python/tvm/tir/tensor_intrin/arm_cpu.py
index 3a3430af51..a6f3538846 100644
--- a/python/tvm/tir/tensor_intrin/arm_cpu.py
+++ b/python/tvm/tir/tensor_intrin/arm_cpu.py
@@ -176,7 +176,51 @@ def _create_ptrue_mask(dtype):
     return T.broadcast(T.bool(True), tir.get_vscale_expr(dtype))
 
 
-def get_sme_transpose_interleave_2svlx2svl_fp32_intrin():
+def _create_active_lane_mask(tensor, relative_offsets, vertical_limit):
+    """
+    Get the active lane mask intrinsic call for predicated accesses.
+
+    Parameters
+    ----------
+    tensor : tvm.tir.Buffer
+        The tensor the buffer access will be performed on.
+    relative_offsets : Tuple[PrimExpr, PrimExpr]
+        The vertical and horizontal offsets into the accumulator tile.
+    vertical_limit : PrimExpr
+        An absolute offset specifying the limit at which rows should be stored.
+
+    Returns
+    -------
+    PrimExpr
+        The active lane mask intrinsic.
+    """
+    vertical_offset, horizontal_offset = relative_offsets
+    stride = tensor.strides[0]
+
+    # The base is the offset of the first value we wish to store
+    base = T.int32(tensor.offset_of([vertical_offset, horizontal_offset])[0])
+
+    # The limit is the maximum offset in the current row of 'base' that we 
wish to allow values
+    # to be stored. Calculating this limit is a bit tricky since we can only 
request offsets of
+    # elements in the tensorized tile of the output tensor. One way to 
calculate this is to find
+    # the offset of the first value in the row of the output tensor that 
'base' is in and add
+    # 'stride' to it.
+    limit = (
+        base
+        - T.int32(horizontal_offset)
+        - T.int32((tensor.offset_of([0, 0])[0] % stride))
+        + T.int32(stride)
+    )
+    limit = T.Min(limit, T.Cast("int32", vertical_limit) * stride)
+
+    return T.get_active_lane_mask(
+        "uint1xvscalex4",
+        T.Cast("int32", base),
+        T.Cast("int32", limit),
+    )
+
+
+def get_sme_transpose_interleave_2svlx2svl_fp32_intrin(cols, rows):
     """
     Transpose a matrix of size 2SVL x 2SVL (where 'SVL' is the Scalable Vector 
Length) using
     the Scalable Matrix Extension (SME).
@@ -247,9 +291,6 @@ def get_sme_transpose_interleave_2svlx2svl_fp32_intrin():
                     strides=[T.int32(), 1],
                 )
 
-                # Disable predication
-                ptrue = _create_ptrue_mask("float32")
-
                 with T.block("root"):
                     T.reads(A[0:SVF2, 0:SVF2])
                     T.writes(A_t[0:SVF2, 0:SVF2])
@@ -263,19 +304,22 @@ def get_sme_transpose_interleave_2svlx2svl_fp32_intrin():
 
                             input_ptr = A.access_ptr("r", offset=offset)
                             sub_tile = T.int32(sub_tile_idx)
+                            predicate = _create_active_lane_mask(
+                                A, (row_offset + slice_idx, col_offset), cols
+                            )
                             T.evaluate(
                                 T.call_llvm_intrin(
                                     "void",
                                     "llvm.aarch64.sme.ld1w.horiz",
                                     T.uint32(4),
-                                    ptrue,
+                                    predicate,
                                     input_ptr,
                                     sub_tile,
                                     slice_idx,
                                 )
                             )
 
-                    # Store columns to the ouptut matrix
+                    # Store columns to the output matrix
                     with T.serial(0, SVF) as slice_idx:
                         for sub_tile_idx in range(0, sub_tile_count):
                             col_offset = SVF if sub_tile_idx >= 
(sub_tile_count // 2) else 0
@@ -284,12 +328,15 @@ def get_sme_transpose_interleave_2svlx2svl_fp32_intrin():
 
                             output_ptr = A_t.access_ptr("w", offset=offset)
                             sub_tile = T.int32(sub_tile_idx)
+                            predicate = _create_active_lane_mask(
+                                A_t, (row_offset + slice_idx, col_offset), rows
+                            )
                             T.evaluate(
                                 T.call_llvm_intrin(
                                     "void",
                                     "llvm.aarch64.sme.st1w.vert",
                                     T.uint32(4),
-                                    ptrue,
+                                    predicate,
                                     output_ptr,
                                     sub_tile,
                                     slice_idx,
@@ -445,7 +492,24 @@ def get_sme_transpose_interleave_block2_2svl_fp16_intrin():
     return desc, impl()
 
 
-def get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(K, in_dtype):
+def get_transpose_interleave_intrin_name(in_dtype, out_dtype, extent_cols, 
extent_rows):
+    if in_dtype == "float32" and out_dtype == "float32":
+        sme_transpose_interleave_intrin_name = (
+            ARM_SME_2SVLx2SVL_FP32_TRANSPOSE_INTERLEAVE + 
f"_{extent_cols}_{extent_rows}"
+        )
+        tir.TensorIntrin.register(
+            sme_transpose_interleave_intrin_name,
+            *get_sme_transpose_interleave_2svlx2svl_fp32_intrin(extent_cols, 
extent_rows),
+            override=True,
+        )
+        return sme_transpose_interleave_intrin_name
+    elif in_dtype == "float16" and out_dtype == "float32":
+        return ARM_SME_BLOCK2_2SVLx1SVL_FP16_TRANSPOSE_INTERLEAVE
+    else:
+        raise ValueError("Input/output data type combination not supported.")
+
+
+def get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(M, K, in_dtype):
     """
     Compute a GEMM of size 2SVL x 2SVL (where 'SVL' is the Scalable Vector 
Length using
     outer product operations from the Scalable Matrix Extension (SME).
@@ -579,15 +643,39 @@ def get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(K, 
in_dtype):
                         k_row = k * rows_per_iter
                         in_dtype_svf = tir.get_vscale_expr(in_dtype)
 
-                        a_low = T.BufferLoad(A, [k_row, T.Ramp(0, 1, 
in_dtype_svf)])
-                        b_low = T.BufferLoad(B, [k_row, T.Ramp(0, 1, 
in_dtype_svf)])
-
+                        # Ideally we'd rely on predicating the loads and use 
the same predicate
+                        # for the outer product operation. However, support 
for predicated
+                        # buffers is not currently supported by multiple 
lowering passes such as
+                        # "LowerMatchBuffer", therefore the predicate is 
passed directly to the
+                        # outer product operation for now.
                         if in_dtype == "float32":
-                            a_high = T.BufferLoad(A, [k_row, 
T.Ramp(in_dtype_svf, 1, in_dtype_svf)])
-                            b_high = T.BufferLoad(B, [k_row, 
T.Ramp(in_dtype_svf, 1, in_dtype_svf)])
+                            a_low = (
+                                T.BufferLoad(A, [k_row, T.Ramp(0, 1, 
in_dtype_svf)]),
+                                _create_active_lane_mask(A, (k_row, 0), K),
+                            )
+                            b_low = (
+                                T.BufferLoad(B, [k_row, T.Ramp(0, 1, 
in_dtype_svf)]),
+                                _create_active_lane_mask(B, (k_row, 0), K),
+                            )
+                            a_high = (
+                                T.BufferLoad(A, [k_row, T.Ramp(in_dtype_svf, 
1, in_dtype_svf)]),
+                                _create_active_lane_mask(A, (k_row, 
in_dtype_svf), K),
+                            )
+                            b_high = (
+                                T.BufferLoad(B, [k_row, T.Ramp(in_dtype_svf, 
1, in_dtype_svf)]),
+                                _create_active_lane_mask(B, (k_row, 
in_dtype_svf), K),
+                            )
                         else:
-                            a_high = T.BufferLoad(A, [k_row + 1, T.Ramp(0, 1, 
in_dtype_svf)])
-                            b_high = T.BufferLoad(B, [k_row + 1, T.Ramp(0, 1, 
in_dtype_svf)])
+                            a_low = (T.BufferLoad(A, [k_row, T.Ramp(0, 1, 
in_dtype_svf)]), ptrue)
+                            b_low = (T.BufferLoad(B, [k_row, T.Ramp(0, 1, 
in_dtype_svf)]), ptrue)
+                            a_high = (
+                                T.BufferLoad(A, [k_row + 1, T.Ramp(0, 1, 
in_dtype_svf)]),
+                                ptrue,
+                            )
+                            b_high = (
+                                T.BufferLoad(B, [k_row + 1, T.Ramp(0, 1, 
in_dtype_svf)]),
+                                ptrue,
+                            )
 
                         input_combinations = [
                             (a_low, b_low),
@@ -606,10 +694,10 @@ def get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(K, 
in_dtype):
                                     fmopa_intrin,
                                     T.uint32(5),
                                     sub_tile,
-                                    ptrue,
-                                    ptrue,
-                                    input_1,
-                                    input_2,
+                                    input_1[1],
+                                    input_2[1],
+                                    input_1[0],
+                                    input_2[0],
                                 )
                             )
 
@@ -626,7 +714,9 @@ def get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(K, 
in_dtype):
                                     "void",
                                     "llvm.aarch64.sme.st1w.horiz",
                                     T.uint32(4),
-                                    _create_ptrue_mask("float32"),
+                                    _create_active_lane_mask(
+                                        C, (vert_offset + slice_idx, 
horiz_offset), M
+                                    ),
                                     output_ptr,
                                     T.int32(sub_tile_idx),
                                     T.int32(slice_idx),
@@ -691,10 +781,6 @@ ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA = 
"sme_2svlx2svl_gemm_interleaved_mopa"
 # in versions of LLVM >= 15. Installations with older versions of LLVM will
 # not be able to use them.
 if llvm_version_major() >= 15:
-    TensorIntrin.register(
-        ARM_SME_2SVLx2SVL_FP32_TRANSPOSE_INTERLEAVE,
-        *get_sme_transpose_interleave_2svlx2svl_fp32_intrin(),
-    )
     TensorIntrin.register(
         ARM_SME_BLOCK2_2SVLx1SVL_FP16_TRANSPOSE_INTERLEAVE,
         *get_sme_transpose_interleave_block2_2svl_fp16_intrin(),
diff --git a/python/tvm/topi/arm_cpu/conv2d.py 
b/python/tvm/topi/arm_cpu/conv2d.py
index a6c951c078..b7327d5b52 100644
--- a/python/tvm/topi/arm_cpu/conv2d.py
+++ b/python/tvm/topi/arm_cpu/conv2d.py
@@ -24,7 +24,6 @@ from tvm import autotvm
 from tvm.script import tir as T
 import tvm.contrib.nnpack
 from tvm.tir.schedule.analysis import has_block
-from tvm.topi.arm_cpu.matmul import _get_transpose_interleave_intrin_name
 
 from ..utils import traverse_inline, get_const_tuple
 from .. import nn
@@ -773,10 +772,7 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
             ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA,
             ARM_SME_INIT,
             get_sme_gemm_interleaved_mopa_2svlx2svl_intrin,
-        )
-
-        transpose_interleave_intrin_name = 
_get_transpose_interleave_intrin_name(
-            in_dtype, out_dtype
+            get_transpose_interleave_intrin_name,
         )
 
         # Interleave the padded im2col matrix utilizing the matrix tile
@@ -787,7 +783,9 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
         ko, ki = sch.split(k, factors=(None, tile_K), disable_predication=True)
         sch.parallel(b)
         sch.reorder(b, ko, mo, ki, mi)
-        sch.tensorize(ki, transpose_interleave_intrin_name)
+        sch.tensorize(
+            ki, get_transpose_interleave_intrin_name(in_dtype, out_dtype, 
M_padded, K_padded)
+        )
 
         # Interleave the padded weights matrix utilizing the matrix tile
         if in_dtype == "float16":
@@ -797,7 +795,9 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
             ko, ki = sch.split(k, factors=(None, tile_K), 
disable_predication=True)
             no, ni = sch.split(n, factors=(None, tile_N), 
disable_predication=True)
             sch.reorder(ko, no, ki, ni)
-            sch.tensorize(ki, transpose_interleave_intrin_name)
+            sch.tensorize(
+                ki, get_transpose_interleave_intrin_name(in_dtype, out_dtype, 
M_padded, K_padded)
+            )
 
         # Split and reorder the loops of the GeMM for tensorization
         b, m, n, k = sch.get_loops(gemm_block)
@@ -816,11 +816,11 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: 
tvm.tir.Schedule):
 
         # Tensorize the GeMM update
         sme_gemm_interleaved_intrin_name = (
-            ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA + f"_{K_padded}_{in_dtype}"
+            ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA + 
f"_{M_padded}_{K_padded}_{in_dtype}"
         )
         tvm.tir.TensorIntrin.register(
             sme_gemm_interleaved_intrin_name,
-            *get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(K_padded, 
in_dtype),
+            *get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(M_padded, 
K_padded, in_dtype),
             override=True,
         )
         sch.tensorize(mi, sme_gemm_interleaved_intrin_name)
@@ -922,16 +922,18 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: 
tvm.tir.Schedule):
         reshape_block = func_blocks["T_reshape"]
         A_pad_block = func_blocks["A_padded_K"] if func_blocks["A_padded_K"] 
else None
         A_pad_block = func_blocks["A_padded_M"] if func_blocks["A_padded_M"] 
else A_pad_block
-        if use_sme:
-            sch.compute_inline(reshape_block)
-        elif A_pad_block:
-            sch.compute_inline(reshape_block)
-            b, m, k = sch.get_loops(A_pad_block)
-            _, k_inner = sch.split(k, [None, tile_N])
-            sch.vectorize(k_inner)
-            sch.compute_at(A_pad_block, mi)
-        else:
-            sch.compute_at(reshape_block, mi)
+        use_explicit_predication = use_sme and in_dtype == "float32"
+        if not use_explicit_predication:
+            if use_sme:
+                sch.compute_inline(reshape_block)
+            elif A_pad_block:
+                sch.compute_inline(reshape_block)
+                b, m, k = sch.get_loops(A_pad_block)
+                _, k_inner = sch.split(k, [None, tile_N])
+                sch.vectorize(k_inner)
+                sch.compute_at(A_pad_block, mi)
+            else:
+                sch.compute_at(reshape_block, mi)
 
     # Weight flattening
     if func_blocks["weight_flatten"]:
diff --git a/python/tvm/topi/arm_cpu/conv2d_gemm.py 
b/python/tvm/topi/arm_cpu/conv2d_gemm.py
index e637aa91e5..bf6a9c7551 100644
--- a/python/tvm/topi/arm_cpu/conv2d_gemm.py
+++ b/python/tvm/topi/arm_cpu/conv2d_gemm.py
@@ -133,23 +133,25 @@ def compute_conv2d_gemm_without_weight_transform(
     )
 
     # Pad to tiles (if necessary)
-    pad_M, pad_K = arm_utils.get_conv2d_im2col_padding(M, K, tile_M, tile_K_A)
-    pad_N, _ = arm_utils.get_conv2d_weights_padding(N, K, tile_N, tile_K_B)
+    use_explicit_predication = use_sme and in_dtype == "float32"
+    if not use_explicit_predication:
+        pad_M, pad_K = arm_utils.get_conv2d_im2col_padding(M, K, tile_M, 
tile_K_A)
+        pad_N, _ = arm_utils.get_conv2d_weights_padding(N, K, tile_N, tile_K_B)
 
-    M_padded = M + pad_M
-    K_padded = K + pad_K
-    N_padded = N + pad_N
+        M_padded = M + pad_M
+        K_padded = K + pad_K
+        N_padded = N + pad_N
 
-    pad_before = (0, 0, 0)
-    pad_after = (0, pad_M, pad_K)
+        pad_before = (0, 0, 0)
+        pad_after = (0, pad_M, pad_K)
 
-    if pad_K != 0:
-        A = nn.pad(A, pad_before=pad_before, pad_after=pad_after, 
name="A_padded_K")
-    elif pad_M != 0:
-        A = nn.pad(A, pad_before=pad_before, pad_after=pad_after, 
name="A_padded_M")
+        if pad_K != 0:
+            A = nn.pad(A, pad_before=pad_before, pad_after=pad_after, 
name="A_padded_K")
+        elif pad_M != 0:
+            A = nn.pad(A, pad_before=pad_before, pad_after=pad_after, 
name="A_padded_M")
 
     idxm = tvm.tir.indexmod
-    k = te.reduce_axis((0, K_padded), "k")
+    k = te.reduce_axis((0, K if use_explicit_predication else K_padded), "k")
 
     # Determine matrix multiplication compute definition
     target = Target.current(allow_none=False)
@@ -300,7 +302,18 @@ def compute_conv2d_gemm_without_weight_transform(
             name="C",
         )
         zero = tvm.tir.const(0)
-    elif use_scalable_vectors or use_sme:
+    elif use_explicit_predication:
+        assert len(B_interleaved_t.shape) == 2
+        C = te.compute(
+            (batches, M, N),
+            lambda b, x, y: te.sum(
+                A[b, x, k].astype(in_dtype) * B_interleaved_t[k, 
y].astype(in_dtype),
+                axis=k,
+            ),
+            name="C",
+        )
+        zero = tvm.tir.const(0)
+    elif use_scalable_vectors:
         assert len(B_interleaved_t.shape) == 2
         C = te.compute(
             (batches, M_padded, N_padded),
diff --git a/python/tvm/topi/arm_cpu/matmul.py 
b/python/tvm/topi/arm_cpu/matmul.py
index 23b8734a0b..63f6289f0e 100644
--- a/python/tvm/topi/arm_cpu/matmul.py
+++ b/python/tvm/topi/arm_cpu/matmul.py
@@ -53,19 +53,16 @@ def compute_matmul_sme(cfg, data_a, data_b, _, out_dtype, 
transpose_a=False, tra
         tile_k *= 2
     tile_n = 2 * tvm.tir.get_vscale_expr(data_a.dtype)
 
-    M_padded, pad_M = pad_dim_to_multiple(M, tile_m)
-    _, pad_K = pad_dim_to_multiple(K, tile_k)
-    N_padded, pad_N = pad_dim_to_multiple(N, tile_n)
-
-    m_pad_after = (pad_M, pad_K)
-    n_pad_after = (pad_K, pad_N)
-    if transpose_b:
-        n_pad_after = (pad_N, pad_K)
-
-    if pad_M != 0:
-        data_a = nn.pad(data_a, pad_before=(0, 0), pad_after=m_pad_after)
-    if pad_N != 0:
-        data_b = nn.pad(data_b, pad_before=(0, 0), pad_after=n_pad_after)
+    if data_a.dtype == "float16":
+        _, pad_M = pad_dim_to_multiple(M, tile_m)
+        _, pad_K = pad_dim_to_multiple(K, tile_k)
+        _, pad_N = pad_dim_to_multiple(N, tile_n)
+        m_pad_after = (pad_M, pad_K)
+        n_pad_after = (pad_N, pad_K) if transpose_b else (pad_K, pad_N)
+        if pad_M != 0:
+            data_a = nn.pad(data_a, pad_before=(0, 0), pad_after=m_pad_after)
+        if pad_N != 0:
+            data_b = nn.pad(data_b, pad_before=(0, 0), pad_after=n_pad_after)
 
     if out_dtype is None:
         out_dtype = data_a.dtype
@@ -87,28 +84,12 @@ def compute_matmul_sme(cfg, data_a, data_b, _, out_dtype, 
transpose_a=False, tra
         (False, False): "T_matmul_NN",
     }[(transpose_a, transpose_b)]
 
-    C = te.compute(
-        (M_padded, N_padded),
+    return te.compute(
+        (M, N),
         compute,
         name=compute_name,
         attrs={"schedule_type": "sme"},
     )
-    return te.compute((M, N), lambda m, n: C[m, n])
-
-
-def _get_transpose_interleave_intrin_name(in_dtype, out_dtype):
-    # pylint: disable=import-outside-toplevel
-    from tvm.tir.tensor_intrin.arm_cpu import (
-        ARM_SME_2SVLx2SVL_FP32_TRANSPOSE_INTERLEAVE,
-        ARM_SME_BLOCK2_2SVLx1SVL_FP16_TRANSPOSE_INTERLEAVE,
-    )
-
-    if in_dtype == "float32" and out_dtype == "float32":
-        return ARM_SME_2SVLx2SVL_FP32_TRANSPOSE_INTERLEAVE
-    elif in_dtype == "float16" and out_dtype == "float32":
-        return ARM_SME_BLOCK2_2SVLx1SVL_FP16_TRANSPOSE_INTERLEAVE
-    else:
-        raise ValueError("Input/output data type combination not supported.")
 
 
 def tir_schedule_matmul_sme(sch):
@@ -120,6 +101,7 @@ def tir_schedule_matmul_sme(sch):
         ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA,
         ARM_SME_INIT,
         get_sme_gemm_interleaved_mopa_2svlx2svl_intrin,
+        get_transpose_interleave_intrin_name,
     )
 
     main_func = sch.mod["main"]
@@ -157,9 +139,9 @@ def tir_schedule_matmul_sme(sch):
     outer_m, inner_m = sch.split(m, factors=(None, tile_m), 
disable_predication=True)
     outer_k, inner_k = sch.split(k, factors=(None, tile_k), 
disable_predication=True)
     sch.reorder(outer_k, outer_m, inner_k, inner_m)
-
-    transpose_interleave_intrin_name = 
_get_transpose_interleave_intrin_name(in_dtype, out_dtype)
-    sch.tensorize(inner_k, transpose_interleave_intrin_name)
+    sch.tensorize(
+        inner_k, get_transpose_interleave_intrin_name(in_dtype, out_dtype, 
extent_m, extent_k)
+    )
 
     # Interleave the weights utilizing the matrix tile
     if transpose_b:
@@ -169,7 +151,9 @@ def tir_schedule_matmul_sme(sch):
         outer_k, inner_k = sch.split(k, factors=(None, tile_k), 
disable_predication=True)
         outer_n, inner_n = sch.split(n, factors=(None, tile_n), 
disable_predication=True)
         sch.reorder(outer_k, outer_n, inner_k, inner_n)
-        sch.tensorize(inner_k, transpose_interleave_intrin_name)
+        sch.tensorize(
+            inner_k, get_transpose_interleave_intrin_name(in_dtype, out_dtype, 
extent_k, extent_n)
+        )
 
     # Split and reorder the loops of the GeMM for tensorization
     tile_m = T.cast(2 * tvm.tir.get_vscale_expr(out_dtype), extent_m.dtype)
@@ -185,11 +169,11 @@ def tir_schedule_matmul_sme(sch):
 
     # Tensorize the GeMM update
     sme_gemm_interleaved_intrin_name = (
-        ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA + f"_{extent_k}_{in_dtype}"
+        ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA + 
f"_{extent_m}_{extent_k}_{in_dtype}"
     )
     tvm.tir.TensorIntrin.register(
         sme_gemm_interleaved_intrin_name,
-        *get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(extent_k, in_dtype),
+        *get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(extent_m, extent_k, 
in_dtype),
         override=True,
     )
     sch.tensorize(inner_m, sme_gemm_interleaved_intrin_name)
diff --git a/tests/python/codegen/test_target_codegen_aarch64.py 
b/tests/python/codegen/test_target_codegen_aarch64.py
index 9b0408b949..f596549a10 100644
--- a/tests/python/codegen/test_target_codegen_aarch64.py
+++ b/tests/python/codegen/test_target_codegen_aarch64.py
@@ -530,12 +530,14 @@ def test_matmul_sme(dtype):
         )
         stores = re.findall(r"st1[whdb]\t{\s?za", assembly)
         smstop = re.findall(r"smstop\t(sm|za)", assembly)
+        whilelo = re.findall(r"whilelo\tp[0-9].[shdb]", assembly)
 
         assert len(smstart) > 0
         assert len(loads) > 0
         assert len(mopa) > 0
         assert len(stores) > 0
         assert len(smstop) > 0
+        assert len(whilelo) > 0
 
     check_correct_assembly(dtype=dtype)
 
@@ -819,12 +821,14 @@ def test_conv2d_sme(dtype):
         )
         stores = re.findall(r"st1[whdb]\t{\s?za", assembly)
         smstop = re.findall(r"smstop\t(sm|za)", assembly)
+        whilelo = re.findall(r"whilelo\tp[0-9].[shdb]", assembly)
 
         assert len(smstart) > 0
         assert len(loads) > 0
         assert len(mopa) > 0
         assert len(stores) > 0
         assert len(smstop) > 0
+        assert len(whilelo) > 0
 
     with tvm.target.Target(target):
         check_correct_assembly(dtype=dtype)
diff --git a/tests/python/topi/test_topi_conv2d_nhwc.py 
b/tests/python/topi/test_topi_conv2d_nhwc.py
index d46db1b28b..e7009ed179 100644
--- a/tests/python/topi/test_topi_conv2d_nhwc.py
+++ b/tests/python/topi/test_topi_conv2d_nhwc.py
@@ -168,10 +168,16 @@ def test_conv2d_nhwc_gemm(device, ref_data, dtype, 
stride, padding, dilation):
     target = tvm.target.Target(target_string)
 
     if target.features.has_sve and llvm_version_major() < 15:
-        pytest.skip(f"LLVM {llvm_version_major()} does not support targetting 
SVE.")
+        pytest.skip(f"LLVM {llvm_version_major()} does not support targeting 
SVE.")
 
     if target.features.has_sme and llvm_version_major() < 16:
-        pytest.skip(f"LLVM {llvm_version_major()} does not support targetting 
SME.")
+        pytest.skip(f"LLVM {llvm_version_major()} does not support targeting 
SME.")
+
+    if target.features.has_sme and a_np.shape[0] > 1:
+        pytest.skip(f"Conv2d with batches > 1 targeting SME not implemented.")
+
+    if target.features.has_sme and (a_np.shape[3] * w_np.shape[0] * 
w_np.shape[1]) <= 1:
+        pytest.skip(f"Conv2d with unit reduction dimension targeting SME not 
supported.")
 
     # SME schedule always outputs float32 results, regardless of input dtype.
     # Otherwise, output dtype is the same as input dtype.

Reply via email to