This is an automated email from the ASF dual-hosted git repository.
mousius pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 6247bf48aa [CMSIS-NN] Aligned buffer sizes for Conv2D post CMSIS-NN
SHA update (#11359)
6247bf48aa is described below
commit 6247bf48aaa59be9549dd8c342702c6005f16c5f
Author: Ashutosh Parkhi <[email protected]>
AuthorDate: Mon May 23 11:59:02 2022 +0100
[CMSIS-NN] Aligned buffer sizes for Conv2D post CMSIS-NN SHA update (#11359)
---
src/relay/backend/contrib/cmsisnn/buffer_size.cc | 18 ++++++----
src/relay/backend/contrib/cmsisnn/buffer_size.h | 3 +-
src/relay/backend/contrib/cmsisnn/relay_to_tir.cc | 6 ++--
.../backend/contrib/cmsisnn/buffer_size_test.cc | 41 ++++++++++++----------
tests/python/relay/aot/test_crt_aot.py | 2 +-
5 files changed, 40 insertions(+), 30 deletions(-)
diff --git a/src/relay/backend/contrib/cmsisnn/buffer_size.cc
b/src/relay/backend/contrib/cmsisnn/buffer_size.cc
index 2502a09e75..b6b98c0fc3 100644
--- a/src/relay/backend/contrib/cmsisnn/buffer_size.cc
+++ b/src/relay/backend/contrib/cmsisnn/buffer_size.cc
@@ -29,24 +29,30 @@ namespace cmsisnn {
int Conv2dBufferSize(CMSISNNFlags flags, int32_t padding_w, int32_t padding_h,
int32_t input_n,
int32_t input_h, int32_t input_c, int32_t output_h,
int32_t output_w,
- int32_t stride_w, int32_t stride_h, int32_t filter_w,
int32_t filter_h) {
+ int32_t stride_w, int32_t stride_h, int32_t dilation_w,
int32_t dilation_h,
+ int32_t filter_w, int32_t filter_h) {
bool is1x1 = (padding_w == 0) && (padding_h == 0) && (input_c % 4 == 0) &&
(stride_w == 1) &&
- (stride_h == 1) && (filter_w == 1) && (filter_h == 1);
- bool is1xN =
- (output_h == 1) && (input_h == 1) && (filter_h == 1) && (output_w % 4 ==
0) && (input_n == 1);
+ (stride_h == 1) && (filter_w == 1) && (filter_h == 1) &&
(dilation_w == 1) &&
+ (dilation_h == 1);
+ bool is1xN = (output_h == 1) && (input_h == 1) && (filter_h == 1) &&
(output_w % 4 == 0) &&
+ (input_n == 1) && (dilation_w == 1) && (dilation_h == 1);
if (is1x1) {
return 0;
}
if (is1xN) {
- if (flags.dsp && !flags.mve) {
+ if (!flags.mve) {
return (2 * input_c * filter_w * filter_h) * (int32_t)sizeof(int16_t);
}
return 0;
}
- if (flags.dsp) {
+ if (flags.mve) {
+ int32_t col_length = input_c * filter_w * filter_h;
+ col_length = (col_length + 7) / 8;
+ return 4 * col_length * 8 * (int32_t)sizeof(int8_t);
+ } else {
return (2 * input_c * filter_w * filter_h) * (int32_t)sizeof(int16_t);
}
return 0;
diff --git a/src/relay/backend/contrib/cmsisnn/buffer_size.h
b/src/relay/backend/contrib/cmsisnn/buffer_size.h
index dec3c3eafc..e89763fd5a 100644
--- a/src/relay/backend/contrib/cmsisnn/buffer_size.h
+++ b/src/relay/backend/contrib/cmsisnn/buffer_size.h
@@ -56,7 +56,8 @@ namespace cmsisnn {
*/
int Conv2dBufferSize(CMSISNNFlags flags, int32_t padding_w, int32_t padding_h,
int32_t input_n,
int32_t input_h, int32_t input_c, int32_t output_h,
int32_t output_w,
- int32_t stride_w, int32_t stride_h, int32_t filter_w,
int32_t filter_h);
+ int32_t stride_w, int32_t stride_h, int32_t dilation_w,
int32_t dilation_h,
+ int32_t filter_w, int32_t filter_h);
/*!
* \brief Calculates the appropriate buffer size for CMSIS-NN Depthwise
Convolutions
diff --git a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
index 210175817f..dc5537ee90 100644
--- a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
+++ b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
@@ -238,9 +238,9 @@ class RelayToTIRVisitor : public MixedModeMutator {
context_buffer_size =
DepthwiseConv2dBufferSize(flags, input_n, input_c, output_c,
filter_w, filter_h);
} else {
- context_buffer_size =
- Conv2dBufferSize(flags, padding_w, padding_h, input_n, input_h,
input_c, output_h,
- output_w, stride_w, stride_h, filter_w, filter_h);
+ context_buffer_size = Conv2dBufferSize(flags, padding_w, padding_h,
input_n, input_h, input_c,
+ output_h, output_w, stride_w,
stride_h, dilation_w,
+ dilation_h, filter_w, filter_h);
}
if (context_buffer_size) {
diff --git a/tests/cpp/relay/backend/contrib/cmsisnn/buffer_size_test.cc
b/tests/cpp/relay/backend/contrib/cmsisnn/buffer_size_test.cc
index 7b8047a3b2..b7458858d4 100644
--- a/tests/cpp/relay/backend/contrib/cmsisnn/buffer_size_test.cc
+++ b/tests/cpp/relay/backend/contrib/cmsisnn/buffer_size_test.cc
@@ -44,7 +44,7 @@ class CMSISNNCalculatedBufferSize : public
testing::TestWithParam<std::array<int
TEST(CMSISNNConv2dBufferSize, Conv1x1) {
int32_t any = fake_parameters(gen);
auto conv2d_1x1 = [=](CMSISNNFlags flags, int32_t input_c) {
- return Conv2dBufferSize(flags, 0, 0, any, any, input_c, any, any, 1, 1, 1,
1);
+ return Conv2dBufferSize(flags, 0, 0, any, any, input_c, any, any, 1, 1, 1,
1, 1, 1);
};
ASSERT_EQ(conv2d_1x1(kNoExt, 4), 0);
@@ -74,15 +74,15 @@ TEST(CMSISNNConv2dBufferSize, Conv1xN) {
int32_t calculated_buffer = (2 * input_c * filter_w * filter_h) *
(int32_t)sizeof(int16_t);
auto conv2d_1xn = [=](CMSISNNFlags flags, int32_t output_w) {
- return Conv2dBufferSize(flags, any, any, 1, 1, input_c, 1, output_w, any,
any, filter_w,
+ return Conv2dBufferSize(flags, any, any, 1, 1, input_c, 1, output_w, any,
any, 1, 1, filter_w,
filter_h);
};
- ASSERT_EQ(conv2d_1xn(kNoExt, 4), 0);
- ASSERT_EQ(conv2d_1xn(kNoExt, 8), 0);
- ASSERT_EQ(conv2d_1xn(kNoExt, 12), 0);
- ASSERT_EQ(conv2d_1xn(kNoExt, 16), 0);
- ASSERT_EQ(conv2d_1xn(kNoExt, 32), 0);
+ ASSERT_EQ(conv2d_1xn(kNoExt, 4), calculated_buffer);
+ ASSERT_EQ(conv2d_1xn(kNoExt, 8), calculated_buffer);
+ ASSERT_EQ(conv2d_1xn(kNoExt, 12), calculated_buffer);
+ ASSERT_EQ(conv2d_1xn(kNoExt, 16), calculated_buffer);
+ ASSERT_EQ(conv2d_1xn(kNoExt, 32), calculated_buffer);
ASSERT_EQ(conv2d_1xn(kHasDSP, 4), calculated_buffer);
ASSERT_EQ(conv2d_1xn(kHasDSP, 8), calculated_buffer);
@@ -104,17 +104,20 @@ TEST(CMSISNNConv2dBufferSize, Default) {
int32_t filter_w = fake_parameters(gen);
int32_t filter_h = fake_parameters(gen);
int32_t calculated_buffer = (2 * input_c * filter_w * filter_h) *
(int32_t)sizeof(int16_t);
+ int32_t col_length = input_c * filter_w * filter_h;
+ col_length = (col_length + 7) / 8;
+ int32_t calculated_buffer_mve = 4 * col_length * 8 * (int32_t)sizeof(int8_t);
auto conv2d = [=](CMSISNNFlags flags, int32_t output_w) {
- return Conv2dBufferSize(flags, any, any, 1, 1, input_c, 1, output_w, any,
any, filter_w,
- filter_h);
+ return Conv2dBufferSize(flags, any, any, 1, 1, input_c, 1, output_w, any,
any, any, any,
+ filter_w, filter_h);
};
- ASSERT_EQ(conv2d(kNoExt, 4), 0);
- ASSERT_EQ(conv2d(kNoExt, 8), 0);
- ASSERT_EQ(conv2d(kNoExt, 12), 0);
- ASSERT_EQ(conv2d(kNoExt, 16), 0);
- ASSERT_EQ(conv2d(kNoExt, 32), 0);
+ ASSERT_EQ(conv2d(kNoExt, 4), calculated_buffer);
+ ASSERT_EQ(conv2d(kNoExt, 8), calculated_buffer);
+ ASSERT_EQ(conv2d(kNoExt, 12), calculated_buffer);
+ ASSERT_EQ(conv2d(kNoExt, 16), calculated_buffer);
+ ASSERT_EQ(conv2d(kNoExt, 32), calculated_buffer);
ASSERT_EQ(conv2d(kHasDSP, 4), calculated_buffer);
ASSERT_EQ(conv2d(kHasDSP, 8), calculated_buffer);
@@ -122,11 +125,11 @@ TEST(CMSISNNConv2dBufferSize, Default) {
ASSERT_EQ(conv2d(kHasDSP, 16), calculated_buffer);
ASSERT_EQ(conv2d(kHasDSP, 32), calculated_buffer);
- ASSERT_EQ(conv2d(kHasMVE, 4), calculated_buffer);
- ASSERT_EQ(conv2d(kHasMVE, 8), calculated_buffer);
- ASSERT_EQ(conv2d(kHasMVE, 12), calculated_buffer);
- ASSERT_EQ(conv2d(kHasMVE, 16), calculated_buffer);
- ASSERT_EQ(conv2d(kHasMVE, 32), calculated_buffer);
+ ASSERT_EQ(conv2d(kHasMVE, 4), calculated_buffer_mve);
+ ASSERT_EQ(conv2d(kHasMVE, 8), calculated_buffer_mve);
+ ASSERT_EQ(conv2d(kHasMVE, 12), calculated_buffer_mve);
+ ASSERT_EQ(conv2d(kHasMVE, 16), calculated_buffer_mve);
+ ASSERT_EQ(conv2d(kHasMVE, 32), calculated_buffer_mve);
}
TEST(CMSISNNDepthwiseConv2dBufferSize, UnEvenChannels) {
diff --git a/tests/python/relay/aot/test_crt_aot.py
b/tests/python/relay/aot/test_crt_aot.py
index d1d80d434b..ffae70d0cf 100644
--- a/tests/python/relay/aot/test_crt_aot.py
+++ b/tests/python/relay/aot/test_crt_aot.py
@@ -992,7 +992,7 @@ def test_workspace_calculation_cmsis_nn():
):
lib = tvm.relay.build(mod, target, executor=executor, runtime=runtime,
params=params)
mlf_memory_map = mlf._build_function_memory_map(lib.function_metadata)
- assert mlf_memory_map["main"][0]["workspace_size_bytes"] == 9904
+ assert mlf_memory_map["main"][0]["workspace_size_bytes"] == 14384
def test_aot_codegen_checks_returns():