This is an automated email from the ASF dual-hosted git repository.

leandron pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 92eeef6  Calculate CMSIS-NN buffer size with respect to architecture 
extensions (#9338)
92eeef6 is described below

commit 92eeef6ddde9b113fd8263f1fc20f08c1c3fea0d
Author: Christopher Sidebottom <[email protected]>
AuthorDate: Wed Jan 5 10:14:55 2022 +0000

    Calculate CMSIS-NN buffer size with respect to architecture extensions 
(#9338)
    
    This correctly calculates the buffer sizes for a variety of targets
    based on the `-mcpu` and `-mattr` flags passed to the `cmsis-nn` code
    generator.
    
    Added for Conv2d, Depthwise Conv2d and Average Pool.
---
 cmake/modules/contrib/CMSISNN.cmake                |   1 +
 python/tvm/driver/tvmc/composite_target.py         |   2 +-
 src/relay/backend/contrib/cmsisnn/buffer_size.cc   |  78 ++++++++
 src/relay/backend/contrib/cmsisnn/buffer_size.h    |  94 ++++++++++
 .../backend/contrib/cmsisnn/compiler_attrs.cc      |  75 ++++++++
 src/relay/backend/contrib/cmsisnn/compiler_attrs.h |  75 ++++++++
 src/relay/backend/contrib/cmsisnn/relay_to_tir.cc  |  33 +++-
 .../backend/contrib/cmsisnn/buffer_size_test.cc    | 206 +++++++++++++++++++++
 .../backend/contrib/cmsisnn/compiler_attrs_test.cc | 157 ++++++++++++++++
 tests/python/relay/aot/aot_test_utils.py           |   5 +
 10 files changed, 718 insertions(+), 8 deletions(-)

diff --git a/cmake/modules/contrib/CMSISNN.cmake 
b/cmake/modules/contrib/CMSISNN.cmake
index 50a3464..73ecd59 100644
--- a/cmake/modules/contrib/CMSISNN.cmake
+++ b/cmake/modules/contrib/CMSISNN.cmake
@@ -16,6 +16,7 @@
 # under the License.
 
 if(USE_CMSISNN)
+  add_definitions(-DTVM_USE_CMSISNN)
   message(STATUS "Build with CMSIS-NN support")
   tvm_file_glob(GLOB RELAY_CONTRIB_CMSISNN_SRCS 
src/relay/backend/contrib/cmsisnn/*.cc)
   list(APPEND COMPILER_SRCS ${RELAY_CONTRIB_CMSISNN_SRCS})
diff --git a/python/tvm/driver/tvmc/composite_target.py 
b/python/tvm/driver/tvmc/composite_target.py
index f347158..3b5ba9d 100644
--- a/python/tvm/driver/tvmc/composite_target.py
+++ b/python/tvm/driver/tvmc/composite_target.py
@@ -53,7 +53,7 @@ REGISTERED_CODEGEN = {
         "pass_pipeline": partition_for_arm_compute_lib,
     },
     "cmsis-nn": {
-        "config_key": None,
+        "config_key": "relay.ext.cmsisnn.options",
         "pass_pipeline": partition_for_cmsisnn,
     },
     "ethos-n77": {
diff --git a/src/relay/backend/contrib/cmsisnn/buffer_size.cc 
b/src/relay/backend/contrib/cmsisnn/buffer_size.cc
new file mode 100644
index 0000000..2502a09
--- /dev/null
+++ b/src/relay/backend/contrib/cmsisnn/buffer_size.cc
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <tvm/ir/attrs.h>
+#include <tvm/ir/transform.h>
+
+#include "compiler_attrs.h"
+
+namespace tvm {
+namespace relay {
+namespace contrib {
+namespace cmsisnn {
+
+int Conv2dBufferSize(CMSISNNFlags flags, int32_t padding_w, int32_t padding_h, 
int32_t input_n,
+                     int32_t input_h, int32_t input_c, int32_t output_h, 
int32_t output_w,
+                     int32_t stride_w, int32_t stride_h, int32_t filter_w, 
int32_t filter_h) {
+  bool is1x1 = (padding_w == 0) && (padding_h == 0) && (input_c % 4 == 0) && 
(stride_w == 1) &&
+               (stride_h == 1) && (filter_w == 1) && (filter_h == 1);
+  bool is1xN =
+      (output_h == 1) && (input_h == 1) && (filter_h == 1) && (output_w % 4 == 
0) && (input_n == 1);
+
+  if (is1x1) {
+    return 0;
+  }
+
+  if (is1xN) {
+    if (flags.dsp && !flags.mve) {
+      return (2 * input_c * filter_w * filter_h) * (int32_t)sizeof(int16_t);
+    }
+    return 0;
+  }
+
+  if (flags.dsp) {
+    return (2 * input_c * filter_w * filter_h) * (int32_t)sizeof(int16_t);
+  }
+  return 0;
+}
+
+int DepthwiseConv2dBufferSize(CMSISNNFlags flags, int32_t input_n, int32_t 
input_c,
+                              int32_t output_c, int32_t filter_w, int32_t 
filter_h) {
+  if (input_c == output_c && input_n == 1) {
+    if (flags.mve) {
+      return (2 * input_c * filter_w * filter_h) * (int32_t)sizeof(int16_t) + 
4;
+    }
+    if (flags.dsp) {
+      return (input_c * filter_w * filter_h) * (int32_t)sizeof(int16_t);
+    }
+  }
+  return 0;
+}
+
+int AvgPoolBufferSize(CMSISNNFlags flags, int32_t input_c) {
+  if (flags.dsp && !flags.mve) {
+    return (input_c * sizeof(int32_t));
+  }
+  return 0;
+}
+
+}  // namespace cmsisnn
+}  // namespace contrib
+}  // namespace relay
+}  // namespace tvm
diff --git a/src/relay/backend/contrib/cmsisnn/buffer_size.h 
b/src/relay/backend/contrib/cmsisnn/buffer_size.h
new file mode 100644
index 0000000..dec3c3e
--- /dev/null
+++ b/src/relay/backend/contrib/cmsisnn/buffer_size.h
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/contrib/cmsisnn/buffer_size.h
+ * \brief CMSIS-NN Buffer Size calculation functions
+ */
+
+#ifndef TVM_RELAY_BACKEND_CONTRIB_CMSISNN_BUFFER_SIZE_H_
+#define TVM_RELAY_BACKEND_CONTRIB_CMSISNN_BUFFER_SIZE_H_
+
+#include <tvm/ir/transform.h>
+
+#include "compiler_attrs.h"
+
+namespace tvm {
+namespace relay {
+namespace contrib {
+namespace cmsisnn {
+
+/*!
+ * \brief Calculates the appropriate buffer size for CMSIS-NN Convolutions
+ * See:
+ * 
https://github.com/ARM-software/CMSIS_5/blob/8c60448c0e1e50e426180b26db9bc31ddf774361/CMSIS/NN/Source/ConvolutionFunctions/arm_convolve_wrapper_s8.c#L108-L127
+ *
+ * \param flags - CMSIS-NN feature flags
+ * \param padding_w - Width padding
+ * \param padding_h - Height padding
+ * \param input_n - Input batch size
+ * \param input_h - Input height
+ * \param input_c - Input channels
+ * \param output_h - Output height
+ * \param output_w - Output width
+ * \param stride_w - Stride width
+ * \param stride_h - Stride height
+ * \param filter_w - Filter width
+ * \param filter_h - Filter height
+ *
+ * \return Size of buffer to allocate for convolution
+ */
+int Conv2dBufferSize(CMSISNNFlags flags, int32_t padding_w, int32_t padding_h, 
int32_t input_n,
+                     int32_t input_h, int32_t input_c, int32_t output_h, 
int32_t output_w,
+                     int32_t stride_w, int32_t stride_h, int32_t filter_w, 
int32_t filter_h);
+
+/*!
+ * \brief Calculates the appropriate buffer size for CMSIS-NN Depthwise 
Convolutions
+ * See:
+ * 
https://github.com/ARM-software/CMSIS_5/blob/325443e52637b6c7eedbd160d238a6c462e89c9f/CMSIS/NN/Source/ConvolutionFunctions/arm_depthwise_conv_wrapper_s8.c#L115-L129
+ *
+ * \param flags - CMSIS-NN feature flags
+ * \param input_n - Input batch size
+ * \param input_c - Input channels
+ * \param output_c - Output channels
+ * \param filter_w - Filter width
+ * \param filter_h - Filter height
+ *
+ * \return Size of buffer to allocate for depthwise convolution
+ */
+int DepthwiseConv2dBufferSize(CMSISNNFlags flags, int32_t input_n, int32_t 
input_c,
+                              int32_t output_c, int32_t filter_w, int32_t 
filter_h);
+
+/*!
+ * \brief Calculates the appropriate buffer size for CMSIS-NN Average Pooling
+ * See:
+ * 
https://github.com/ARM-software/CMSIS_5/blob/bff28575f0c96a4ee9008947fea2b018a69b4900/CMSIS/NN/Source/PoolingFunctions/arm_avgpool_s8.c#L388-L398
+ *
+ * \param input_c - Input channels
+ *
+ * \return Size of buffer to allocate for average pooling
+ */
+int AvgPoolBufferSize(CMSISNNFlags flags, int32_t input_c);
+
+}  // namespace cmsisnn
+}  // namespace contrib
+}  // namespace relay
+}  // namespace tvm
+
+#endif  // TVM_RELAY_BACKEND_CONTRIB_CMSISNN_BUFFER_SIZE_H_
diff --git a/src/relay/backend/contrib/cmsisnn/compiler_attrs.cc 
b/src/relay/backend/contrib/cmsisnn/compiler_attrs.cc
new file mode 100644
index 0000000..c6fb8f3
--- /dev/null
+++ b/src/relay/backend/contrib/cmsisnn/compiler_attrs.cc
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "compiler_attrs.h"
+
+#include <tvm/ir/attrs.h>
+#include <tvm/ir/transform.h>
+
+#include <string>
+
+namespace tvm {
+namespace relay {
+namespace contrib {
+namespace cmsisnn {
+
+static const char* mveCPUs[] = {"cortex-m55"};
+static const char* dspCPUs[] = {"cortex-m7", "cortex-m33", "cortex-m35p"};
+
+TVM_REGISTER_NODE_TYPE(CMSISNNCompilerConfigNode);
+TVM_REGISTER_PASS_CONFIG_OPTION("relay.ext.cmsisnn.options", 
CMSISNNCompilerConfig);
+
+template <typename Container>
+static inline bool MatchesCpu(std::string mcpu, const Container& cpus) {
+  auto matches_cpu = [mcpu](const char* cpu) { return mcpu.find(cpu) == 0; };
+  return std::find_if(std::begin(cpus), std::end(cpus), matches_cpu) != 
std::end(cpus);
+}
+
+static inline bool HasFlag(std::string attr, std::string flag) {
+  return attr.find(flag) != std::string::npos;
+}
+
+CMSISNNFlags GetCompilerFlags(const tvm::transform::PassContext& ctx) {
+  auto cfg = 
ctx->GetConfig<CMSISNNCompilerConfig>("relay.ext.cmsisnn.options");
+  if (!cfg.defined()) {
+    return kNoExt;
+  }
+
+  std::string mcpu = cfg.value()->mcpu;
+  std::string mattr = cfg.value()->mattr;
+
+  bool nomve = HasFlag(mcpu, "+nomve") || HasFlag(mattr, "+nomve");
+  bool nodsp = HasFlag(mcpu, "+nodsp") || HasFlag(mattr, "+nodsp");
+
+  auto has_mve = MatchesCpu(mcpu, mveCPUs);
+  if (has_mve && !nomve && !nodsp) {
+    return kHasMVE;
+  }
+
+  auto has_dsp = MatchesCpu(mcpu, dspCPUs);
+  if (has_dsp && !nodsp) {
+    return kHasDSP;
+  }
+
+  return kNoExt;
+}
+
+}  // namespace cmsisnn
+}  // namespace contrib
+}  // namespace relay
+}  // namespace tvm
diff --git a/src/relay/backend/contrib/cmsisnn/compiler_attrs.h 
b/src/relay/backend/contrib/cmsisnn/compiler_attrs.h
new file mode 100644
index 0000000..75005d9
--- /dev/null
+++ b/src/relay/backend/contrib/cmsisnn/compiler_attrs.h
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/contrib/cmsisnn/compiler_attrs.h
+ * \brief CMSIS-NN Compiler Attribute functionality
+ */
+
+#ifndef TVM_RELAY_BACKEND_CONTRIB_CMSISNN_COMPILER_ATTRS_H_
+#define TVM_RELAY_BACKEND_CONTRIB_CMSISNN_COMPILER_ATTRS_H_
+
+#include <tvm/ir/transform.h>
+
+namespace tvm {
+namespace relay {
+namespace contrib {
+namespace cmsisnn {
+
+/*! \brief Attributes to store the compiler options for CMSIS-NN. */
+struct CMSISNNCompilerConfigNode : public 
tvm::AttrsNode<CMSISNNCompilerConfigNode> {
+  String mcpu;
+  String mattr;
+
+  TVM_DECLARE_ATTRS(CMSISNNCompilerConfigNode, 
"ext.attrs.CMSISNNCompilerConfigNode") {
+    TVM_ATTR_FIELD(mcpu)
+        .describe(
+            "The CPU to configure CMSIS-NN for (i.e. cortex-m55, cortex-m4), 
can also include "
+            "attributes (i.e. cortex-m55+nomve)")
+        .set_default("");
+    TVM_ATTR_FIELD(mattr)
+        .describe("The attributes to configure CMSIS-NN (i.e. +nodsp, +nomve)")
+        .set_default("");
+  }
+};
+
+class CMSISNNCompilerConfig : public Attrs {
+ public:
+  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(CMSISNNCompilerConfig, Attrs,
+                                            CMSISNNCompilerConfigNode);
+};
+
+/*! \brief Flags to configure the calculations for CMSIS-NN. */
+struct CMSISNNFlags {
+  bool dsp;  // Enable or disable dsp buffers
+  bool mve;  // Enable or disable mve buffers
+};
+
+constexpr CMSISNNFlags kNoExt = {.dsp = false, .mve = false};
+constexpr CMSISNNFlags kHasDSP = {.dsp = true, .mve = false};
+constexpr CMSISNNFlags kHasMVE = {.dsp = true, .mve = true};
+
+CMSISNNFlags GetCompilerFlags(const tvm::transform::PassContext& ctx);
+
+}  // namespace cmsisnn
+}  // namespace contrib
+}  // namespace relay
+}  // namespace tvm
+
+#endif  // TVM_RELAY_BACKEND_CONTRIB_CMSISNN_COMPILER_ATTRS_H_
diff --git a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc 
b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
index 7f1582a..b874424 100644
--- a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
+++ b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
@@ -29,6 +29,8 @@
 
 #include "../../../qnn/utils.h"
 #include "../../../transforms/pattern_utils.h"
+#include "buffer_size.h"
+#include "compiler_attrs.h"
 
 namespace tvm {
 namespace relay {
@@ -159,6 +161,9 @@ class RelayToTIRVisitor : public MixedModeMutator {
     // CMSIS-NN data structure "cmsis_nn_dims" for ifm expects input layout as 
NHWC
     // This is the same layout we expect in Relay
     Array<PrimExpr> input_shape = 
conv2d_call->args[0]->type_as<TensorTypeNode>()->shape;
+    int32_t input_n = qnn::get_const_int(input_shape[0]);
+    int32_t input_h = qnn::get_const_int(input_shape[1]);
+    int32_t input_c = qnn::get_const_int(input_shape[3]);
 
     // CMSIS-NN data structure "cmsis_nn_dims" for weights expects following 
layouts
     // OHWI for Conv2D and IHWO for Depthwise convolutions
@@ -167,6 +172,9 @@ class RelayToTIRVisitor : public MixedModeMutator {
     Array<PrimExpr> bias_shape{1, 1, 1, out_channels};
 
     Array<PrimExpr> output_shape = 
conv2d_call->type_as<TensorTypeNode>()->shape;
+    int32_t output_h = qnn::get_const_int(output_shape[1]);
+    int32_t output_w = qnn::get_const_int(output_shape[2]);
+    int32_t output_c = qnn::get_const_int(output_shape[3]);
 
     int32_t depth_multiplier = -1;
     int kernel_pos_o = kernel_layout.find("O");
@@ -179,7 +187,8 @@ class RelayToTIRVisitor : public MixedModeMutator {
 
     // original filter_layout for depthwise is HWOI
     std::string cmsisnn_api = "arm_convolve_wrapper_s8";
-    if (depth_multiplier != -1) {
+    bool is_depthwise = depth_multiplier != -1;
+    if (is_depthwise) {
       cmsisnn_api = "arm_depthwise_conv_wrapper_s8";
       int filter_pos_h = kernel_layout.find("H");
       int filter_pos_w = kernel_layout.find("W");
@@ -187,6 +196,8 @@ class RelayToTIRVisitor : public MixedModeMutator {
                                              filter_shape[filter_pos_w], 
out_channels};
       filter_shape = depthwise_filter_shape;
     }
+    int32_t filter_h = qnn::get_const_int(filter_shape[1]);
+    int32_t filter_w = qnn::get_const_int(filter_shape[2]);
 
     tvm::Array<PrimExpr> call_ext_args = {tir::StringImm(cmsisnn_api), input, 
filter, multiplier};
     if (bias_add_call) {
@@ -195,11 +206,18 @@ class RelayToTIRVisitor : public MixedModeMutator {
     call_ext_args.push_back(shift);
     call_ext_args.push_back(output);
 
-    // 
https://github.com/ARM-software/CMSIS_5/blob/d788fd583984388553391de18afd8b4d2a146868/CMSIS/NN/Source/ConvolutionFunctions/arm_convolve_s8.c#L367
     std::string context_buffer_name = "NULL";
-    size_t context_buffer_size =
-        (2 * qnn::get_const_int(input_shape[3]) * 
qnn::get_const_int(filter_shape[2]) *
-         qnn::get_const_int(filter_shape[1]) * (int32_t)sizeof(int16_t));
+    CMSISNNFlags flags = GetCompilerFlags(transform::PassContext::Current());
+    size_t context_buffer_size;
+    if (is_depthwise) {
+      context_buffer_size =
+          DepthwiseConv2dBufferSize(flags, input_n, input_c, output_c, 
filter_w, filter_h);
+    } else {
+      context_buffer_size =
+          Conv2dBufferSize(flags, padding_w, padding_h, input_n, input_h, 
input_c, output_h,
+                           output_w, stride_w, stride_h, filter_w, filter_h);
+    }
+
     if (context_buffer_size) {
       context_buffer_name = "context_buffer_" + 
std::to_string(context_buffer_id_++);
     }
@@ -397,8 +415,9 @@ class RelayToTIRVisitor : public MixedModeMutator {
     int context_buffer_size = 0;
     std::string context_buffer_name = "NULL";
     if (pool_name == "cmsisnn.qnn_avg_pool2d") {
-      // TODO(@Mousius): Need to move this into buffer_size calculations
-      context_buffer_size = qnn::get_const_int(input_shape[3]) * 
sizeof(int32_t);
+      CMSISNNFlags flags = GetCompilerFlags(transform::PassContext::Current());
+      int32_t input_c = qnn::get_const_int(input_shape[3]);
+      context_buffer_size = AvgPoolBufferSize(flags, input_c);
       context_buffer_name = "context_buffer_" + 
std::to_string(context_buffer_id_++);
     }
     tvm::Array<PrimExpr> context_buffer_args = 
{tir::StringImm(context_buffer_name),
diff --git a/tests/cpp/relay/backend/contrib/cmsisnn/buffer_size_test.cc 
b/tests/cpp/relay/backend/contrib/cmsisnn/buffer_size_test.cc
new file mode 100644
index 0000000..c83daaf
--- /dev/null
+++ b/tests/cpp/relay/backend/contrib/cmsisnn/buffer_size_test.cc
@@ -0,0 +1,206 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifdef TVM_USE_CMSISNN
+
+#include "../../../../../../src/relay/backend/contrib/cmsisnn/buffer_size.h"
+
+#include <gtest/gtest.h>
+#include <tvm/ir/transform.h>
+
+#include <cmath>
+#include <random>
+#include <string>
+
+#include "../../../../../../src/relay/backend/contrib/cmsisnn/compiler_attrs.h"
+
+namespace tvm {
+namespace relay {
+namespace contrib {
+namespace cmsisnn {
+
+static std::random_device rd;
+static std::mt19937 gen(rd());
+static std::uniform_int_distribution<> fake_parameters(1, 100);
+
+class CMSISNNCalculatedBufferSize : public 
testing::TestWithParam<std::array<int32_t, 3>> {};
+
+TEST(CMSISNNConv2dBufferSize, Conv1x1) {
+  int32_t any = fake_parameters(gen);
+  auto conv2d_1x1 = [=](CMSISNNFlags flags, int32_t input_c) {
+    return Conv2dBufferSize(flags, 0, 0, any, any, input_c, any, any, 1, 1, 1, 
1);
+  };
+
+  ASSERT_EQ(conv2d_1x1(kNoExt, 4), 0);
+  ASSERT_EQ(conv2d_1x1(kNoExt, 8), 0);
+  ASSERT_EQ(conv2d_1x1(kNoExt, 12), 0);
+  ASSERT_EQ(conv2d_1x1(kNoExt, 16), 0);
+  ASSERT_EQ(conv2d_1x1(kNoExt, 32), 0);
+
+  ASSERT_EQ(conv2d_1x1(kHasDSP, 4), 0);
+  ASSERT_EQ(conv2d_1x1(kHasDSP, 8), 0);
+  ASSERT_EQ(conv2d_1x1(kHasDSP, 12), 0);
+  ASSERT_EQ(conv2d_1x1(kHasDSP, 16), 0);
+  ASSERT_EQ(conv2d_1x1(kHasDSP, 32), 0);
+
+  ASSERT_EQ(conv2d_1x1(kHasMVE, 4), 0);
+  ASSERT_EQ(conv2d_1x1(kHasMVE, 8), 0);
+  ASSERT_EQ(conv2d_1x1(kHasMVE, 12), 0);
+  ASSERT_EQ(conv2d_1x1(kHasMVE, 16), 0);
+  ASSERT_EQ(conv2d_1x1(kHasMVE, 32), 0);
+}
+
+TEST(CMSISNNConv2dBufferSize, Conv1xN) {
+  int32_t any = fake_parameters(gen);
+  int32_t input_c = fake_parameters(gen);
+  int32_t filter_w = fake_parameters(gen);
+  int32_t filter_h = 1;
+  int32_t calculated_buffer = (2 * input_c * filter_w * filter_h) * 
(int32_t)sizeof(int16_t);
+
+  auto conv2d_1xn = [=](CMSISNNFlags flags, int32_t output_w) {
+    return Conv2dBufferSize(flags, any, any, 1, 1, input_c, 1, output_w, any, 
any, filter_w,
+                            filter_h);
+  };
+
+  ASSERT_EQ(conv2d_1xn(kNoExt, 4), 0);
+  ASSERT_EQ(conv2d_1xn(kNoExt, 8), 0);
+  ASSERT_EQ(conv2d_1xn(kNoExt, 12), 0);
+  ASSERT_EQ(conv2d_1xn(kNoExt, 16), 0);
+  ASSERT_EQ(conv2d_1xn(kNoExt, 32), 0);
+
+  ASSERT_EQ(conv2d_1xn(kHasDSP, 4), calculated_buffer);
+  ASSERT_EQ(conv2d_1xn(kHasDSP, 8), calculated_buffer);
+  ASSERT_EQ(conv2d_1xn(kHasDSP, 12), calculated_buffer);
+  ASSERT_EQ(conv2d_1xn(kHasDSP, 16), calculated_buffer);
+  ASSERT_EQ(conv2d_1xn(kHasDSP, 32), calculated_buffer);
+
+  ASSERT_EQ(conv2d_1xn(kHasMVE, 4), 0);
+  ASSERT_EQ(conv2d_1xn(kHasMVE, 8), 0);
+  ASSERT_EQ(conv2d_1xn(kHasMVE, 12), 0);
+  ASSERT_EQ(conv2d_1xn(kHasMVE, 16), 0);
+  ASSERT_EQ(conv2d_1xn(kHasMVE, 32), 0);
+}
+
+TEST(CMSISNNConv2dBufferSize, Default) {
+  int32_t any = fake_parameters(gen);
+
+  int32_t input_c = fake_parameters(gen);
+  int32_t filter_w = fake_parameters(gen);
+  int32_t filter_h = fake_parameters(gen);
+  int32_t calculated_buffer = (2 * input_c * filter_w * filter_h) * 
(int32_t)sizeof(int16_t);
+
+  auto conv2d = [=](CMSISNNFlags flags, int32_t output_w) {
+    return Conv2dBufferSize(flags, any, any, 1, 1, input_c, 1, output_w, any, 
any, filter_w,
+                            filter_h);
+  };
+
+  ASSERT_EQ(conv2d(kNoExt, 4), 0);
+  ASSERT_EQ(conv2d(kNoExt, 8), 0);
+  ASSERT_EQ(conv2d(kNoExt, 12), 0);
+  ASSERT_EQ(conv2d(kNoExt, 16), 0);
+  ASSERT_EQ(conv2d(kNoExt, 32), 0);
+
+  ASSERT_EQ(conv2d(kHasDSP, 4), calculated_buffer);
+  ASSERT_EQ(conv2d(kHasDSP, 8), calculated_buffer);
+  ASSERT_EQ(conv2d(kHasDSP, 12), calculated_buffer);
+  ASSERT_EQ(conv2d(kHasDSP, 16), calculated_buffer);
+  ASSERT_EQ(conv2d(kHasDSP, 32), calculated_buffer);
+
+  ASSERT_EQ(conv2d(kHasMVE, 4), calculated_buffer);
+  ASSERT_EQ(conv2d(kHasMVE, 8), calculated_buffer);
+  ASSERT_EQ(conv2d(kHasMVE, 12), calculated_buffer);
+  ASSERT_EQ(conv2d(kHasMVE, 16), calculated_buffer);
+  ASSERT_EQ(conv2d(kHasMVE, 32), calculated_buffer);
+}
+
+TEST(CMSISNNDepthwiseConv2dBufferSize, UnEvenChannels) {
+  int32_t filter_w = fake_parameters(gen);
+  int32_t filter_h = fake_parameters(gen);
+  int32_t input_n = 1;
+
+  auto depthwise_conv2d_with_channels = [=](CMSISNNFlags flags, int32_t 
input_c, int32_t output_c) {
+    return DepthwiseConv2dBufferSize(flags, input_n, input_c, output_c, 
filter_w, filter_h);
+  };
+
+  ASSERT_EQ(depthwise_conv2d_with_channels(kNoExt, 4, 6), 0);
+  ASSERT_EQ(depthwise_conv2d_with_channels(kNoExt, 8, 7), 0);
+  ASSERT_EQ(depthwise_conv2d_with_channels(kHasDSP, 4, 6), 0);
+  ASSERT_EQ(depthwise_conv2d_with_channels(kHasDSP, 8, 7), 0);
+  ASSERT_EQ(depthwise_conv2d_with_channels(kHasMVE, 4, 6), 0);
+  ASSERT_EQ(depthwise_conv2d_with_channels(kHasMVE, 8, 7), 0);
+}
+
+TEST(CMSISNNDepthwiseConv2dBufferSize, MultipleBatches) {
+  int32_t input_output_c = fake_parameters(gen);
+  int32_t filter_w = fake_parameters(gen);
+  int32_t filter_h = fake_parameters(gen);
+
+  auto depthwise_conv2d_with_batch = [=](CMSISNNFlags flags, int32_t input_n) {
+    return DepthwiseConv2dBufferSize(flags, input_n, input_output_c, 
input_output_c, filter_w,
+                                     filter_h);
+  };
+
+  ASSERT_EQ(depthwise_conv2d_with_batch(kNoExt, 4), 0);
+  ASSERT_EQ(depthwise_conv2d_with_batch(kNoExt, 7), 0);
+  ASSERT_EQ(depthwise_conv2d_with_batch(kHasDSP, 4), 0);
+  ASSERT_EQ(depthwise_conv2d_with_batch(kHasDSP, 7), 0);
+  ASSERT_EQ(depthwise_conv2d_with_batch(kHasMVE, 4), 0);
+  ASSERT_EQ(depthwise_conv2d_with_batch(kHasMVE, 7), 0);
+}
+
+TEST(CMSISNNDepthwiseConv2dBufferSize, Default) {
+  int32_t input_output_c = fake_parameters(gen);
+  int32_t filter_w = fake_parameters(gen);
+  int32_t filter_h = fake_parameters(gen);
+  int32_t input_n = 1;
+
+  int32_t mve_calculated_buffer =
+      (2 * input_output_c * filter_w * filter_h) * (int32_t)sizeof(int16_t) + 
4;
+  int32_t dsp_calculated_buffer = (input_output_c * filter_w * filter_h) * 
(int32_t)sizeof(int16_t);
+
+  auto depthwise_conv2d = [=](CMSISNNFlags flags) {
+    return DepthwiseConv2dBufferSize(flags, input_n, input_output_c, 
input_output_c, filter_w,
+                                     filter_h);
+  };
+
+  ASSERT_EQ(depthwise_conv2d(kNoExt), 0);
+  ASSERT_EQ(depthwise_conv2d(kNoExt), 0);
+  ASSERT_EQ(depthwise_conv2d(kHasDSP), dsp_calculated_buffer);
+  ASSERT_EQ(depthwise_conv2d(kHasDSP), dsp_calculated_buffer);
+  ASSERT_EQ(depthwise_conv2d(kHasMVE), mve_calculated_buffer);
+  ASSERT_EQ(depthwise_conv2d(kHasMVE), mve_calculated_buffer);
+}
+
+TEST(CMSISNNAvgPoolBufferSize, Default) {
+  int32_t input_c = fake_parameters(gen);
+  int32_t calculated_buffer = (input_c * sizeof(int32_t));
+
+  auto avg_pool = [=](CMSISNNFlags flags) { return AvgPoolBufferSize(flags, 
input_c); };
+
+  ASSERT_EQ(avg_pool(kNoExt), 0);
+  ASSERT_EQ(avg_pool(kHasDSP), calculated_buffer);
+  ASSERT_EQ(avg_pool(kHasMVE), 0);
+}
+
+}  // namespace cmsisnn
+}  // namespace contrib
+}  // namespace relay
+}  // namespace tvm
+
+#endif
diff --git a/tests/cpp/relay/backend/contrib/cmsisnn/compiler_attrs_test.cc 
b/tests/cpp/relay/backend/contrib/cmsisnn/compiler_attrs_test.cc
new file mode 100644
index 0000000..0eb980f
--- /dev/null
+++ b/tests/cpp/relay/backend/contrib/cmsisnn/compiler_attrs_test.cc
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifdef TVM_USE_CMSISNN
+
+#include "../../../../../../src/relay/backend/contrib/cmsisnn/compiler_attrs.h"
+
+#include <gtest/gtest.h>
+#include <tvm/ir/transform.h>
+
+#include <cmath>
+#include <string>
+
+namespace tvm {
+namespace relay {
+namespace contrib {
+namespace cmsisnn {
+
+static const char* mveCPUs[] = {"cortex-m55"};
+static const char* dspCPUs[] = {"cortex-m7", "cortex-m33", "cortex-m35p"};
+static const char* noExtensions[] = {"cortex-m0", "cortex-m3", "cortex-m4"};
+
+class CMSISNNFlagsMVECPUs : public testing::TestWithParam<const char*> {};
+class CMSISNNFlagsDSPCPUs : public testing::TestWithParam<const char*> {};
+class CMSISNNFlagsNoExtensions : public testing::TestWithParam<const char*> {};
+
+static CMSISNNFlags GetFlagsWithCompilerAttrs(String mcpu, String mattr) {
+  auto context_node = make_object<tvm::transform::PassContextNode>();
+  auto cmsisnn_config_node = make_object<CMSISNNCompilerConfigNode>();
+  cmsisnn_config_node->InitBySeq("mcpu", mcpu, "mattr", mattr);
+
+  context_node->config = {
+      {"relay.ext.cmsisnn.options", 
CMSISNNCompilerConfig(cmsisnn_config_node)}};
+
+  tvm::transform::PassContext context = 
tvm::transform::PassContext(context_node);
+  return GetCompilerFlags(context);
+}
+
+TEST(CMSISNNFlags, CreateFromUndefined) {
+  auto context_node = make_object<tvm::transform::PassContextNode>();
+  tvm::transform::PassContext context = 
tvm::transform::PassContext(context_node);
+  CMSISNNFlags flags = GetCompilerFlags(context);
+  ASSERT_EQ(flags.mve, false);
+  ASSERT_EQ(flags.dsp, false);
+}
+
+TEST_P(CMSISNNFlagsMVECPUs, CheckMVESet) {
+  CMSISNNFlags flags = GetFlagsWithCompilerAttrs(GetParam(), "");
+  ASSERT_EQ(flags.dsp, true);
+  ASSERT_EQ(flags.mve, true);
+}
+
+TEST_P(CMSISNNFlagsMVECPUs, CheckMVEOverrideCPU) {
+  std::string mcpu = GetParam();
+  CMSISNNFlags flags = GetFlagsWithCompilerAttrs(mcpu + "+nomve", "");
+  ASSERT_EQ(flags.dsp, false);
+  ASSERT_EQ(flags.mve, false);
+}
+
+TEST_P(CMSISNNFlagsMVECPUs, CheckDSPOverrideCPU) {
+  std::string mcpu = GetParam();
+  CMSISNNFlags flags = GetFlagsWithCompilerAttrs(mcpu + "+nodsp", "");
+  ASSERT_EQ(flags.dsp, false);
+  ASSERT_EQ(flags.mve, false);
+}
+
+TEST_P(CMSISNNFlagsMVECPUs, CheckCombinedOverrideCPU) {
+  std::string mcpu = GetParam();
+  CMSISNNFlags flags = GetFlagsWithCompilerAttrs(mcpu + "+nodsp+nomve", "");
+  ASSERT_EQ(flags.dsp, false);
+  ASSERT_EQ(flags.mve, false);
+  flags = GetFlagsWithCompilerAttrs(mcpu + "+nomve+nodsp", "");
+  ASSERT_EQ(flags.dsp, false);
+  ASSERT_EQ(flags.mve, false);
+}
+
+TEST_P(CMSISNNFlagsMVECPUs, CheckMVEOverrideMAttr) {
+  CMSISNNFlags flags = GetFlagsWithCompilerAttrs(GetParam(), "+nomve");
+  ASSERT_EQ(flags.dsp, false);
+  ASSERT_EQ(flags.mve, false);
+}
+
+TEST_P(CMSISNNFlagsMVECPUs, CheckDSPOverrideMattr) {
+  CMSISNNFlags flags = GetFlagsWithCompilerAttrs(GetParam(), "+nodsp");
+  ASSERT_EQ(flags.dsp, false);
+  ASSERT_EQ(flags.mve, false);
+}
+
+TEST_P(CMSISNNFlagsMVECPUs, CheckCombinedOverrideMattr) {
+  CMSISNNFlags flags = GetFlagsWithCompilerAttrs(GetParam(), "+nodsp+nomve");
+  ASSERT_EQ(flags.dsp, false);
+  ASSERT_EQ(flags.mve, false);
+  flags = GetFlagsWithCompilerAttrs(GetParam(), "+nomve+nodsp");
+  ASSERT_EQ(flags.dsp, false);
+  ASSERT_EQ(flags.mve, false);
+  flags = GetFlagsWithCompilerAttrs(GetParam(), "+woofles+nomve+nodsp");
+  ASSERT_EQ(flags.dsp, false);
+  ASSERT_EQ(flags.mve, false);
+}
+
+TEST_P(CMSISNNFlagsDSPCPUs, CheckDSPSet) {
+  CMSISNNFlags flags = GetFlagsWithCompilerAttrs(GetParam(), "");
+  ASSERT_EQ(flags.dsp, true);
+  ASSERT_EQ(flags.mve, false);
+}
+
+TEST_P(CMSISNNFlagsDSPCPUs, CheckDSPOverrideCPU) {
+  std::string mcpu = GetParam();
+  CMSISNNFlags flags = GetFlagsWithCompilerAttrs(mcpu + "+nodsp", "");
+  ASSERT_EQ(flags.dsp, false);
+  ASSERT_EQ(flags.mve, false);
+  flags = GetFlagsWithCompilerAttrs(mcpu + "+nodsp+woofles", "");
+  ASSERT_EQ(flags.dsp, false);
+  ASSERT_EQ(flags.mve, false);
+}
+
+TEST_P(CMSISNNFlagsDSPCPUs, CheckDSPOverrideMattr) {
+  CMSISNNFlags flags = GetFlagsWithCompilerAttrs(GetParam(), "+nodsp");
+  ASSERT_EQ(flags.dsp, false);
+  ASSERT_EQ(flags.mve, false);
+  flags = GetFlagsWithCompilerAttrs(GetParam(), "+nodsp+woofles");
+  ASSERT_EQ(flags.dsp, false);
+  ASSERT_EQ(flags.mve, false);
+}
+
+TEST_P(CMSISNNFlagsNoExtensions, CheckNoFlags) {
+  CMSISNNFlags flags = GetFlagsWithCompilerAttrs(GetParam(), "");
+  ASSERT_EQ(flags.dsp, false);
+  ASSERT_EQ(flags.mve, false);
+}
+
+INSTANTIATE_TEST_CASE_P(CMSISNNFlags, CMSISNNFlagsMVECPUs, 
::testing::ValuesIn(mveCPUs));
+INSTANTIATE_TEST_CASE_P(CMSISNNFlags, CMSISNNFlagsDSPCPUs, 
::testing::ValuesIn(dspCPUs));
+INSTANTIATE_TEST_CASE_P(CMSISNNFlags, CMSISNNFlagsNoExtensions, 
::testing::ValuesIn(noExtensions));
+
+}  // namespace cmsisnn
+}  // namespace contrib
+}  // namespace relay
+}  // namespace tvm
+
+#endif
diff --git a/tests/python/relay/aot/aot_test_utils.py 
b/tests/python/relay/aot/aot_test_utils.py
index d335528..e7ca5c6 100644
--- a/tests/python/relay/aot/aot_test_utils.py
+++ b/tests/python/relay/aot/aot_test_utils.py
@@ -145,6 +145,11 @@ AOT_CORSTONE300_RUNNER = AOTTestRunner(
     """,
     includes=["uart.h"],
     parameters={"NPU_VARIANT": "256"},
+    pass_config={
+        "relay.ext.cmsisnn.options": {
+            "mcpu": "cortex-m55",
+        }
+    },
 )
 
 

Reply via email to