Anndrey24 commented on code in PR #17091:
URL: https://github.com/apache/tvm/pull/17091#discussion_r1638442104


##########
python/tvm/topi/arm_cpu/dense_alter_op.py:
##########
@@ -82,6 +83,31 @@ def _alter_dense(attrs, inputs, tinfos, out_type):
             False,
             transpose_b,
         )
+    elif topi_impl == "dense_gemm.arm_cpu":
+        # Pre-compute transposed weights and convert to a matmul
+        assert isinstance(
+            inputs[1], relay.Constant
+        ), "dense_gemm.arm_cpu requires weights be a Relay Constant"
+
+        weight_dtype = tinfos[1].dtype
+        weight_data = inputs[1].data.numpy()
+        interleaved = weight_data.transpose()
+        encoded_weight = relay.const(interleaved, weight_dtype)

Review Comment:
   It seems that some models fail at this check because their weights are not 
`relay.Constant`, but rather `relay.Call` (e.g. `inputs[1] = 
CallNode(Op(transpose), [CallNode(Op(reshape), ...`).  
   
   Instead of relying on retrieving the `data.numpy()` to perform of transpose 
here, which would indeed require the weights to only be `relay.Constant`, I 
think we could introduce the transposition at Relay level and give the compiler 
the opportunity to optimise it (e.g. fusing into the previous operation) using 
something like: `encoded_weight = relay.transpose(inputs[1])`. What do you 
think?



##########
cmake/config.cmake:
##########


Review Comment:
   I assume this was an accidental change?



##########
python/tvm/relay/op/strategy/arm_cpu.py:
##########
@@ -773,6 +784,18 @@ def matmul_strategy_arm_cpu(attrs, inputs, out_type, 
target):
             lambda: None,
             name="matmul.arm_cpu.sme",
         )
+    elif (
+        data.dtype in ["float16", "float32"]
+        and weight.dtype in ["float16", "float32"]
+        and out_type.dtype in ["float16", "float32"]
+        and not (attrs.transpose_a or attrs.transpose_b)
+        and len(data.shape) == 2
+    ):

Review Comment:
   Same comment about a `target.features.is_aarch64` condition here too.



##########
tests/python/relay/test_dense.py:
##########
@@ -0,0 +1,49 @@
+import tvm
+from tvm import relay
+from tvm.testing import assert_allclose
+import numpy as np
+from tvm.ir.instrument import pass_instrument
+
+
+def _test_accuracy(input_values, output_values, build_mod):
+
+    dev = tvm.cpu(0)
+
+    input_buf = tvm.nd.array(input_values, device=dev)
+    rt = tvm.contrib.graph_executor.GraphModule(build_mod["default"](dev))
+    rt.set_input("data", input_buf)
+    rt.run()
+    out = rt.get_output(0)
+
+    tvm.testing.assert_allclose(out.numpy(), output_values)
+
+
+# Define input shape and data type
+data_size = (64, 64)

Review Comment:
   We should probably add a few more test cases for different sizes.



##########
tests/python/relay/test_dense.py:
##########
@@ -0,0 +1,49 @@
+import tvm
+from tvm import relay
+from tvm.testing import assert_allclose
+import numpy as np
+from tvm.ir.instrument import pass_instrument
+
+
+def _test_accuracy(input_values, output_values, build_mod):
+
+    dev = tvm.cpu(0)
+
+    input_buf = tvm.nd.array(input_values, device=dev)
+    rt = tvm.contrib.graph_executor.GraphModule(build_mod["default"](dev))
+    rt.set_input("data", input_buf)
+    rt.run()
+    out = rt.get_output(0)
+
+    tvm.testing.assert_allclose(out.numpy(), output_values)
+
+
+# Define input shape and data type
+data_size = (64, 64)
+data_shape = data_size  # Input shape
+data_type = "float32"  # Data type
+weight_shape = data_size
+
+# Create Relay input variable
+d = relay.var("data", shape=data_shape, dtype=data_type)
+w1 = np.ones(weight_shape, dtype=data_type)
+w = relay.const(w1)
+
+# Create Relay dense layer
+y = relay.nn.dense(d, w)
+
+# Create Relay module
+mod = tvm.IRModule()
+
+# Define a Relay function with the dense layer
+mod["main"] = relay.Function([d], y)
+
+# Compile the Relay module
+target = "llvm -mtriple=aarch64-linux-gnu -device=arm_cpu -mattr=+v8.2a,+neon" 
 # Example target, you can change this to your desired target
