This is an automated email from the ASF dual-hosted git repository.
leandron pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 92eeef6 Calculate CMSIS-NN buffer size with respect to architecture
extensions (#9338)
92eeef6 is described below
commit 92eeef6ddde9b113fd8263f1fc20f08c1c3fea0d
Author: Christopher Sidebottom <[email protected]>
AuthorDate: Wed Jan 5 10:14:55 2022 +0000
Calculate CMSIS-NN buffer size with respect to architecture extensions
(#9338)
This correctly calculates the buffer sizes for a variety of targets
based on the `-mcpu` and `-mattr` flags passed to the `cmsis-nn` code
generator.
Added for Conv2d, Depthwise Conv2d and Average Pool.
---
cmake/modules/contrib/CMSISNN.cmake | 1 +
python/tvm/driver/tvmc/composite_target.py | 2 +-
src/relay/backend/contrib/cmsisnn/buffer_size.cc | 78 ++++++++
src/relay/backend/contrib/cmsisnn/buffer_size.h | 94 ++++++++++
.../backend/contrib/cmsisnn/compiler_attrs.cc | 75 ++++++++
src/relay/backend/contrib/cmsisnn/compiler_attrs.h | 75 ++++++++
src/relay/backend/contrib/cmsisnn/relay_to_tir.cc | 33 +++-
.../backend/contrib/cmsisnn/buffer_size_test.cc | 206 +++++++++++++++++++++
.../backend/contrib/cmsisnn/compiler_attrs_test.cc | 157 ++++++++++++++++
tests/python/relay/aot/aot_test_utils.py | 5 +
10 files changed, 718 insertions(+), 8 deletions(-)
diff --git a/cmake/modules/contrib/CMSISNN.cmake
b/cmake/modules/contrib/CMSISNN.cmake
index 50a3464..73ecd59 100644
--- a/cmake/modules/contrib/CMSISNN.cmake
+++ b/cmake/modules/contrib/CMSISNN.cmake
@@ -16,6 +16,7 @@
# under the License.
if(USE_CMSISNN)
+ add_definitions(-DTVM_USE_CMSISNN)
message(STATUS "Build with CMSIS-NN support")
tvm_file_glob(GLOB RELAY_CONTRIB_CMSISNN_SRCS
src/relay/backend/contrib/cmsisnn/*.cc)
list(APPEND COMPILER_SRCS ${RELAY_CONTRIB_CMSISNN_SRCS})
diff --git a/python/tvm/driver/tvmc/composite_target.py
b/python/tvm/driver/tvmc/composite_target.py
index f347158..3b5ba9d 100644
--- a/python/tvm/driver/tvmc/composite_target.py
+++ b/python/tvm/driver/tvmc/composite_target.py
@@ -53,7 +53,7 @@ REGISTERED_CODEGEN = {
"pass_pipeline": partition_for_arm_compute_lib,
},
"cmsis-nn": {
- "config_key": None,
+ "config_key": "relay.ext.cmsisnn.options",
"pass_pipeline": partition_for_cmsisnn,
},
"ethos-n77": {
diff --git a/src/relay/backend/contrib/cmsisnn/buffer_size.cc
b/src/relay/backend/contrib/cmsisnn/buffer_size.cc
new file mode 100644
index 0000000..2502a09
--- /dev/null
+++ b/src/relay/backend/contrib/cmsisnn/buffer_size.cc
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <tvm/ir/attrs.h>
+#include <tvm/ir/transform.h>
+
+#include "compiler_attrs.h"
+
+namespace tvm {
+namespace relay {
+namespace contrib {
+namespace cmsisnn {
+
+int Conv2dBufferSize(CMSISNNFlags flags, int32_t padding_w, int32_t padding_h,
int32_t input_n,
+ int32_t input_h, int32_t input_c, int32_t output_h,
int32_t output_w,
+ int32_t stride_w, int32_t stride_h, int32_t filter_w,
int32_t filter_h) {
+ bool is1x1 = (padding_w == 0) && (padding_h == 0) && (input_c % 4 == 0) &&
(stride_w == 1) &&
+ (stride_h == 1) && (filter_w == 1) && (filter_h == 1);
+ bool is1xN =
+ (output_h == 1) && (input_h == 1) && (filter_h == 1) && (output_w % 4 ==
0) && (input_n == 1);
+
+ if (is1x1) {
+ return 0;
+ }
+
+ if (is1xN) {
+ if (flags.dsp && !flags.mve) {
+ return (2 * input_c * filter_w * filter_h) * (int32_t)sizeof(int16_t);
+ }
+ return 0;
+ }
+
+ if (flags.dsp) {
+ return (2 * input_c * filter_w * filter_h) * (int32_t)sizeof(int16_t);
+ }
+ return 0;
+}
+
+int DepthwiseConv2dBufferSize(CMSISNNFlags flags, int32_t input_n, int32_t
input_c,
+ int32_t output_c, int32_t filter_w, int32_t
filter_h) {
+ if (input_c == output_c && input_n == 1) {
+ if (flags.mve) {
+ return (2 * input_c * filter_w * filter_h) * (int32_t)sizeof(int16_t) +
4;
+ }
+ if (flags.dsp) {
+ return (input_c * filter_w * filter_h) * (int32_t)sizeof(int16_t);
+ }
+ }
+ return 0;
+}
+
+int AvgPoolBufferSize(CMSISNNFlags flags, int32_t input_c) {
+ if (flags.dsp && !flags.mve) {
+ return (input_c * sizeof(int32_t));
+ }
+ return 0;
+}
+
+} // namespace cmsisnn
+} // namespace contrib
+} // namespace relay
+} // namespace tvm
diff --git a/src/relay/backend/contrib/cmsisnn/buffer_size.h
b/src/relay/backend/contrib/cmsisnn/buffer_size.h
new file mode 100644
index 0000000..dec3c3e
--- /dev/null
+++ b/src/relay/backend/contrib/cmsisnn/buffer_size.h
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/contrib/cmsisnn/buffer_size.h
+ * \brief CMSIS-NN Buffer Size calculation functions
+ */
+
+#ifndef TVM_RELAY_BACKEND_CONTRIB_CMSISNN_BUFFER_SIZE_H_
+#define TVM_RELAY_BACKEND_CONTRIB_CMSISNN_BUFFER_SIZE_H_
+
+#include <tvm/ir/transform.h>
+
+#include "compiler_attrs.h"
+
+namespace tvm {
+namespace relay {
+namespace contrib {
+namespace cmsisnn {
+
+/*!
+ * \brief Calculates the appropriate buffer size for CMSIS-NN Convolutions
+ * See:
+ *
https://github.com/ARM-software/CMSIS_5/blob/8c60448c0e1e50e426180b26db9bc31ddf774361/CMSIS/NN/Source/ConvolutionFunctions/arm_convolve_wrapper_s8.c#L108-L127
+ *
+ * \param flags - CMSIS-NN feature flags
+ * \param padding_w - Width padding
+ * \param padding_h - Height padding
+ * \param input_n - Input batch size
+ * \param input_h - Input height
+ * \param input_c - Input channels
+ * \param output_h - Output height
+ * \param output_w - Output width
+ * \param stride_w - Stride width
+ * \param stride_h - Stride height
+ * \param filter_w - Filter width
+ * \param filter_h - Filter height
+ *
+ * \return Size of buffer to allocate for convolution
+ */
+int Conv2dBufferSize(CMSISNNFlags flags, int32_t padding_w, int32_t padding_h,
int32_t input_n,
+ int32_t input_h, int32_t input_c, int32_t output_h,
int32_t output_w,
+ int32_t stride_w, int32_t stride_h, int32_t filter_w,
int32_t filter_h);
+
+/*!
+ * \brief Calculates the appropriate buffer size for CMSIS-NN Depthwise
Convolutions
+ * See:
+ *
https://github.com/ARM-software/CMSIS_5/blob/325443e52637b6c7eedbd160d238a6c462e89c9f/CMSIS/NN/Source/ConvolutionFunctions/arm_depthwise_conv_wrapper_s8.c#L115-L129
+ *
+ * \param flags - CMSIS-NN feature flags
+ * \param input_n - Input batch size
+ * \param input_c - Input channels
+ * \param output_c - Output channels
+ * \param filter_w - Filter width
+ * \param filter_h - Filter height
+ *
+ * \return Size of buffer to allocate for depthwise convolution
+ */
+int DepthwiseConv2dBufferSize(CMSISNNFlags flags, int32_t input_n, int32_t
input_c,
+ int32_t output_c, int32_t filter_w, int32_t
filter_h);
+
+/*!
+ * \brief Calculates the appropriate buffer size for CMSIS-NN Average Pooling
+ * See:
+ *
https://github.com/ARM-software/CMSIS_5/blob/bff28575f0c96a4ee9008947fea2b018a69b4900/CMSIS/NN/Source/PoolingFunctions/arm_avgpool_s8.c#L388-L398
+ *
+ * \param input_c - Input channels
+ *
+ * \return Size of buffer to allocate for average pooling
+ */
+int AvgPoolBufferSize(CMSISNNFlags flags, int32_t input_c);
+
+} // namespace cmsisnn
+} // namespace contrib
+} // namespace relay
+} // namespace tvm
+
+#endif // TVM_RELAY_BACKEND_CONTRIB_CMSISNN_BUFFER_SIZE_H_
diff --git a/src/relay/backend/contrib/cmsisnn/compiler_attrs.cc
b/src/relay/backend/contrib/cmsisnn/compiler_attrs.cc
new file mode 100644
index 0000000..c6fb8f3
--- /dev/null
+++ b/src/relay/backend/contrib/cmsisnn/compiler_attrs.cc
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "compiler_attrs.h"
+
+#include <tvm/ir/attrs.h>
+#include <tvm/ir/transform.h>
+
+#include <string>
+
+namespace tvm {
+namespace relay {
+namespace contrib {
+namespace cmsisnn {
+
+static const char* mveCPUs[] = {"cortex-m55"};
+static const char* dspCPUs[] = {"cortex-m7", "cortex-m33", "cortex-m35p"};
+
+TVM_REGISTER_NODE_TYPE(CMSISNNCompilerConfigNode);
+TVM_REGISTER_PASS_CONFIG_OPTION("relay.ext.cmsisnn.options",
CMSISNNCompilerConfig);
+
+template <typename Container>
+static inline bool MatchesCpu(std::string mcpu, const Container& cpus) {
+ auto matches_cpu = [mcpu](const char* cpu) { return mcpu.find(cpu) == 0; };
+ return std::find_if(std::begin(cpus), std::end(cpus), matches_cpu) !=
std::end(cpus);
+}
+
+static inline bool HasFlag(std::string attr, std::string flag) {
+ return attr.find(flag) != std::string::npos;
+}
+
+CMSISNNFlags GetCompilerFlags(const tvm::transform::PassContext& ctx) {
+ auto cfg =
ctx->GetConfig<CMSISNNCompilerConfig>("relay.ext.cmsisnn.options");
+ if (!cfg.defined()) {
+ return kNoExt;
+ }
+
+ std::string mcpu = cfg.value()->mcpu;
+ std::string mattr = cfg.value()->mattr;
+
+ bool nomve = HasFlag(mcpu, "+nomve") || HasFlag(mattr, "+nomve");
+ bool nodsp = HasFlag(mcpu, "+nodsp") || HasFlag(mattr, "+nodsp");
+
+ auto has_mve = MatchesCpu(mcpu, mveCPUs);
+ if (has_mve && !nomve && !nodsp) {
+ return kHasMVE;
+ }
+
+ auto has_dsp = MatchesCpu(mcpu, dspCPUs);
+ if (has_dsp && !nodsp) {
+ return kHasDSP;
+ }
+
+ return kNoExt;
+}
+
+} // namespace cmsisnn
+} // namespace contrib
+} // namespace relay
+} // namespace tvm
diff --git a/src/relay/backend/contrib/cmsisnn/compiler_attrs.h
b/src/relay/backend/contrib/cmsisnn/compiler_attrs.h
new file mode 100644
index 0000000..75005d9
--- /dev/null
+++ b/src/relay/backend/contrib/cmsisnn/compiler_attrs.h
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/contrib/cmsisnn/compiler_attrs.h
+ * \brief CMSIS-NN Compiler Attribute functionality
+ */
+
+#ifndef TVM_RELAY_BACKEND_CONTRIB_CMSISNN_COMPILER_ATTRS_H_
+#define TVM_RELAY_BACKEND_CONTRIB_CMSISNN_COMPILER_ATTRS_H_
+
+#include <tvm/ir/transform.h>
+
+namespace tvm {
+namespace relay {
+namespace contrib {
+namespace cmsisnn {
+
+/*! \brief Attributes to store the compiler options for CMSIS-NN. */
+struct CMSISNNCompilerConfigNode : public
tvm::AttrsNode<CMSISNNCompilerConfigNode> {
+ String mcpu;
+ String mattr;
+
+ TVM_DECLARE_ATTRS(CMSISNNCompilerConfigNode,
"ext.attrs.CMSISNNCompilerConfigNode") {
+ TVM_ATTR_FIELD(mcpu)
+ .describe(
+ "The CPU to configure CMSIS-NN for (i.e. cortex-m55, cortex-m4),
can also include "
+ "attributes (i.e. cortex-m55+nomve)")
+ .set_default("");
+ TVM_ATTR_FIELD(mattr)
+ .describe("The attributes to configure CMSIS-NN (i.e. +nodsp, +nomve)")
+ .set_default("");
+ }
+};
+
+class CMSISNNCompilerConfig : public Attrs {
+ public:
+ TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(CMSISNNCompilerConfig, Attrs,
+ CMSISNNCompilerConfigNode);
+};
+
+/*! \brief Flags to configure the calculations for CMSIS-NN. */
+struct CMSISNNFlags {
+ bool dsp; // Enable or disable dsp buffers
+ bool mve; // Enable or disable mve buffers
+};
+
+constexpr CMSISNNFlags kNoExt = {.dsp = false, .mve = false};
+constexpr CMSISNNFlags kHasDSP = {.dsp = true, .mve = false};
+constexpr CMSISNNFlags kHasMVE = {.dsp = true, .mve = true};
+
+CMSISNNFlags GetCompilerFlags(const tvm::transform::PassContext& ctx);
+
+} // namespace cmsisnn
+} // namespace contrib
+} // namespace relay
+} // namespace tvm
+
+#endif // TVM_RELAY_BACKEND_CONTRIB_CMSISNN_COMPILER_ATTRS_H_
diff --git a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
index 7f1582a..b874424 100644
--- a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
+++ b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
@@ -29,6 +29,8 @@
#include "../../../qnn/utils.h"
#include "../../../transforms/pattern_utils.h"
+#include "buffer_size.h"
+#include "compiler_attrs.h"
namespace tvm {
namespace relay {
@@ -159,6 +161,9 @@ class RelayToTIRVisitor : public MixedModeMutator {
// CMSIS-NN data structure "cmsis_nn_dims" for ifm expects input layout as
NHWC
// This is the same layout we expect in Relay
Array<PrimExpr> input_shape =
conv2d_call->args[0]->type_as<TensorTypeNode>()->shape;
+ int32_t input_n = qnn::get_const_int(input_shape[0]);
+ int32_t input_h = qnn::get_const_int(input_shape[1]);
+ int32_t input_c = qnn::get_const_int(input_shape[3]);
// CMSIS-NN data structure "cmsis_nn_dims" for weights expects following
layouts
// OHWI for Conv2D and IHWO for Depthwise convolutions
@@ -167,6 +172,9 @@ class RelayToTIRVisitor : public MixedModeMutator {
Array<PrimExpr> bias_shape{1, 1, 1, out_channels};
Array<PrimExpr> output_shape =
conv2d_call->type_as<TensorTypeNode>()->shape;
+ int32_t output_h = qnn::get_const_int(output_shape[1]);
+ int32_t output_w = qnn::get_const_int(output_shape[2]);
+ int32_t output_c = qnn::get_const_int(output_shape[3]);
int32_t depth_multiplier = -1;
int kernel_pos_o = kernel_layout.find("O");
@@ -179,7 +187,8 @@ class RelayToTIRVisitor : public MixedModeMutator {
// original filter_layout for depthwise is HWOI
std::string cmsisnn_api = "arm_convolve_wrapper_s8";
- if (depth_multiplier != -1) {
+ bool is_depthwise = depth_multiplier != -1;
+ if (is_depthwise) {
cmsisnn_api = "arm_depthwise_conv_wrapper_s8";
int filter_pos_h = kernel_layout.find("H");
int filter_pos_w = kernel_layout.find("W");
@@ -187,6 +196,8 @@ class RelayToTIRVisitor : public MixedModeMutator {
filter_shape[filter_pos_w],
out_channels};
filter_shape = depthwise_filter_shape;
}
+ int32_t filter_h = qnn::get_const_int(filter_shape[1]);
+ int32_t filter_w = qnn::get_const_int(filter_shape[2]);
tvm::Array<PrimExpr> call_ext_args = {tir::StringImm(cmsisnn_api), input,
filter, multiplier};
if (bias_add_call) {
@@ -195,11 +206,18 @@ class RelayToTIRVisitor : public MixedModeMutator {
call_ext_args.push_back(shift);
call_ext_args.push_back(output);
- //
https://github.com/ARM-software/CMSIS_5/blob/d788fd583984388553391de18afd8b4d2a146868/CMSIS/NN/Source/ConvolutionFunctions/arm_convolve_s8.c#L367
std::string context_buffer_name = "NULL";
- size_t context_buffer_size =
- (2 * qnn::get_const_int(input_shape[3]) *
qnn::get_const_int(filter_shape[2]) *
- qnn::get_const_int(filter_shape[1]) * (int32_t)sizeof(int16_t));
+ CMSISNNFlags flags = GetCompilerFlags(transform::PassContext::Current());
+ size_t context_buffer_size;
+ if (is_depthwise) {
+ context_buffer_size =
+ DepthwiseConv2dBufferSize(flags, input_n, input_c, output_c,
filter_w, filter_h);
+ } else {
+ context_buffer_size =
+ Conv2dBufferSize(flags, padding_w, padding_h, input_n, input_h,
input_c, output_h,
+ output_w, stride_w, stride_h, filter_w, filter_h);
+ }
+
if (context_buffer_size) {
context_buffer_name = "context_buffer_" +
std::to_string(context_buffer_id_++);
}
@@ -397,8 +415,9 @@ class RelayToTIRVisitor : public MixedModeMutator {
int context_buffer_size = 0;
std::string context_buffer_name = "NULL";
if (pool_name == "cmsisnn.qnn_avg_pool2d") {
- // TODO(@Mousius): Need to move this into buffer_size calculations
- context_buffer_size = qnn::get_const_int(input_shape[3]) *
sizeof(int32_t);
+ CMSISNNFlags flags = GetCompilerFlags(transform::PassContext::Current());
+ int32_t input_c = qnn::get_const_int(input_shape[3]);
+ context_buffer_size = AvgPoolBufferSize(flags, input_c);
context_buffer_name = "context_buffer_" +
std::to_string(context_buffer_id_++);
}
tvm::Array<PrimExpr> context_buffer_args =
{tir::StringImm(context_buffer_name),
diff --git a/tests/cpp/relay/backend/contrib/cmsisnn/buffer_size_test.cc
b/tests/cpp/relay/backend/contrib/cmsisnn/buffer_size_test.cc
new file mode 100644
index 0000000..c83daaf
--- /dev/null
+++ b/tests/cpp/relay/backend/contrib/cmsisnn/buffer_size_test.cc
@@ -0,0 +1,206 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifdef TVM_USE_CMSISNN
+
+#include "../../../../../../src/relay/backend/contrib/cmsisnn/buffer_size.h"
+
+#include <gtest/gtest.h>
+#include <tvm/ir/transform.h>
+
+#include <cmath>
+#include <random>
+#include <string>
+
+#include "../../../../../../src/relay/backend/contrib/cmsisnn/compiler_attrs.h"
+
+namespace tvm {
+namespace relay {
+namespace contrib {
+namespace cmsisnn {
+
+static std::random_device rd;
+static std::mt19937 gen(rd());
+static std::uniform_int_distribution<> fake_parameters(1, 100);
+
+class CMSISNNCalculatedBufferSize : public
testing::TestWithParam<std::array<int32_t, 3>> {};
+
+TEST(CMSISNNConv2dBufferSize, Conv1x1) {
+ int32_t any = fake_parameters(gen);
+ auto conv2d_1x1 = [=](CMSISNNFlags flags, int32_t input_c) {
+ return Conv2dBufferSize(flags, 0, 0, any, any, input_c, any, any, 1, 1, 1,
1);
+ };
+
+ ASSERT_EQ(conv2d_1x1(kNoExt, 4), 0);
+ ASSERT_EQ(conv2d_1x1(kNoExt, 8), 0);
+ ASSERT_EQ(conv2d_1x1(kNoExt, 12), 0);
+ ASSERT_EQ(conv2d_1x1(kNoExt, 16), 0);
+ ASSERT_EQ(conv2d_1x1(kNoExt, 32), 0);
+
+ ASSERT_EQ(conv2d_1x1(kHasDSP, 4), 0);
+ ASSERT_EQ(conv2d_1x1(kHasDSP, 8), 0);
+ ASSERT_EQ(conv2d_1x1(kHasDSP, 12), 0);
+ ASSERT_EQ(conv2d_1x1(kHasDSP, 16), 0);
+ ASSERT_EQ(conv2d_1x1(kHasDSP, 32), 0);
+
+ ASSERT_EQ(conv2d_1x1(kHasMVE, 4), 0);
+ ASSERT_EQ(conv2d_1x1(kHasMVE, 8), 0);
+ ASSERT_EQ(conv2d_1x1(kHasMVE, 12), 0);
+ ASSERT_EQ(conv2d_1x1(kHasMVE, 16), 0);
+ ASSERT_EQ(conv2d_1x1(kHasMVE, 32), 0);
+}
+
+TEST(CMSISNNConv2dBufferSize, Conv1xN) {
+ int32_t any = fake_parameters(gen);
+ int32_t input_c = fake_parameters(gen);
+ int32_t filter_w = fake_parameters(gen);
+ int32_t filter_h = 1;
+ int32_t calculated_buffer = (2 * input_c * filter_w * filter_h) *
(int32_t)sizeof(int16_t);
+
+ auto conv2d_1xn = [=](CMSISNNFlags flags, int32_t output_w) {
+ return Conv2dBufferSize(flags, any, any, 1, 1, input_c, 1, output_w, any,
any, filter_w,
+ filter_h);
+ };
+
+ ASSERT_EQ(conv2d_1xn(kNoExt, 4), 0);
+ ASSERT_EQ(conv2d_1xn(kNoExt, 8), 0);
+ ASSERT_EQ(conv2d_1xn(kNoExt, 12), 0);
+ ASSERT_EQ(conv2d_1xn(kNoExt, 16), 0);
+ ASSERT_EQ(conv2d_1xn(kNoExt, 32), 0);
+
+ ASSERT_EQ(conv2d_1xn(kHasDSP, 4), calculated_buffer);
+ ASSERT_EQ(conv2d_1xn(kHasDSP, 8), calculated_buffer);
+ ASSERT_EQ(conv2d_1xn(kHasDSP, 12), calculated_buffer);
+ ASSERT_EQ(conv2d_1xn(kHasDSP, 16), calculated_buffer);
+ ASSERT_EQ(conv2d_1xn(kHasDSP, 32), calculated_buffer);
+
+ ASSERT_EQ(conv2d_1xn(kHasMVE, 4), 0);
+ ASSERT_EQ(conv2d_1xn(kHasMVE, 8), 0);
+ ASSERT_EQ(conv2d_1xn(kHasMVE, 12), 0);
+ ASSERT_EQ(conv2d_1xn(kHasMVE, 16), 0);
+ ASSERT_EQ(conv2d_1xn(kHasMVE, 32), 0);
+}
+
+TEST(CMSISNNConv2dBufferSize, Default) {
+ int32_t any = fake_parameters(gen);
+
+ int32_t input_c = fake_parameters(gen);
+ int32_t filter_w = fake_parameters(gen);
+ int32_t filter_h = fake_parameters(gen);
+ int32_t calculated_buffer = (2 * input_c * filter_w * filter_h) *
(int32_t)sizeof(int16_t);
+
+ auto conv2d = [=](CMSISNNFlags flags, int32_t output_w) {
+ return Conv2dBufferSize(flags, any, any, 1, 1, input_c, 1, output_w, any,
any, filter_w,
+ filter_h);
+ };
+
+ ASSERT_EQ(conv2d(kNoExt, 4), 0);
+ ASSERT_EQ(conv2d(kNoExt, 8), 0);
+ ASSERT_EQ(conv2d(kNoExt, 12), 0);
+ ASSERT_EQ(conv2d(kNoExt, 16), 0);
+ ASSERT_EQ(conv2d(kNoExt, 32), 0);
+
+ ASSERT_EQ(conv2d(kHasDSP, 4), calculated_buffer);
+ ASSERT_EQ(conv2d(kHasDSP, 8), calculated_buffer);
+ ASSERT_EQ(conv2d(kHasDSP, 12), calculated_buffer);
+ ASSERT_EQ(conv2d(kHasDSP, 16), calculated_buffer);
+ ASSERT_EQ(conv2d(kHasDSP, 32), calculated_buffer);
+
+ ASSERT_EQ(conv2d(kHasMVE, 4), calculated_buffer);
+ ASSERT_EQ(conv2d(kHasMVE, 8), calculated_buffer);
+ ASSERT_EQ(conv2d(kHasMVE, 12), calculated_buffer);
+ ASSERT_EQ(conv2d(kHasMVE, 16), calculated_buffer);
+ ASSERT_EQ(conv2d(kHasMVE, 32), calculated_buffer);
+}
+
+TEST(CMSISNNDepthwiseConv2dBufferSize, UnEvenChannels) {
+ int32_t filter_w = fake_parameters(gen);
+ int32_t filter_h = fake_parameters(gen);
+ int32_t input_n = 1;
+
+ auto depthwise_conv2d_with_channels = [=](CMSISNNFlags flags, int32_t
input_c, int32_t output_c) {
+ return DepthwiseConv2dBufferSize(flags, input_n, input_c, output_c,
filter_w, filter_h);
+ };
+
+ ASSERT_EQ(depthwise_conv2d_with_channels(kNoExt, 4, 6), 0);
+ ASSERT_EQ(depthwise_conv2d_with_channels(kNoExt, 8, 7), 0);
+ ASSERT_EQ(depthwise_conv2d_with_channels(kHasDSP, 4, 6), 0);
+ ASSERT_EQ(depthwise_conv2d_with_channels(kHasDSP, 8, 7), 0);
+ ASSERT_EQ(depthwise_conv2d_with_channels(kHasMVE, 4, 6), 0);
+ ASSERT_EQ(depthwise_conv2d_with_channels(kHasMVE, 8, 7), 0);
+}
+
+TEST(CMSISNNDepthwiseConv2dBufferSize, MultipleBatches) {
+ int32_t input_output_c = fake_parameters(gen);
+ int32_t filter_w = fake_parameters(gen);
+ int32_t filter_h = fake_parameters(gen);
+
+ auto depthwise_conv2d_with_batch = [=](CMSISNNFlags flags, int32_t input_n) {
+ return DepthwiseConv2dBufferSize(flags, input_n, input_output_c,
input_output_c, filter_w,
+ filter_h);
+ };
+
+ ASSERT_EQ(depthwise_conv2d_with_batch(kNoExt, 4), 0);
+ ASSERT_EQ(depthwise_conv2d_with_batch(kNoExt, 7), 0);
+ ASSERT_EQ(depthwise_conv2d_with_batch(kHasDSP, 4), 0);
+ ASSERT_EQ(depthwise_conv2d_with_batch(kHasDSP, 7), 0);
+ ASSERT_EQ(depthwise_conv2d_with_batch(kHasMVE, 4), 0);
+ ASSERT_EQ(depthwise_conv2d_with_batch(kHasMVE, 7), 0);
+}
+
+TEST(CMSISNNDepthwiseConv2dBufferSize, Default) {
+ int32_t input_output_c = fake_parameters(gen);
+ int32_t filter_w = fake_parameters(gen);
+ int32_t filter_h = fake_parameters(gen);
+ int32_t input_n = 1;
+
+ int32_t mve_calculated_buffer =
+ (2 * input_output_c * filter_w * filter_h) * (int32_t)sizeof(int16_t) +
4;
+ int32_t dsp_calculated_buffer = (input_output_c * filter_w * filter_h) *
(int32_t)sizeof(int16_t);
+
+ auto depthwise_conv2d = [=](CMSISNNFlags flags) {
+ return DepthwiseConv2dBufferSize(flags, input_n, input_output_c,
input_output_c, filter_w,
+ filter_h);
+ };
+
+ ASSERT_EQ(depthwise_conv2d(kNoExt), 0);
+ ASSERT_EQ(depthwise_conv2d(kNoExt), 0);
+ ASSERT_EQ(depthwise_conv2d(kHasDSP), dsp_calculated_buffer);
+ ASSERT_EQ(depthwise_conv2d(kHasDSP), dsp_calculated_buffer);
+ ASSERT_EQ(depthwise_conv2d(kHasMVE), mve_calculated_buffer);
+ ASSERT_EQ(depthwise_conv2d(kHasMVE), mve_calculated_buffer);
+}
+
+TEST(CMSISNNAvgPoolBufferSize, Default) {
+ int32_t input_c = fake_parameters(gen);
+ int32_t calculated_buffer = (input_c * sizeof(int32_t));
+
+ auto avg_pool = [=](CMSISNNFlags flags) { return AvgPoolBufferSize(flags,
input_c); };
+
+ ASSERT_EQ(avg_pool(kNoExt), 0);
+ ASSERT_EQ(avg_pool(kHasDSP), calculated_buffer);
+ ASSERT_EQ(avg_pool(kHasMVE), 0);
+}
+
+} // namespace cmsisnn
+} // namespace contrib
+} // namespace relay
+} // namespace tvm
+
+#endif
diff --git a/tests/cpp/relay/backend/contrib/cmsisnn/compiler_attrs_test.cc
b/tests/cpp/relay/backend/contrib/cmsisnn/compiler_attrs_test.cc
new file mode 100644
index 0000000..0eb980f
--- /dev/null
+++ b/tests/cpp/relay/backend/contrib/cmsisnn/compiler_attrs_test.cc
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifdef TVM_USE_CMSISNN
+
+#include "../../../../../../src/relay/backend/contrib/cmsisnn/compiler_attrs.h"
+
+#include <gtest/gtest.h>
+#include <tvm/ir/transform.h>
+
+#include <cmath>
+#include <string>
+
+namespace tvm {
+namespace relay {
+namespace contrib {
+namespace cmsisnn {
+
+static const char* mveCPUs[] = {"cortex-m55"};
+static const char* dspCPUs[] = {"cortex-m7", "cortex-m33", "cortex-m35p"};
+static const char* noExtensions[] = {"cortex-m0", "cortex-m3", "cortex-m4"};
+
+class CMSISNNFlagsMVECPUs : public testing::TestWithParam<const char*> {};
+class CMSISNNFlagsDSPCPUs : public testing::TestWithParam<const char*> {};
+class CMSISNNFlagsNoExtensions : public testing::TestWithParam<const char*> {};
+
+static CMSISNNFlags GetFlagsWithCompilerAttrs(String mcpu, String mattr) {
+ auto context_node = make_object<tvm::transform::PassContextNode>();
+ auto cmsisnn_config_node = make_object<CMSISNNCompilerConfigNode>();
+ cmsisnn_config_node->InitBySeq("mcpu", mcpu, "mattr", mattr);
+
+ context_node->config = {
+ {"relay.ext.cmsisnn.options",
CMSISNNCompilerConfig(cmsisnn_config_node)}};
+
+ tvm::transform::PassContext context =
tvm::transform::PassContext(context_node);
+ return GetCompilerFlags(context);
+}
+
+TEST(CMSISNNFlags, CreateFromUndefined) {
+ auto context_node = make_object<tvm::transform::PassContextNode>();
+ tvm::transform::PassContext context =
tvm::transform::PassContext(context_node);
+ CMSISNNFlags flags = GetCompilerFlags(context);
+ ASSERT_EQ(flags.mve, false);
+ ASSERT_EQ(flags.dsp, false);
+}
+
+TEST_P(CMSISNNFlagsMVECPUs, CheckMVESet) {
+ CMSISNNFlags flags = GetFlagsWithCompilerAttrs(GetParam(), "");
+ ASSERT_EQ(flags.dsp, true);
+ ASSERT_EQ(flags.mve, true);
+}
+
+TEST_P(CMSISNNFlagsMVECPUs, CheckMVEOverrideCPU) {
+ std::string mcpu = GetParam();
+ CMSISNNFlags flags = GetFlagsWithCompilerAttrs(mcpu + "+nomve", "");
+ ASSERT_EQ(flags.dsp, false);
+ ASSERT_EQ(flags.mve, false);
+}
+
+TEST_P(CMSISNNFlagsMVECPUs, CheckDSPOverrideCPU) {
+ std::string mcpu = GetParam();
+ CMSISNNFlags flags = GetFlagsWithCompilerAttrs(mcpu + "+nodsp", "");
+ ASSERT_EQ(flags.dsp, false);
+ ASSERT_EQ(flags.mve, false);
+}
+
+TEST_P(CMSISNNFlagsMVECPUs, CheckCombinedOverrideCPU) {
+ std::string mcpu = GetParam();
+ CMSISNNFlags flags = GetFlagsWithCompilerAttrs(mcpu + "+nodsp+nomve", "");
+ ASSERT_EQ(flags.dsp, false);
+ ASSERT_EQ(flags.mve, false);
+ flags = GetFlagsWithCompilerAttrs(mcpu + "+nomve+nodsp", "");
+ ASSERT_EQ(flags.dsp, false);
+ ASSERT_EQ(flags.mve, false);
+}
+
+TEST_P(CMSISNNFlagsMVECPUs, CheckMVEOverrideMAttr) {
+ CMSISNNFlags flags = GetFlagsWithCompilerAttrs(GetParam(), "+nomve");
+ ASSERT_EQ(flags.dsp, false);
+ ASSERT_EQ(flags.mve, false);
+}
+
+TEST_P(CMSISNNFlagsMVECPUs, CheckDSPOverrideMattr) {
+ CMSISNNFlags flags = GetFlagsWithCompilerAttrs(GetParam(), "+nodsp");
+ ASSERT_EQ(flags.dsp, false);
+ ASSERT_EQ(flags.mve, false);
+}
+
+TEST_P(CMSISNNFlagsMVECPUs, CheckCombinedOverrideMattr) {
+ CMSISNNFlags flags = GetFlagsWithCompilerAttrs(GetParam(), "+nodsp+nomve");
+ ASSERT_EQ(flags.dsp, false);
+ ASSERT_EQ(flags.mve, false);
+ flags = GetFlagsWithCompilerAttrs(GetParam(), "+nomve+nodsp");
+ ASSERT_EQ(flags.dsp, false);
+ ASSERT_EQ(flags.mve, false);
+ flags = GetFlagsWithCompilerAttrs(GetParam(), "+woofles+nomve+nodsp");
+ ASSERT_EQ(flags.dsp, false);
+ ASSERT_EQ(flags.mve, false);
+}
+
+TEST_P(CMSISNNFlagsDSPCPUs, CheckDSPSet) {
+ CMSISNNFlags flags = GetFlagsWithCompilerAttrs(GetParam(), "");
+ ASSERT_EQ(flags.dsp, true);
+ ASSERT_EQ(flags.mve, false);
+}
+
+TEST_P(CMSISNNFlagsDSPCPUs, CheckDSPOverrideCPU) {
+ std::string mcpu = GetParam();
+ CMSISNNFlags flags = GetFlagsWithCompilerAttrs(mcpu + "+nodsp", "");
+ ASSERT_EQ(flags.dsp, false);
+ ASSERT_EQ(flags.mve, false);
+ flags = GetFlagsWithCompilerAttrs(mcpu + "+nodsp+woofles", "");
+ ASSERT_EQ(flags.dsp, false);
+ ASSERT_EQ(flags.mve, false);
+}
+
+TEST_P(CMSISNNFlagsDSPCPUs, CheckDSPOverrideMattr) {
+ CMSISNNFlags flags = GetFlagsWithCompilerAttrs(GetParam(), "+nodsp");
+ ASSERT_EQ(flags.dsp, false);
+ ASSERT_EQ(flags.mve, false);
+ flags = GetFlagsWithCompilerAttrs(GetParam(), "+nodsp+woofles");
+ ASSERT_EQ(flags.dsp, false);
+ ASSERT_EQ(flags.mve, false);
+}
+
+TEST_P(CMSISNNFlagsNoExtensions, CheckNoFlags) {
+ CMSISNNFlags flags = GetFlagsWithCompilerAttrs(GetParam(), "");
+ ASSERT_EQ(flags.dsp, false);
+ ASSERT_EQ(flags.mve, false);
+}
+
+INSTANTIATE_TEST_CASE_P(CMSISNNFlags, CMSISNNFlagsMVECPUs,
::testing::ValuesIn(mveCPUs));
+INSTANTIATE_TEST_CASE_P(CMSISNNFlags, CMSISNNFlagsDSPCPUs,
::testing::ValuesIn(dspCPUs));
+INSTANTIATE_TEST_CASE_P(CMSISNNFlags, CMSISNNFlagsNoExtensions,
::testing::ValuesIn(noExtensions));
+
+} // namespace cmsisnn
+} // namespace contrib
+} // namespace relay
+} // namespace tvm
+
+#endif
diff --git a/tests/python/relay/aot/aot_test_utils.py
b/tests/python/relay/aot/aot_test_utils.py
index d335528..e7ca5c6 100644
--- a/tests/python/relay/aot/aot_test_utils.py
+++ b/tests/python/relay/aot/aot_test_utils.py
@@ -145,6 +145,11 @@ AOT_CORSTONE300_RUNNER = AOTTestRunner(
""",
includes=["uart.h"],
parameters={"NPU_VARIANT": "256"},
+ pass_config={
+ "relay.ext.cmsisnn.options": {
+ "mcpu": "cortex-m55",
+ }
+ },
)