This is an automated email from the ASF dual-hosted git repository.

mousius pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6247bf48aa [CMSIS-NN] Aligned buffer sizes for Conv2D post CMSIS-NN 
SHA update (#11359)
6247bf48aa is described below

commit 6247bf48aaa59be9549dd8c342702c6005f16c5f
Author: Ashutosh Parkhi <[email protected]>
AuthorDate: Mon May 23 11:59:02 2022 +0100

    [CMSIS-NN] Aligned buffer sizes for Conv2D post CMSIS-NN SHA update (#11359)
---
 src/relay/backend/contrib/cmsisnn/buffer_size.cc   | 18 ++++++----
 src/relay/backend/contrib/cmsisnn/buffer_size.h    |  3 +-
 src/relay/backend/contrib/cmsisnn/relay_to_tir.cc  |  6 ++--
 .../backend/contrib/cmsisnn/buffer_size_test.cc    | 41 ++++++++++++----------
 tests/python/relay/aot/test_crt_aot.py             |  2 +-
 5 files changed, 40 insertions(+), 30 deletions(-)

diff --git a/src/relay/backend/contrib/cmsisnn/buffer_size.cc 
b/src/relay/backend/contrib/cmsisnn/buffer_size.cc
index 2502a09e75..b6b98c0fc3 100644
--- a/src/relay/backend/contrib/cmsisnn/buffer_size.cc
+++ b/src/relay/backend/contrib/cmsisnn/buffer_size.cc
@@ -29,24 +29,30 @@ namespace cmsisnn {
 
 int Conv2dBufferSize(CMSISNNFlags flags, int32_t padding_w, int32_t padding_h, 
int32_t input_n,
                      int32_t input_h, int32_t input_c, int32_t output_h, 
int32_t output_w,
-                     int32_t stride_w, int32_t stride_h, int32_t filter_w, 
int32_t filter_h) {
+                     int32_t stride_w, int32_t stride_h, int32_t dilation_w, 
int32_t dilation_h,
+                     int32_t filter_w, int32_t filter_h) {
   bool is1x1 = (padding_w == 0) && (padding_h == 0) && (input_c % 4 == 0) && 
(stride_w == 1) &&
-               (stride_h == 1) && (filter_w == 1) && (filter_h == 1);
-  bool is1xN =
-      (output_h == 1) && (input_h == 1) && (filter_h == 1) && (output_w % 4 == 
0) && (input_n == 1);
+               (stride_h == 1) && (filter_w == 1) && (filter_h == 1) && 
(dilation_w == 1) &&
+               (dilation_h == 1);
+  bool is1xN = (output_h == 1) && (input_h == 1) && (filter_h == 1) && 
(output_w % 4 == 0) &&
+               (input_n == 1) && (dilation_w == 1) && (dilation_h == 1);
 
   if (is1x1) {
     return 0;
   }
 
   if (is1xN) {
-    if (flags.dsp && !flags.mve) {
+    if (!flags.mve) {
       return (2 * input_c * filter_w * filter_h) * (int32_t)sizeof(int16_t);
     }
     return 0;
   }
 
-  if (flags.dsp) {
+  if (flags.mve) {
+    int32_t col_length = input_c * filter_w * filter_h;
+    col_length = (col_length + 7) / 8;
+    return 4 * col_length * 8 * (int32_t)sizeof(int8_t);
+  } else {
     return (2 * input_c * filter_w * filter_h) * (int32_t)sizeof(int16_t);
   }
   return 0;
diff --git a/src/relay/backend/contrib/cmsisnn/buffer_size.h 
b/src/relay/backend/contrib/cmsisnn/buffer_size.h
index dec3c3eafc..e89763fd5a 100644
--- a/src/relay/backend/contrib/cmsisnn/buffer_size.h
+++ b/src/relay/backend/contrib/cmsisnn/buffer_size.h
@@ -56,7 +56,8 @@ namespace cmsisnn {
  */
 int Conv2dBufferSize(CMSISNNFlags flags, int32_t padding_w, int32_t padding_h, 
int32_t input_n,
                      int32_t input_h, int32_t input_c, int32_t output_h, 
int32_t output_w,
-                     int32_t stride_w, int32_t stride_h, int32_t filter_w, 
int32_t filter_h);
+                     int32_t stride_w, int32_t stride_h, int32_t dilation_w, 
int32_t dilation_h,
+                     int32_t filter_w, int32_t filter_h);
 
 /*!
  * \brief Calculates the appropriate buffer size for CMSIS-NN Depthwise 
Convolutions
diff --git a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc 
b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
index 210175817f..dc5537ee90 100644
--- a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
+++ b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
@@ -238,9 +238,9 @@ class RelayToTIRVisitor : public MixedModeMutator {
       context_buffer_size =
           DepthwiseConv2dBufferSize(flags, input_n, input_c, output_c, 
filter_w, filter_h);
     } else {
-      context_buffer_size =
-          Conv2dBufferSize(flags, padding_w, padding_h, input_n, input_h, 
input_c, output_h,
-                           output_w, stride_w, stride_h, filter_w, filter_h);
+      context_buffer_size = Conv2dBufferSize(flags, padding_w, padding_h, 
input_n, input_h, input_c,
+                                             output_h, output_w, stride_w, 
stride_h, dilation_w,
+                                             dilation_h, filter_w, filter_h);
     }
 
     if (context_buffer_size) {
diff --git a/tests/cpp/relay/backend/contrib/cmsisnn/buffer_size_test.cc 
b/tests/cpp/relay/backend/contrib/cmsisnn/buffer_size_test.cc
index 7b8047a3b2..b7458858d4 100644
--- a/tests/cpp/relay/backend/contrib/cmsisnn/buffer_size_test.cc
+++ b/tests/cpp/relay/backend/contrib/cmsisnn/buffer_size_test.cc
@@ -44,7 +44,7 @@ class CMSISNNCalculatedBufferSize : public 
testing::TestWithParam<std::array<int
 TEST(CMSISNNConv2dBufferSize, Conv1x1) {
   int32_t any = fake_parameters(gen);
   auto conv2d_1x1 = [=](CMSISNNFlags flags, int32_t input_c) {
-    return Conv2dBufferSize(flags, 0, 0, any, any, input_c, any, any, 1, 1, 1, 
1);
+    return Conv2dBufferSize(flags, 0, 0, any, any, input_c, any, any, 1, 1, 1, 
1, 1, 1);
   };
 
   ASSERT_EQ(conv2d_1x1(kNoExt, 4), 0);
@@ -74,15 +74,15 @@ TEST(CMSISNNConv2dBufferSize, Conv1xN) {
   int32_t calculated_buffer = (2 * input_c * filter_w * filter_h) * 
(int32_t)sizeof(int16_t);
 
   auto conv2d_1xn = [=](CMSISNNFlags flags, int32_t output_w) {
-    return Conv2dBufferSize(flags, any, any, 1, 1, input_c, 1, output_w, any, 
any, filter_w,
+    return Conv2dBufferSize(flags, any, any, 1, 1, input_c, 1, output_w, any, 
any, 1, 1, filter_w,
                             filter_h);
   };
 
-  ASSERT_EQ(conv2d_1xn(kNoExt, 4), 0);
-  ASSERT_EQ(conv2d_1xn(kNoExt, 8), 0);
-  ASSERT_EQ(conv2d_1xn(kNoExt, 12), 0);
-  ASSERT_EQ(conv2d_1xn(kNoExt, 16), 0);
-  ASSERT_EQ(conv2d_1xn(kNoExt, 32), 0);
+  ASSERT_EQ(conv2d_1xn(kNoExt, 4), calculated_buffer);
+  ASSERT_EQ(conv2d_1xn(kNoExt, 8), calculated_buffer);
+  ASSERT_EQ(conv2d_1xn(kNoExt, 12), calculated_buffer);
+  ASSERT_EQ(conv2d_1xn(kNoExt, 16), calculated_buffer);
+  ASSERT_EQ(conv2d_1xn(kNoExt, 32), calculated_buffer);
 
   ASSERT_EQ(conv2d_1xn(kHasDSP, 4), calculated_buffer);
   ASSERT_EQ(conv2d_1xn(kHasDSP, 8), calculated_buffer);
@@ -104,17 +104,20 @@ TEST(CMSISNNConv2dBufferSize, Default) {
   int32_t filter_w = fake_parameters(gen);
   int32_t filter_h = fake_parameters(gen);
   int32_t calculated_buffer = (2 * input_c * filter_w * filter_h) * 
(int32_t)sizeof(int16_t);
+  int32_t col_length = input_c * filter_w * filter_h;
+  col_length = (col_length + 7) / 8;
+  int32_t calculated_buffer_mve = 4 * col_length * 8 * (int32_t)sizeof(int8_t);
 
   auto conv2d = [=](CMSISNNFlags flags, int32_t output_w) {
-    return Conv2dBufferSize(flags, any, any, 1, 1, input_c, 1, output_w, any, 
any, filter_w,
-                            filter_h);
+    return Conv2dBufferSize(flags, any, any, 1, 1, input_c, 1, output_w, any, 
any, any, any,
+                            filter_w, filter_h);
   };
 
-  ASSERT_EQ(conv2d(kNoExt, 4), 0);
-  ASSERT_EQ(conv2d(kNoExt, 8), 0);
-  ASSERT_EQ(conv2d(kNoExt, 12), 0);
-  ASSERT_EQ(conv2d(kNoExt, 16), 0);
-  ASSERT_EQ(conv2d(kNoExt, 32), 0);
+  ASSERT_EQ(conv2d(kNoExt, 4), calculated_buffer);
+  ASSERT_EQ(conv2d(kNoExt, 8), calculated_buffer);
+  ASSERT_EQ(conv2d(kNoExt, 12), calculated_buffer);
+  ASSERT_EQ(conv2d(kNoExt, 16), calculated_buffer);
+  ASSERT_EQ(conv2d(kNoExt, 32), calculated_buffer);
 
   ASSERT_EQ(conv2d(kHasDSP, 4), calculated_buffer);
   ASSERT_EQ(conv2d(kHasDSP, 8), calculated_buffer);
@@ -122,11 +125,11 @@ TEST(CMSISNNConv2dBufferSize, Default) {
   ASSERT_EQ(conv2d(kHasDSP, 16), calculated_buffer);
   ASSERT_EQ(conv2d(kHasDSP, 32), calculated_buffer);
 
-  ASSERT_EQ(conv2d(kHasMVE, 4), calculated_buffer);
-  ASSERT_EQ(conv2d(kHasMVE, 8), calculated_buffer);
-  ASSERT_EQ(conv2d(kHasMVE, 12), calculated_buffer);
-  ASSERT_EQ(conv2d(kHasMVE, 16), calculated_buffer);
-  ASSERT_EQ(conv2d(kHasMVE, 32), calculated_buffer);
+  ASSERT_EQ(conv2d(kHasMVE, 4), calculated_buffer_mve);
+  ASSERT_EQ(conv2d(kHasMVE, 8), calculated_buffer_mve);
+  ASSERT_EQ(conv2d(kHasMVE, 12), calculated_buffer_mve);
+  ASSERT_EQ(conv2d(kHasMVE, 16), calculated_buffer_mve);
+  ASSERT_EQ(conv2d(kHasMVE, 32), calculated_buffer_mve);
 }
 
 TEST(CMSISNNDepthwiseConv2dBufferSize, UnEvenChannels) {
diff --git a/tests/python/relay/aot/test_crt_aot.py 
b/tests/python/relay/aot/test_crt_aot.py
index d1d80d434b..ffae70d0cf 100644
--- a/tests/python/relay/aot/test_crt_aot.py
+++ b/tests/python/relay/aot/test_crt_aot.py
@@ -992,7 +992,7 @@ def test_workspace_calculation_cmsis_nn():
     ):
         lib = tvm.relay.build(mod, target, executor=executor, runtime=runtime, 
params=params)
     mlf_memory_map = mlf._build_function_memory_map(lib.function_metadata)
-    assert mlf_memory_map["main"][0]["workspace_size_bytes"] == 9904
+    assert mlf_memory_map["main"][0]["workspace_size_bytes"] == 14384
 
 
 def test_aot_codegen_checks_returns():

Reply via email to