+lib = relay.build(mod, target=target, params=None)

Review Comment:
   To differentiate this Relay test from the TOPI one, I think it would be good 
to make sure the AlterOpLayout pass runs, either by calling it manually as you 
did 
[here](https://github.com/apache/tvm/pull/17091/files#diff-1acd41854c82cb9d55e9e17c34c5f972e98a6a6c517dcf6f448d39acef293278R1500),
 or by adding a pass context with `opt_level=3` like 
[here](https://github.com/apache/tvm/blob/main/tests/python/relay/strategy/arm_cpu/test_dense.py#L146-L147).
  
   
   That way we're also testing the the AlterOpLayout pass has not made the 
computation incorrect.



##########
python/tvm/topi/arm_cpu/dense_gemm.py:
##########
@@ -0,0 +1,157 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-variable, too-many-locals
+"""GEMM Convolution schedule on AArch64"""
+import tvm
+from tvm.target import Target
+from tvm import te
+from tvm.topi import nn
+from tvm.topi.arm_cpu.arm_utils import get_tiling_A, get_tiling_B_transformed
+from ..utils import get_const_tuple, traverse_inline
+from ..nn.utils import get_pad_tuple
+from .. import tag
+
+# Compute function
+def dense_gemm_compute(
+    cfg, data, weight, bias=None, out_dtype=None, transpose_a=False, 
transpose_b=True
+):
+    """
+    Compute dense using GeMM.
+
+    transpose_b : Optional[bool] = True
+    Whether the weight tensor is in transposed format.
+    """
+
+    if out_dtype is None:
+        out_dtype = data.dtype
+    M, K = get_const_tuple(data.shape)  # batch, in_dim
+    if bool(transpose_b):  # out_dim
+        (N, _) = get_const_tuple(weight.shape)
+    else:
+        (_, N) = get_const_tuple(weight.shape)
+
+    in_dtype = data.dtype
+
+    tile_M, tile_K_A = get_tiling_A(False, in_dtype)
+    tile_N, tile_K_B = get_tiling_B_transformed(False, out_dtype, False)
+
+    pad_M = 0
+    pad_K = 0
+    pad_N = 0
+
+    if M % tile_M != 0:
+        pad_M = tile_M - (M % tile_M)
+
+    if K % tile_K_A != 0:
+        pad_K = tile_K_A - (K % tile_K_A)
+
+    M_padded = M + pad_M
+    K_padded = K + pad_K

Review Comment:
   We can replace these lines with calls to `arm_utils.pad_dim_to_multiple()` 
like in the [SME 
schedule](https://github.com/apache/tvm/blob/main/python/tvm/topi/arm_cpu/matmul.py#L56-L58)
 to make the schedule easier to read.



##########
python/tvm/topi/arm_cpu/dense_gemm.py:
##########
@@ -0,0 +1,157 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-variable, too-many-locals
+"""GEMM Convolution schedule on AArch64"""

Review Comment:
   Nit: "Dense" instead of "Convolution"



##########
python/tvm/topi/arm_cpu/dense_gemm.py:
##########
@@ -0,0 +1,157 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-variable, too-many-locals
+"""GEMM Convolution schedule on AArch64"""
+import tvm
+from tvm.target import Target
+from tvm import te
+from tvm.topi import nn
+from tvm.topi.arm_cpu.arm_utils import get_tiling_A, get_tiling_B_transformed
+from ..utils import get_const_tuple, traverse_inline
+from ..nn.utils import get_pad_tuple
+from .. import tag
+
+# Compute function
+def dense_gemm_compute(
+    cfg, data, weight, bias=None, out_dtype=None, transpose_a=False, 
transpose_b=True
+):
+    """
+    Compute dense using GeMM.
+
+    transpose_b : Optional[bool] = True
+    Whether the weight tensor is in transposed format.

Review Comment:
   Nit: We should probably either include all function parameters and return 
variables in the docstring or none at all.



##########
python/tvm/relay/op/strategy/arm_cpu.py:
##########
@@ -729,6 +729,17 @@ def schedule_dense_arm_cpu(attrs, inputs, out_type, 
target):
             plevel=12,
         )
 
+    if (
+        data.dtype in ["float16", "float32"]
+        and weight.dtype in ["float16", "float32"]
+        and out_type.dtype in ["float16", "float32"]
+    ):

Review Comment:
   We should probably also guard the selection of this strategy by a 
`target.features.is_aarch64` condition to make sure the target can run the SIMD 
instructions.



##########
python/tvm/topi/arm_cpu/dense_gemm.py:
##########
@@ -0,0 +1,157 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-variable, too-many-locals
+"""GEMM Convolution schedule on AArch64"""
+import tvm
+from tvm.target import Target
+from tvm import te
+from tvm.topi import nn
+from tvm.topi.arm_cpu.arm_utils import get_tiling_A, get_tiling_B_transformed
+from ..utils import get_const_tuple, traverse_inline
+from ..nn.utils import get_pad_tuple
+from .. import tag
+
+# Compute function
+def dense_gemm_compute(
+    cfg, data, weight, bias=None, out_dtype=None, transpose_a=False, 
transpose_b=True
+):
+    """
+    Compute dense using GeMM.
+
+    transpose_b : Optional[bool] = True
+    Whether the weight tensor is in transposed format.
+    """
+
+    if out_dtype is None:
+        out_dtype = data.dtype
+    M, K = get_const_tuple(data.shape)  # batch, in_dim
+    if bool(transpose_b):  # out_dim
+        (N, _) = get_const_tuple(weight.shape)
+    else:
+        (_, N) = get_const_tuple(weight.shape)
+
+    in_dtype = data.dtype
+
+    tile_M, tile_K_A = get_tiling_A(False, in_dtype)
+    tile_N, tile_K_B = get_tiling_B_transformed(False, out_dtype, False)

Review Comment:
   We should stick with either `in_dtype` or `out_dtype` being passed to select 
a tiling strategy.  
   Based on the compute definition for `C` later on, passing `out_dtype` would 
make sense here as that's what the computation is going to be in.



##########
python/tvm/topi/arm_cpu/dense_gemm.py:
##########
@@ -0,0 +1,157 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-variable, too-many-locals
+"""GEMM Convolution schedule on AArch64"""
+import tvm
+from tvm.target import Target
+from tvm import te
+from tvm.topi import nn
+from tvm.topi.arm_cpu.arm_utils import get_tiling_A, get_tiling_B_transformed
+from ..utils import get_const_tuple, traverse_inline
+from ..nn.utils import get_pad_tuple
+from .. import tag
+
+# Compute function
+def dense_gemm_compute(
+    cfg, data, weight, bias=None, out_dtype=None, transpose_a=False, 
transpose_b=True
+):
+    """
+    Compute dense using GeMM.
+
+    transpose_b : Optional[bool] = True
+    Whether the weight tensor is in transposed format.
+    """
+
+    if out_dtype is None:
+        out_dtype = data.dtype
+    M, K = get_const_tuple(data.shape)  # batch, in_dim
+    if bool(transpose_b):  # out_dim
+        (N, _) = get_const_tuple(weight.shape)
+    else:
+        (_, N) = get_const_tuple(weight.shape)
+
+    in_dtype = data.dtype
+
+    tile_M, tile_K_A = get_tiling_A(False, in_dtype)
+    tile_N, tile_K_B = get_tiling_B_transformed(False, out_dtype, False)
+
+    pad_M = 0
+    pad_K = 0
+    pad_N = 0
+
+    if M % tile_M != 0:
+        pad_M = tile_M - (M % tile_M)
+
+    if K % tile_K_A != 0:
+        pad_K = tile_K_A - (K % tile_K_A)
+
+    M_padded = M + pad_M
+    K_padded = K + pad_K
+    k = te.reduce_axis((0, K_padded), name="k")
+
+    pad_before = (0, 0)
+    pad_after = (pad_M, pad_K)
+
+    if pad_K != 0:
+        data = nn.pad(data, pad_before=pad_before, pad_after=pad_after, 
name="A_padded_K")
+    elif pad_M != 0:
+        data = nn.pad(data, pad_before=pad_before, pad_after=pad_after, 
name="A_padded_M")
+
+    if N % tile_N != 0:
+        pad_N = tile_N - (N % tile_N)
+    N_padded = N + pad_N

Review Comment:
   Same comment about replacing with an `arm_utils.pad_dim_to_multiple()` call 
and maybe grouping it alongside the others above.



##########
python/tvm/topi/arm_cpu/dense_gemm.py:
##########
@@ -0,0 +1,157 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-variable, too-many-locals
+"""GEMM Convolution schedule on AArch64"""
+import tvm
+from tvm.target import Target
+from tvm import te
+from tvm.topi import nn
+from tvm.topi.arm_cpu.arm_utils import get_tiling_A, get_tiling_B_transformed
+from ..utils import get_const_tuple, traverse_inline
+from ..nn.utils import get_pad_tuple
+from .. import tag
+
+# Compute function
+def dense_gemm_compute(
+    cfg, data, weight, bias=None, out_dtype=None, transpose_a=False, 
transpose_b=True
+):
+    """
+    Compute dense using GeMM.
+
+    transpose_b : Optional[bool] = True
+    Whether the weight tensor is in transposed format.
+    """
+
+    if out_dtype is None:
+        out_dtype = data.dtype
+    M, K = get_const_tuple(data.shape)  # batch, in_dim
+    if bool(transpose_b):  # out_dim
+        (N, _) = get_const_tuple(weight.shape)
+    else:
+        (_, N) = get_const_tuple(weight.shape)
+
+    in_dtype = data.dtype
+
+    tile_M, tile_K_A = get_tiling_A(False, in_dtype)
+    tile_N, tile_K_B = get_tiling_B_transformed(False, out_dtype, False)
+
+    pad_M = 0
+    pad_K = 0
+    pad_N = 0
+
+    if M % tile_M != 0:
+        pad_M = tile_M - (M % tile_M)
+
+    if K % tile_K_A != 0:
+        pad_K = tile_K_A - (K % tile_K_A)
+
+    M_padded = M + pad_M
+    K_padded = K + pad_K
+    k = te.reduce_axis((0, K_padded), name="k")
+
+    pad_before = (0, 0)
+    pad_after = (pad_M, pad_K)
+
+    if pad_K != 0:
+        data = nn.pad(data, pad_before=pad_before, pad_after=pad_after, 
name="A_padded_K")
+    elif pad_M != 0:
+        data = nn.pad(data, pad_before=pad_before, pad_after=pad_after, 
name="A_padded_M")
+
+    if N % tile_N != 0:
+        pad_N = tile_N - (N % tile_N)
+    N_padded = N + pad_N
+
+    if bool(transpose_b):
+        weight = te.compute(
+            (K_padded, N_padded), lambda x, y: weight[y, x], 
name="weight_transposed"
+        )
+
+    if pad_K != 0 or pad_N != 0:
+        weight = nn.pad(weight, pad_before=(0, 0), pad_after=(pad_N, pad_K), 
name="weight_padded")
+
+    C = te.compute(
+        (M_padded, N_padded),
+        lambda x, y: te.sum(
+            data[x, k].astype(out_dtype) * weight[k, y].astype(out_dtype),
+            axis=k,
+        ).astype(out_dtype),
+        name="C",
+    )
+
+    if bias is not None:
+        C = te.compute(
+            (M_padded, N_padded),
+            lambda i, j: C[i, j] + bias[j].astype(out_dtype),
+            tag=tag.BROADCAST,
+            name="dense_biased_output",
+        )
+
+    zero = (

Review Comment:
   Nit: We should include a 
[comment](https://github.com/apache/tvm/blob/main/python/tvm/topi/arm_cpu/conv2d_gemm.py#L313-L314)
 explaining what this `zero` variable is for, since it looks quite odd 
otherwise.



##########
python/tvm/topi/arm_cpu/dense_gemm.py:
##########
@@ -0,0 +1,157 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-variable, too-many-locals
+"""GEMM Convolution schedule on AArch64"""
+import tvm
+from tvm.target import Target
+from tvm import te
+from tvm.topi import nn
+from tvm.topi.arm_cpu.arm_utils import get_tiling_A, get_tiling_B_transformed
+from ..utils import get_const_tuple, traverse_inline
+from ..nn.utils import get_pad_tuple
+from .. import tag
+
+# Compute function
+def dense_gemm_compute(
+    cfg, data, weight, bias=None, out_dtype=None, transpose_a=False, 
transpose_b=True
+):
+    """
+    Compute dense using GeMM.
+
+    transpose_b : Optional[bool] = True
+    Whether the weight tensor is in transposed format.
+    """
+
+    if out_dtype is None:
+        out_dtype = data.dtype
+    M, K = get_const_tuple(data.shape)  # batch, in_dim
+    if bool(transpose_b):  # out_dim
+        (N, _) = get_const_tuple(weight.shape)
+    else:
+        (_, N) = get_const_tuple(weight.shape)
+
+    in_dtype = data.dtype
+
+    tile_M, tile_K_A = get_tiling_A(False, in_dtype)
+    tile_N, tile_K_B = get_tiling_B_transformed(False, out_dtype, False)
+
+    pad_M = 0
+    pad_K = 0
+    pad_N = 0
+
+    if M % tile_M != 0:
+        pad_M = tile_M - (M % tile_M)
+
+    if K % tile_K_A != 0:
+        pad_K = tile_K_A - (K % tile_K_A)
+
+    M_padded = M + pad_M
+    K_padded = K + pad_K
+    k = te.reduce_axis((0, K_padded), name="k")
+
+    pad_before = (0, 0)
+    pad_after = (pad_M, pad_K)
+
+    if pad_K != 0:
+        data = nn.pad(data, pad_before=pad_before, pad_after=pad_after, 
name="A_padded_K")
+    elif pad_M != 0:
+        data = nn.pad(data, pad_before=pad_before, pad_after=pad_after, 
name="A_padded_M")
+
+    if N % tile_N != 0:
+        pad_N = tile_N - (N % tile_N)
+    N_padded = N + pad_N
+
+    if bool(transpose_b):
+        weight = te.compute(
+            (K_padded, N_padded), lambda x, y: weight[y, x], 
name="weight_transposed"
+        )
+
+    if pad_K != 0 or pad_N != 0:
+        weight = nn.pad(weight, pad_before=(0, 0), pad_after=(pad_N, pad_K), 
name="weight_padded")
+
+    C = te.compute(
+        (M_padded, N_padded),
+        lambda x, y: te.sum(
+            data[x, k].astype(out_dtype) * weight[k, y].astype(out_dtype),
+            axis=k,
+        ).astype(out_dtype),
+        name="C",
+    )
+
+    if bias is not None:
+        C = te.compute(
+            (M_padded, N_padded),
+            lambda i, j: C[i, j] + bias[j].astype(out_dtype),
+            tag=tag.BROADCAST,
+            name="dense_biased_output",
+        )
+
+    zero = (
+        tvm.tir.const(1, C.dtype) * C[0, N_padded - 1]
+        - tvm.tir.const(1, C.dtype) * C[0, N_padded - 1]
+    )
+
+    out = te.compute(
+        (M, N), lambda x, y: (C[x, y] + zero).astype(out_dtype), 
name="dense_gemm_output"
+    )
+
+    return out
+
+
+def _dense_gemm_schedule_template(s, out):

Review Comment:
   Just curious, why the "_template" suffix?



##########
python/tvm/topi/arm_cpu/dense_gemm.py:
##########
@@ -0,0 +1,157 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-variable, too-many-locals
+"""GEMM Convolution schedule on AArch64"""
+import tvm
+from tvm.target import Target
+from tvm import te
+from tvm.topi import nn
+from tvm.topi.arm_cpu.arm_utils import get_tiling_A, get_tiling_B_transformed
+from ..utils import get_const_tuple, traverse_inline
+from ..nn.utils import get_pad_tuple
+from .. import tag
+
+# Compute function
+def dense_gemm_compute(
+    cfg, data, weight, bias=None, out_dtype=None, transpose_a=False, 
transpose_b=True
+):
+    """
+    Compute dense using GeMM.
+
+    transpose_b : Optional[bool] = True
+    Whether the weight tensor is in transposed format.
+    """
+
+    if out_dtype is None:
+        out_dtype = data.dtype
+    M, K = get_const_tuple(data.shape)  # batch, in_dim
+    if bool(transpose_b):  # out_dim
+        (N, _) = get_const_tuple(weight.shape)
+    else:
+        (_, N) = get_const_tuple(weight.shape)
+
+    in_dtype = data.dtype
+
+    tile_M, tile_K_A = get_tiling_A(False, in_dtype)
+    tile_N, tile_K_B = get_tiling_B_transformed(False, out_dtype, False)
+
+    pad_M = 0
+    pad_K = 0
+    pad_N = 0
+
+    if M % tile_M != 0:
+        pad_M = tile_M - (M % tile_M)
+
+    if K % tile_K_A != 0:
+        pad_K = tile_K_A - (K % tile_K_A)
+
+    M_padded = M + pad_M
+    K_padded = K + pad_K
+    k = te.reduce_axis((0, K_padded), name="k")
+
+    pad_before = (0, 0)
+    pad_after = (pad_M, pad_K)
+
+    if pad_K != 0:
+        data = nn.pad(data, pad_before=pad_before, pad_after=pad_after, 
name="A_padded_K")
+    elif pad_M != 0:
+        data = nn.pad(data, pad_before=pad_before, pad_after=pad_after, 
name="A_padded_M")

Review Comment:
   Since we don't treat "A_padded_K" and "A_padded_M" differently in the 
schedule, I think we can merge the two cases into a single "A_padded" 
definition like you did [for the 
weights](https://github.com/apache/tvm/pull/17091/files#diff-e6cbc7c422b4f68f0b7a793ca4e158929b4366eaa92700c32116f08b643d8f3dR83-R84).



##########
python/tvm/topi/arm_cpu/dense_gemm.py:
##########
@@ -0,0 +1,157 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-variable, too-many-locals
+"""GEMM Convolution schedule on AArch64"""
+import tvm
+from tvm.target import Target
+from tvm import te
+from tvm.topi import nn
+from tvm.topi.arm_cpu.arm_utils import get_tiling_A, get_tiling_B_transformed
+from ..utils import get_const_tuple, traverse_inline
+from ..nn.utils import get_pad_tuple
+from .. import tag
+
+# Compute function
+def dense_gemm_compute(
+    cfg, data, weight, bias=None, out_dtype=None, transpose_a=False, 
transpose_b=True
+):
+    """
+    Compute dense using GeMM.
+
+    transpose_b : Optional[bool] = True
+    Whether the weight tensor is in transposed format.
+    """
+
+    if out_dtype is None:
+        out_dtype = data.dtype
+    M, K = get_const_tuple(data.shape)  # batch, in_dim
+    if bool(transpose_b):  # out_dim
+        (N, _) = get_const_tuple(weight.shape)
+    else:
+        (_, N) = get_const_tuple(weight.shape)
+
+    in_dtype = data.dtype
+
+    tile_M, tile_K_A = get_tiling_A(False, in_dtype)
+    tile_N, tile_K_B = get_tiling_B_transformed(False, out_dtype, False)
+
+    pad_M = 0
+    pad_K = 0
+    pad_N = 0
+
+    if M % tile_M != 0:
+        pad_M = tile_M - (M % tile_M)
+
+    if K % tile_K_A != 0:
+        pad_K = tile_K_A - (K % tile_K_A)
+
+    M_padded = M + pad_M
+    K_padded = K + pad_K
+    k = te.reduce_axis((0, K_padded), name="k")
+
+    pad_before = (0, 0)
+    pad_after = (pad_M, pad_K)
+
+    if pad_K != 0:
+        data = nn.pad(data, pad_before=pad_before, pad_after=pad_after, 
name="A_padded_K")
+    elif pad_M != 0:
+        data = nn.pad(data, pad_before=pad_before, pad_after=pad_after, 
name="A_padded_M")
+
+    if N % tile_N != 0:
+        pad_N = tile_N - (N % tile_N)
+    N_padded = N + pad_N
+
+    if bool(transpose_b):
+        weight = te.compute(
+            (K_padded, N_padded), lambda x, y: weight[y, x], 
name="weight_transposed"
+        )
+
+    if pad_K != 0 or pad_N != 0:
+        weight = nn.pad(weight, pad_before=(0, 0), pad_after=(pad_N, pad_K), 
name="weight_padded")
+
+    C = te.compute(
+        (M_padded, N_padded),
+        lambda x, y: te.sum(
+            data[x, k].astype(out_dtype) * weight[k, y].astype(out_dtype),
+            axis=k,
+        ).astype(out_dtype),
+        name="C",
+    )
+
+    if bias is not None:
+        C = te.compute(
+            (M_padded, N_padded),
+            lambda i, j: C[i, j] + bias[j].astype(out_dtype),
+            tag=tag.BROADCAST,
+            name="dense_biased_output",
+        )
+
+    zero = (
+        tvm.tir.const(1, C.dtype) * C[0, N_padded - 1]
+        - tvm.tir.const(1, C.dtype) * C[0, N_padded - 1]
+    )
+
+    out = te.compute(
+        (M, N), lambda x, y: (C[x, y] + zero).astype(out_dtype), 
name="dense_gemm_output"
+    )
+
+    return out
+
+
+def _dense_gemm_schedule_template(s, out):
+    C = out.op.input_tensors[0]
+    A = C.op.input_tensors[0]
+    in_type = A.dtype
+    y_tile_size, _ = get_tiling_B_transformed(False, in_type)
+    if C.op.name == "dense_biased_output":
+        s[C].compute_inline()
+        C = C.op.input_tensors[0]
+    x, y = s[C].op.axis
+    (k,) = s[C].op.reduce_axis
+    k_outer, k_inner = s[C].split(k, factor=4)
+    x_outer, x_inner = s[C].split(x, factor=4)
+    y_outer, y_inner = s[C].split(y, factor=y_tile_size)

Review Comment:
   When I tried simplifying the conv2d schedule to only these 3 splits I 
noticed a performance regression in some cases 
(https://github.com/apache/tvm/pull/16951), so I think it's best to stay on the 
safe side and use something like the version with 4 splits 
[here](https://github.com/apache/tvm/blob/main/python/tvm/topi/arm_cpu/conv2d_gemm.py#L491-L494).



##########
tests/python/relay/test_dense.py:
##########
@@ -0,0 +1,49 @@
+import tvm

Review Comment:
   This test could probably go alongside the SME one 
[here](https://github.com/apache/tvm/blob/main/tests/python/relay/strategy/arm_cpu/test_dense.py#L111),
 since it's `arm_cpu` specific, but I don't mind leaving it here either.



##########
python/tvm/topi/arm_cpu/dense_gemm.py:
##########
@@ -0,0 +1,157 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-variable, too-many-locals
+"""GEMM Convolution schedule on AArch64"""
+import tvm
+from tvm.target import Target
+from tvm import te
+from tvm.topi import nn
+from tvm.topi.arm_cpu.arm_utils import get_tiling_A, get_tiling_B_transformed
+from ..utils import get_const_tuple, traverse_inline
+from ..nn.utils import get_pad_tuple
+from .. import tag
+
+# Compute function
+def dense_gemm_compute(
+    cfg, data, weight, bias=None, out_dtype=None, transpose_a=False, 
transpose_b=True
+):
+    """
+    Compute dense using GeMM.
+
+    transpose_b : Optional[bool] = True
+    Whether the weight tensor is in transposed format.
+    """
+
+    if out_dtype is None:
+        out_dtype = data.dtype
+    M, K = get_const_tuple(data.shape)  # batch, in_dim
+    if bool(transpose_b):  # out_dim
+        (N, _) = get_const_tuple(weight.shape)
+    else:
+        (_, N) = get_const_tuple(weight.shape)
+
+    in_dtype = data.dtype
+
+    tile_M, tile_K_A = get_tiling_A(False, in_dtype)
+    tile_N, tile_K_B = get_tiling_B_transformed(False, out_dtype, False)
+
+    pad_M = 0
+    pad_K = 0
+    pad_N = 0
+
+    if M % tile_M != 0:
+        pad_M = tile_M - (M % tile_M)
+
+    if K % tile_K_A != 0:
+        pad_K = tile_K_A - (K % tile_K_A)
+
+    M_padded = M + pad_M
+    K_padded = K + pad_K
+    k = te.reduce_axis((0, K_padded), name="k")
+
+    pad_before = (0, 0)
+    pad_after = (pad_M, pad_K)
+
+    if pad_K != 0:
+        data = nn.pad(data, pad_before=pad_before, pad_after=pad_after, 
name="A_padded_K")
+    elif pad_M != 0:
+        data = nn.pad(data, pad_before=pad_before, pad_after=pad_after, 
name="A_padded_M")
+
+    if N % tile_N != 0:
+        pad_N = tile_N - (N % tile_N)
+    N_padded = N + pad_N
+
+    if bool(transpose_b):
+        weight = te.compute(
+            (K_padded, N_padded), lambda x, y: weight[y, x], 
name="weight_transposed"
+        )
+
+    if pad_K != 0 or pad_N != 0:
+        weight = nn.pad(weight, pad_before=(0, 0), pad_after=(pad_N, pad_K), 
name="weight_padded")
+
+    C = te.compute(
+        (M_padded, N_padded),
+        lambda x, y: te.sum(
+            data[x, k].astype(out_dtype) * weight[k, y].astype(out_dtype),
+            axis=k,
+        ).astype(out_dtype),
+        name="C",
+    )
+
+    if bias is not None:
+        C = te.compute(
+            (M_padded, N_padded),
+            lambda i, j: C[i, j] + bias[j].astype(out_dtype),
+            tag=tag.BROADCAST,
+            name="dense_biased_output",
+        )
+
+    zero = (
+        tvm.tir.const(1, C.dtype) * C[0, N_padded - 1]
+        - tvm.tir.const(1, C.dtype) * C[0, N_padded - 1]
+    )
+
+    out = te.compute(
+        (M, N), lambda x, y: (C[x, y] + zero).astype(out_dtype), 
name="dense_gemm_output"
+    )
+
+    return out
+
+
+def _dense_gemm_schedule_template(s, out):
+    C = out.op.input_tensors[0]
+    A = C.op.input_tensors[0]
+    in_type = A.dtype
+    y_tile_size, _ = get_tiling_B_transformed(False, in_type)

Review Comment:
   `out_dtype` here as well (to match the compute definition)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to