This is an automated email from the ASF dual-hosted git repository.
ekalda pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 4b8297480d [SME][TOPI] Add conv2d NHWC SME fp16->fp32 schedule (#17048)
4b8297480d is described below
commit 4b8297480d52d637d762544a8d68b01b7d01ff7a
Author: Andrei Hutu <[email protected]>
AuthorDate: Wed Jun 5 17:02:59 2024 +0100
[SME][TOPI] Add conv2d NHWC SME fp16->fp32 schedule (#17048)
This commit extends the SME conv2d NHWC schedule to support convolutions
with float16 inputs (data and kernel) and a float32 output using the tensor
intrinsics added in #16981.
---
python/tvm/relay/op/strategy/arm_cpu.py | 39 ++++++++---
python/tvm/topi/arm_cpu/arm_utils.py | 22 +++---
python/tvm/topi/arm_cpu/conv2d.py | 81 +++++++++++++++++++---
python/tvm/topi/arm_cpu/conv2d_alter_op.py | 28 ++++++++
python/tvm/topi/arm_cpu/conv2d_gemm.py | 11 +++
python/tvm/topi/nn/conv2d.py | 7 +-
tests/python/relay/strategy/arm_cpu/test_conv2d.py | 39 +++++++----
.../relay/strategy/test_select_implementation.py | 60 +++++++++++++---
tests/python/topi/test_topi_conv2d_nhwc.py | 15 ++--
9 files changed, 244 insertions(+), 58 deletions(-)
diff --git a/python/tvm/relay/op/strategy/arm_cpu.py
b/python/tvm/relay/op/strategy/arm_cpu.py
index 12f19462f7..35fd2b7a78 100644
--- a/python/tvm/relay/op/strategy/arm_cpu.py
+++ b/python/tvm/relay/op/strategy/arm_cpu.py
@@ -23,6 +23,7 @@ import re
from tvm import relay, topi, tir
from tvm.tir.schedule.analysis import has_block
+from tvm.dlight.gpu.matmul import auto_inline_consumers
from ....auto_scheduler import is_auto_scheduler_enabled
from ....meta_schedule import is_meta_schedule_enabled
@@ -255,9 +256,8 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type,
target):
if is_aarch64 and data.dtype in ["float32", "float16"]:
if (
target.features.has_sme
- and data.dtype in ["float32"]
- and kernel.dtype in ["float32"]
- and out_type.dtype in ["float32"]
+ and kernel.dtype == data.dtype
+ and out_type.dtype == "float32"
):
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_hybrid_SME),
@@ -536,6 +536,7 @@ def
conv2d_gemm_without_weight_transform_strategy_arm_cpu(attrs, inputs, out_typ
"""conv2d_winograd_without_weight_transform arm cpu strategy"""
layout = attrs.data_layout
data = inputs[0]
+ kernel = inputs[1]
strategy = _op.OpStrategy()
is_aarch64 = target.features.is_aarch64
has_dot_prod = target.features.has_dotprod
@@ -581,13 +582,31 @@ def
conv2d_gemm_without_weight_transform_strategy_arm_cpu(attrs, inputs, out_typ
wrap_topi_schedule(interleaved_schedule),
name="conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu",
)
+ # Non-quantized cases
elif data.dtype in ["float32", "float16"]:
- # Non-quantized cases
- strategy.add_implementation(
-
wrap_compute_conv2d_gemm(topi.arm_cpu.compute_conv2d_NHWC_hybrid_without_transform),
-
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_hybrid_without_transform),
- name="conv2d_NHWC_hybrid_without_transform.arm_cpu",
- )
+ # The SME schedule for float16->float32 prearranges the two
matrices to be multiplied
+ # using the ARM_SME_BLOCK2_2SVLx1SVL_FP16_TRANSPOSE_INTERLEAVE
intrinsic which expects
+ # the reduction axis K as the second dimension of the matrix (i.e.
shape = (_, K)).
+ # This means that the flattened weights matrix B needs to be
transposed to (N, K).
+ if (
+ target.features.has_sme
+ and kernel.dtype == "float16"
+ and data.dtype == "float16"
+ and out_type.dtype == "float32"
+ ):
+ strategy.add_implementation(
+
wrap_compute_conv2d_gemm(topi.arm_cpu.compute_conv2d_NHWC_SME_transposed_B),
+ lambda: None,
+ name="conv2d_NHWC_hybrid_SME_transposed_B.arm_cpu",
+ )
+ else:
+ strategy.add_implementation(
+ wrap_compute_conv2d_gemm(
+
topi.arm_cpu.compute_conv2d_NHWC_hybrid_without_transform
+ ),
+
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_hybrid_without_transform),
+ name="conv2d_NHWC_hybrid_without_transform.arm_cpu",
+ )
else:
raise RuntimeError(
f"Unsupported conv2d_NHWC_without_transform layout {layout}"
@@ -819,6 +838,8 @@ def arm_cpu_tir_strategy(sch: tir.Schedule) -> bool:
topi.arm_cpu.matmul.tir_schedule_matmul_sme(sch)
return True
elif has_block(sch, "conv2d_gemm_output"):
+ conv2d_block = sch.get_block("conv2d_gemm_output")
+ auto_inline_consumers(sch, conv2d_block)
topi.arm_cpu.schedule_conv2d_NHWC_hybrid_TIR(sch)
return True
diff --git a/python/tvm/topi/arm_cpu/arm_utils.py
b/python/tvm/topi/arm_cpu/arm_utils.py
index 5c4b3c0456..f690b22731 100644
--- a/python/tvm/topi/arm_cpu/arm_utils.py
+++ b/python/tvm/topi/arm_cpu/arm_utils.py
@@ -68,8 +68,11 @@ def get_tiling_A(interleave_A, in_dtype, use_sme=False):
tile_M = 4
tile_K = 16
elif use_sme:
- tile_M = 2 * 4 * tvm.tir.vscale()
- tile_K = 2 * 4 * tvm.tir.vscale()
+ tile_M = 2 * tvm.tir.get_vscale_expr(in_dtype)
+ if in_dtype == "float16":
+ tile_K = tvm.tir.get_vscale_expr(in_dtype)
+ else:
+ tile_K = 2 * tvm.tir.get_vscale_expr(in_dtype)
else:
# In non-SME, non-quantized cases, A is not interleaved.
# We are loading 4 rows from A.
@@ -139,17 +142,16 @@ def get_tiling_B_transformed(interleave_A, in_dtype,
use_scalable_vectors=False,
tile_N = 4
tile_K = 16
elif use_sme:
- tile_N = 2 * 4 * tvm.tir.vscale()
- tile_K = 2 * 4 * tvm.tir.vscale()
- # In non-SME, non-quantized cases, A is not interleaved.
- elif use_scalable_vectors:
+ tile_N = 2 * tvm.tir.get_vscale_expr(in_dtype)
if in_dtype == "float16":
- # Each load from B' contains 32 * vscale elements (i.e. 32 *
vscale columns from B)
- tile_N = 32 * tvm.tir.vscale()
+ tile_K = tvm.tir.get_vscale_expr(in_dtype)
else:
- # Each load from B' contains 16 * vscale elements (i.e. 16 *
vscale columns from B)
- tile_N = 16 * tvm.tir.vscale()
+ tile_K = 2 * tvm.tir.get_vscale_expr(in_dtype)
+ # In non-SME, non-quantized cases, A is not interleaved.
+ elif use_scalable_vectors:
+ # Each load from B' contains 4 * scalable vectors (i.e. 4 * SVL
columns from B)
# We are loading 4 rows from B', in the dimension of reduction (i.e. 4
rows from B)
+ tile_N = 4 * tvm.tir.get_vscale_expr(in_dtype)
tile_K = 4
elif in_dtype == "float16" and target.features.has_fp16_simd:
# Each load from B' contains 32 elements (i.e. 32 columns from B)
diff --git a/python/tvm/topi/arm_cpu/conv2d.py
b/python/tvm/topi/arm_cpu/conv2d.py
index d0fe251e7e..a6c951c078 100644
--- a/python/tvm/topi/arm_cpu/conv2d.py
+++ b/python/tvm/topi/arm_cpu/conv2d.py
@@ -24,6 +24,7 @@ from tvm import autotvm
from tvm.script import tir as T
import tvm.contrib.nnpack
from tvm.tir.schedule.analysis import has_block
+from tvm.topi.arm_cpu.matmul import _get_transpose_interleave_intrin_name
from ..utils import traverse_inline, get_const_tuple
from .. import nn
@@ -680,6 +681,43 @@ def compute_conv2d_NHWC_hybrid_SME(cfg, data, kernel,
strides, padding, dilation
)
[email protected]_topi_compute("conv2d_NHWC_hybrid_SME_transposed_B.arm_cpu")
+def compute_conv2d_NHWC_SME_transposed_B(
+ cfg,
+ data,
+ kernel,
+ strides,
+ padding,
+ dilation,
+ out_dtype,
+ kernel_size,
+ output_channels,
+):
+ """Compute conv2d NHWC hybrid SME transposed B"""
+ N, K = get_const_tuple(kernel.shape)
+ tile_N, tile_K = get_tiling_B_transformed(False, data.dtype, True, True)
+ pad_N, pad_K = tvm.topi.arm_cpu.arm_utils.get_conv2d_weights_padding(N, K,
tile_N, tile_K)
+
+ kernel = tvm.topi.nn.pad(
+ kernel, pad_before=(0, 0), pad_after=(pad_N, pad_K),
name="weight_padding"
+ )
+
+ return compute_conv2d_gemm_without_weight_transform(
+ cfg,
+ data,
+ kernel,
+ strides,
+ padding,
+ dilation,
+ out_dtype,
+ kernel_size,
+ output_channels,
+ interleave_A=False,
+ use_scalable_vectors=True,
+ use_sme=True,
+ )
+
+
def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
"""
Perform TIR scheduling for conv2d NHWC.
@@ -688,7 +726,8 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
primfunc = sch.mod["main"]
buffer_names = primfunc.params
buffer_list = [primfunc.buffer_map[buf] for buf in buffer_names]
- dtype = buffer_list[0].dtype
+ in_dtype = buffer_list[0].dtype
+ out_dtype = "float32"
# Determine PrimFunc blocks
block_list = [
@@ -698,6 +737,8 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
"A_padded_K",
"A_padded_M",
"weight_flatten",
+ "weight_padding",
+ "weight_transpose",
"C",
"conv2d_gemm_output",
]
@@ -716,8 +757,8 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
M_padded = sch.get(m).extent
N_padded = sch.get(n).extent
K_padded = sch.get(k).extent
- tile_M, tile_K = get_tiling_A(False, dtype, use_sme)
- tile_N, _ = get_tiling_B_transformed(False, dtype, use_scalable_vectors,
use_sme)
+ tile_M, tile_K = get_tiling_A(False, in_dtype, use_sme)
+ tile_N, _ = get_tiling_B_transformed(False, in_dtype,
use_scalable_vectors, use_sme)
tile_M = T.cast(tile_M, M_padded.dtype)
tile_N = T.cast(tile_N, N_padded.dtype)
tile_K = T.cast(tile_K, K_padded.dtype)
@@ -729,12 +770,15 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch:
tvm.tir.Schedule):
# pylint: disable=import-outside-toplevel
from tvm.topi.arm_cpu.pstate_attributes import SMEAttributes
from tvm.tir.tensor_intrin.arm_cpu import (
- ARM_SME_2SVLx2SVL_FP32_TRANSPOSE_INTERLEAVE,
ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA,
ARM_SME_INIT,
get_sme_gemm_interleaved_mopa_2svlx2svl_intrin,
)
+ transpose_interleave_intrin_name =
_get_transpose_interleave_intrin_name(
+ in_dtype, out_dtype
+ )
+
# Interleave the padded im2col matrix utilizing the matrix tile
interleave_t_A_block = sch.cache_read(gemm_block, 0, "global")
sch.transform_layout(interleave_t_A_block, ("write", 0), lambda b, m,
k: (b, k, m))
@@ -743,24 +787,40 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch:
tvm.tir.Schedule):
ko, ki = sch.split(k, factors=(None, tile_K), disable_predication=True)
sch.parallel(b)
sch.reorder(b, ko, mo, ki, mi)
- sch.tensorize(ki, ARM_SME_2SVLx2SVL_FP32_TRANSPOSE_INTERLEAVE)
+ sch.tensorize(ki, transpose_interleave_intrin_name)
+
+ # Interleave the padded weights matrix utilizing the matrix tile
+ if in_dtype == "float16":
+ interleave_b_block = sch.cache_read(gemm_block, 1, "global")
+ sch.transform_layout(interleave_b_block, ("write", 0), lambda n,
k: (k, n))
+ n, k = sch.get_loops(interleave_b_block)
+ ko, ki = sch.split(k, factors=(None, tile_K),
disable_predication=True)
+ no, ni = sch.split(n, factors=(None, tile_N),
disable_predication=True)
+ sch.reorder(ko, no, ki, ni)
+ sch.tensorize(ki, transpose_interleave_intrin_name)
# Split and reorder the loops of the GeMM for tensorization
b, m, n, k = sch.get_loops(gemm_block)
+ tile_M, _ = get_tiling_A(False, out_dtype, True)
+ tile_N, _ = get_tiling_B_transformed(False, out_dtype, True, True)
+ tile_M = T.cast(tile_M, M_padded.dtype)
+ tile_N = T.cast(tile_N, N_padded.dtype)
mo, mi = sch.split(m, factors=(None, tile_M), disable_predication=True)
no, ni = sch.split(n, factors=(None, tile_N), disable_predication=True)
sch.parallel(b)
sch.reorder(b, mo, no, mi, ni, k)
- # Tensorize the GeMM output matrix initialization to zero
+ # Tensorize the GeMM initialization
init_block = sch.decompose_reduction(gemm_block, mi)
sch.tensorize(sch.get_loops(init_block)[-2], ARM_SME_INIT)
# Tensorize the GeMM update
- sme_gemm_interleaved_intrin_name =
ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA + f"_{K_padded}"
+ sme_gemm_interleaved_intrin_name = (
+ ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA + f"_{K_padded}_{in_dtype}"
+ )
tvm.tir.TensorIntrin.register(
sme_gemm_interleaved_intrin_name,
- *get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(K_padded, dtype),
+ *get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(K_padded,
in_dtype),
override=True,
)
sch.tensorize(mi, sme_gemm_interleaved_intrin_name)
@@ -878,6 +938,11 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
weight_flatten_block = func_blocks["weight_flatten"]
sch.compute_inline(weight_flatten_block)
+ # Weight transpose
+ if func_blocks["weight_transpose"] and func_blocks["weight_padding"]:
+ weight_padding_block = func_blocks["weight_padding"]
+ sch.compute_inline(weight_padding_block)
+
# Conv2d output block
output_block = func_blocks["conv2d_gemm_output"]
n, h, w, c = sch.get_loops(output_block)
diff --git a/python/tvm/topi/arm_cpu/conv2d_alter_op.py
b/python/tvm/topi/arm_cpu/conv2d_alter_op.py
index fe4569ceb1..2476cb92b9 100644
--- a/python/tvm/topi/arm_cpu/conv2d_alter_op.py
+++ b/python/tvm/topi/arm_cpu/conv2d_alter_op.py
@@ -162,6 +162,34 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
inputs[0], new_kernel_expr, **new_attrs
)
+ if (
+ topi_tmpl == "conv2d_NHWC_hybrid_SME.arm_cpu"
+ and data_dtype == "float16"
+ and kernel_dtype == "float16"
+ and out_dtype == "float32"
+ ):
+ assert data_layout == "NHWC" and kernel_layout == "HWIO"
+ KH, KW, IC, OC = get_const_tuple(kernel.shape)
+ K = KH * KW * IC
+ N = OC
+ # The SME schedule for float16->float32 prearranges the two matrices
to be multiplied
+ # using the ARM_SME_BLOCK2_2SVLx1SVL_FP16_TRANSPOSE_INTERLEAVE
intrinsic which expects
+ # the reduction axis K as the second dimension of the matrix (i.e.
shape = (_, K)).
+ # This means that the flattened weights matrix B needs to be
transposed to (N, K).
+ transposed_kernel_expr = relay.transpose(inputs[1], axes=[3, 0, 1, 2])
+ transposed_flattened_kernel_expr =
relay.reshape(transposed_kernel_expr, newshape=(N, K))
+ new_kernel_expr = transposed_flattened_kernel_expr
+ new_kernel = te.placeholder((N, K), kernel.dtype)
+ new_workload_name = "conv2d_NHWC_hybrid_SME_transposed_B.arm_cpu"
+ new_workload = autotvm.task.args_to_workload(
+ [data, new_kernel, strides, padding, dilation, out_dtype, (KH,
KW), OC],
+ new_workload_name,
+ )
+ dispatch_ctx.update(target, new_workload, cfg)
+ return relay.nn.contrib_conv2d_gemm_without_weight_transform(
+ inputs[0], new_kernel_expr, **new_attrs
+ )
+
# Only microTVM does layout alteration for NHWC layout with real data types
if data_layout == "NHWC" and data_dtype not in ["uint8", "int8"]:
return None
diff --git a/python/tvm/topi/arm_cpu/conv2d_gemm.py
b/python/tvm/topi/arm_cpu/conv2d_gemm.py
index 0c3908bb70..e637aa91e5 100644
--- a/python/tvm/topi/arm_cpu/conv2d_gemm.py
+++ b/python/tvm/topi/arm_cpu/conv2d_gemm.py
@@ -289,6 +289,17 @@ def compute_conv2d_gemm_without_weight_transform(
tvm.tir.const(1, C.dtype) * C[0, M_padded - 1, N_padded - 1]
- tvm.tir.const(1, C.dtype) * C[0, M_padded - 1, N_padded - 1]
)
+ elif use_sme and in_dtype == "float16" and out_dtype == "float32":
+ assert len(B_interleaved_t.shape) == 2
+ C = te.compute(
+ (batches, M_padded, N_padded),
+ lambda b, x, y: te.sum(
+ A[b, x, k].astype(out_dtype) * B_interleaved_t[y,
k].astype(out_dtype),
+ axis=k,
+ ),
+ name="C",
+ )
+ zero = tvm.tir.const(0)
elif use_scalable_vectors or use_sme:
assert len(B_interleaved_t.shape) == 2
C = te.compute(
diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py
index 8d61c62250..205730ff22 100644
--- a/python/tvm/topi/nn/conv2d.py
+++ b/python/tvm/topi/nn/conv2d.py
@@ -654,7 +654,12 @@ def conv2d_gemm_weight_transform(kernel, tile_N, tile_K,
use_scalable_vectors=Fa
kernel_flat, pad_before=(0, 0), pad_after=(pad_K, pad_N),
name="weight_padding"
)
- if use_sme or use_scalable_vectors:
+ if use_sme and kernel.dtype == "float16":
+ return te.compute(
+ (N_padded, K_padded), lambda x, y: kernel_flat[y, x],
name="weight_transpose"
+ )
+
+ if use_scalable_vectors or use_sme:
return kernel_flat
if kernel.dtype in ["int8", "uint8"]:
diff --git a/tests/python/relay/strategy/arm_cpu/test_conv2d.py
b/tests/python/relay/strategy/arm_cpu/test_conv2d.py
index 2708094afb..f4fa250ecf 100644
--- a/tests/python/relay/strategy/arm_cpu/test_conv2d.py
+++ b/tests/python/relay/strategy/arm_cpu/test_conv2d.py
@@ -120,7 +120,8 @@ class TestConv2d_NCHW_Spatial_Pack(Conv2dTests):
schedule_name = parameter("conv2d_nchw_spatial_pack.arm_cpu")
-dtype = tvm.testing.parameter("float32")
+in_dtype = tvm.testing.parameter("float16", "float32")
+out_dtype = tvm.testing.parameter("float32")
batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation =
tvm.testing.parameters(
# Pad M, N, K
@@ -154,30 +155,35 @@ batch, in_channel, in_size, num_filter, kernel, stride,
padding, dilation = tvm.
@tvm.testing.fixture()
-def ref_data(dtype, batch, in_channel, in_size, num_filter, kernel, stride,
padding, dilation):
+def ref_data(
+ in_dtype, out_dtype, batch, in_channel, in_size, num_filter, kernel,
stride, padding, dilation
+):
np.random.seed(0)
in_height = in_width = in_size
a_shape = (batch, in_height, in_width, in_channel)
w_shape = (kernel, kernel, in_channel, num_filter)
- a_np = np.random.uniform(size=a_shape).astype(dtype)
- w_np = np.random.uniform(size=w_shape).astype(dtype)
- return a_np, w_np
+ a_np = np.random.uniform(size=a_shape).astype(in_dtype)
+ w_np = np.random.uniform(size=w_shape).astype(in_dtype)
+ dw_np = tvm.topi.testing.dilate_python(w_np, (dilation, dilation, 1, 1))
+ b_np = tvm.topi.testing.conv2d_nhwc_python(
+ a_np.astype(out_dtype), dw_np.astype(out_dtype), stride, padding
+ ).astype(out_dtype)
+ return a_np, w_np, dw_np, b_np
@pytest.mark.skipif(
llvm_version_major() < 16, reason="SME is not supported in earlier
versions of LLVM"
)
@tvm.testing.requires_aprofile_aem_fvp
-def test_conv2d_fp32(target, ref_data, dtype, stride, padding, dilation):
- a_np, w_np = ref_data
- dw_np = tvm.topi.testing.dilate_python(w_np, (dilation, dilation, 1, 1))
+def test_conv2d_sme(target, ref_data, in_dtype, out_dtype, stride, padding,
dilation):
+ a_np, w_np, dw_np, b_np = ref_data
kernel_size = get_const_tuple(w_np.shape[:2])
out_channels = w_np.shape[3]
- x = relay.var("data", shape=a_np.shape, dtype=dtype)
- weight = relay.const(w_np, dtype=dtype)
+ x = relay.var("data", shape=a_np.shape, dtype=in_dtype)
+ weight = relay.const(w_np, dtype=in_dtype)
conv2d = relay.nn.conv2d(
x,
weight,
@@ -188,7 +194,7 @@ def test_conv2d_fp32(target, ref_data, dtype, stride,
padding, dilation):
padding=get_pad_tuple(padding, dw_np.shape[:2]),
data_layout="NHWC",
kernel_layout="HWIO",
- out_dtype=dtype,
+ out_dtype=out_dtype,
)
func = relay.Function(relay.analysis.free_vars(conv2d), conv2d)
@@ -198,7 +204,7 @@ def test_conv2d_fp32(target, ref_data, dtype, stride,
padding, dilation):
inputs = {"data": a_np}
params = {}
- ref_outputs = generate_ref_data(ir_mod, inputs, params)
+ ref_outputs = {"output": b_np}
target = tvm.target.Target("llvm -mtriple=aarch64-none-elf
-mattr=+v9.2a,+sme")
runtime = tvm.relay.backend.Runtime("crt", {"system-lib": True})
@@ -220,9 +226,12 @@ def test_conv2d_fp32(target, ref_data, dtype, stride,
padding, dilation):
runtime=runtime,
params=params,
)
- generated_func = executor_factory.lowered_ir_mods.items()[0][1][
- "tvmgen_default_fused_nn_conv2d"
- ]
+
+ if in_dtype == "float16":
+ func_name =
"tvmgen_default_fused_nn_contrib_conv2d_gemm_without_weight_transform"
+ else:
+ func_name = "tvmgen_default_fused_nn_conv2d"
+ generated_func = executor_factory.lowered_ir_mods.items()[0][1][func_name]
extra_memory_in_bytes =
calculate_extra_workspace_size_from_scalable_extents(generated_func, 4)
test_model = AOTTestModel(
diff --git a/tests/python/relay/strategy/test_select_implementation.py
b/tests/python/relay/strategy/test_select_implementation.py
index 01a914e793..b95bd4072a 100644
--- a/tests/python/relay/strategy/test_select_implementation.py
+++ b/tests/python/relay/strategy/test_select_implementation.py
@@ -58,7 +58,7 @@ def test_concatenate(target, expected_implementation):
assert impl.name == expected_implementation
-def _get_conv2d_impl(dtype, target):
+def _get_conv2d_impl(in_dtype, out_dtype, target):
"""Returns selected conv2d implementation for a given datatype and
target"""
data_shape = (1, 1, 1, 4)
weight_shape = (1, 1, 4, 4)
@@ -68,21 +68,24 @@ def _get_conv2d_impl(dtype, target):
kernel_size = (1, 1)
out = relay.nn.conv2d(
- relay.var("data", shape=data_shape, dtype=dtype),
- relay.var("weight", shape=weight_shape, dtype=dtype),
+ relay.var("data", shape=data_shape, dtype=in_dtype),
+ relay.var("weight", shape=weight_shape, dtype=in_dtype),
kernel_size=kernel_size,
channels=channels,
data_layout=data_layout,
kernel_layout=kernel_layout,
- out_dtype=dtype,
+ out_dtype=out_dtype,
)
with target:
out = run_opt_pass(out, relay.transform.AlterOpLayout())
+ data_shape = out.type_args[0].shape
+ weight_shape = out.type_args[1].shape
+
impl, _ = relay.backend.te_compiler.select_implementation(
out.op,
out.attrs,
- [te.placeholder(data_shape, dtype), te.placeholder(weight_shape,
dtype)],
+ [te.placeholder(data_shape, in_dtype),
te.placeholder(weight_shape, in_dtype)],
out.checked_type,
target,
use_autotvm=False,
@@ -131,7 +134,7 @@ def test_int8_conv2d(target, expected_impl):
target = tvm.target.Target(target)
dtype = "int8"
- selected_impl = _get_conv2d_impl(dtype, target)
+ selected_impl = _get_conv2d_impl(dtype, dtype, target)
assert selected_impl == expected_impl
@@ -171,7 +174,7 @@ def test_fp32_conv2d(target, expected_impl):
target = tvm.target.Target(target)
dtype = "float32"
- selected_impl = _get_conv2d_impl(dtype, target)
+ selected_impl = _get_conv2d_impl(dtype, dtype, target)
assert selected_impl == expected_impl
@@ -211,7 +214,48 @@ def test_fp16_conv2d(target, expected_impl):
target = tvm.target.Target(target)
dtype = "float16"
- selected_impl = _get_conv2d_impl(dtype, target)
+ selected_impl = _get_conv2d_impl(dtype, dtype, target)
+ assert selected_impl == expected_impl
+
+
[email protected](
+ llvm_version_major() < 15, reason=f"Requires LLVM 15+, got
{llvm_version_major()}"
+)
[email protected](
+ "target,expected_impl",
+ [
+ (
+ "llvm -device=arm_cpu -mtriple=armv8l-linux-gnu -mattr=+neon",
+ "conv2d_nhwc_spatial_pack.arm_cpu",
+ ),
+ (
+ "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu",
+ "conv2d_NHWC_hybrid_without_transform.arm_cpu",
+ ),
+ (
+ "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+neon",
+ "conv2d_NHWC_hybrid_without_transform.arm_cpu",
+ ),
+ (
+ "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu
-mattr=+v8.2a,+neon",
+ "conv2d_NHWC_hybrid_without_transform.arm_cpu",
+ ),
+ (
+ "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+v9a",
+ "conv2d_NHWC_hybrid_without_transform.arm_cpu",
+ ),
+ (
+ "llvm --device=arm_cpu --mtriple=aarch64-linux-gnu
-mattr=+v9.2a,+sme",
+ "conv2d_NHWC_hybrid_SME_transposed_B.arm_cpu",
+ ),
+ ],
+)
+def test_fp16_to_fp32_conv2d(target, expected_impl):
+ target = tvm.target.Target(target)
+ in_dtype = "float16"
+ out_dtype = "float32"
+
+ selected_impl = _get_conv2d_impl(in_dtype, out_dtype, target)
assert selected_impl == expected_impl
diff --git a/tests/python/topi/test_topi_conv2d_nhwc.py
b/tests/python/topi/test_topi_conv2d_nhwc.py
index 02f16b59c0..d46db1b28b 100644
--- a/tests/python/topi/test_topi_conv2d_nhwc.py
+++ b/tests/python/topi/test_topi_conv2d_nhwc.py
@@ -68,7 +68,7 @@ device = tvm.testing.parameter(
False,
),
(
- "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.2a",
+ "llvm --device arm_cpu --mtriple aarch64-linux-gnu
-mattr=+v8.2a,+fullfp16",
topi.arm_cpu.compute_conv2d_NHWC_hybrid,
topi.arm_cpu.schedule_conv2d_NHWC_hybrid_TIR,
True,
@@ -173,13 +173,14 @@ def test_conv2d_nhwc_gemm(device, ref_data, dtype,
stride, padding, dilation):
if target.features.has_sme and llvm_version_major() < 16:
pytest.skip(f"LLVM {llvm_version_major()} does not support targetting
SME.")
- if target.features.has_sme and dtype == "float16":
- pytest.skip(f"Conv2d fp16 targetting SME not implemented.")
+ # SME schedule always outputs float32 results, regardless of input dtype.
+ # Otherwise, output dtype is the same as input dtype.
+ out_dtype = "float32" if target.features.has_sme else dtype
with target:
a = tvm.nd.array(a_np, dev)
w = tvm.nd.array(w_np, dev)
- B = compute(A, W, stride, padding, dilation, dtype)
+ B = compute(A, W, stride, padding, dilation, out_dtype)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype),
dev)
if use_tir_schedule:
primfunc = te.create_prim_func([A, W, B])
@@ -190,22 +191,22 @@ def test_conv2d_nhwc_gemm(device, ref_data, dtype,
stride, padding, dilation):
func = tvm.build(s, [A, W, B], target)
# Run only on AArch64 devices
- # Do not run SVE schedules on non-SVE devices
+ # Do not run SVE/SME schedules on non-SVE/SME devices
build_only = (
platform.machine() != "aarch64"
- or (target.features.has_sve and not
tvm.testing.requires_aarch64_sve.run_time_check())
or (
dtype == "float16"
and target.features.has_fp16_simd
and not tvm.testing.requires_arm_fp16.run_time_check()
)
+ or (target.features.has_sve and not
tvm.testing.requires_aarch64_sve.run_time_check())
or (target.features.has_sme and not
tvm.testing.requires_aarch64_sme.run_time_check())
)
if build_only:
return
func(a, w, b)
- tol = get_tolerance(dtype, w_np, b_np)
+ tol = get_tolerance(out_dtype, w_np, b_np)
tvm.testing.assert_allclose(b.numpy(), b_np, rtol=tol["rtol"],
atol=tol["atol"])