This is an automated email from the ASF dual-hosted git repository.

xiaoxiang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/nuttx-apps.git

commit 7d87768f78a716fcfb122e739bfe1be6de0831a2
Author: jihandong <jihand...@xiaomi.com>
AuthorDate: Thu Jun 6 20:26:59 2024 +0800

    ml: a cmdline tool to use tflite-micro.
    
    Signed-off-by: jihandong <jihand...@xiaomi.com>
---
 mlearning/tflite-micro/Kconfig      |  17 +++-
 mlearning/tflite-micro/Makefile     |   7 ++
 mlearning/tflite-micro/tflm_tool.cc | 150 ++++++++++++++++++++++++++++++++++++
 3 files changed, 173 insertions(+), 1 deletion(-)

diff --git a/mlearning/tflite-micro/Kconfig b/mlearning/tflite-micro/Kconfig
index 603ec4abb..4363d7c88 100644
--- a/mlearning/tflite-micro/Kconfig
+++ b/mlearning/tflite-micro/Kconfig
@@ -25,6 +25,21 @@ config TFLITEMICRO
 
 if TFLITEMICRO
        config TFLITEMICRO_DEBUG
-       bool "TFLITEMICRO_DEBUG"
+       bool "Print tflite-micro's debug message"
        default n
+
+config TFLITEMICRO_TOOL
+       bool "tflite-micro cmdline tool"
+       default n
+
+if TFLITEMICRO_TOOL
+       config TFLITEMICRO_TOOL_PRIORITY
+       int "tflite-micro tool priority"
+       default 100
+
+       config TFLITEMICRO_TOOL_STACKSIZE
+       int "tflite-micro tool stacksize"
+       default 4096
+
+endif
 endif
diff --git a/mlearning/tflite-micro/Makefile b/mlearning/tflite-micro/Makefile
index eb03470f4..b5dbd7f81 100644
--- a/mlearning/tflite-micro/Makefile
+++ b/mlearning/tflite-micro/Makefile
@@ -99,6 +99,13 @@ endif
 # extra hardware support.
 -include $(TFLM_DIR)/tensorflow/lite/micro/nuttx/Makefile
 
+ifneq ($(CONFIG_TFLITEMICRO_TOOL),)
+MAINSRC   = tflm_tool.cc
+PROGNAME  = tflm
+PRIORITY  = $(CONFIG_TFLITEMICRO_TOOL_PRIORITY)
+STACKSIZE = $(CONFIG_TFLITEMICRO_TOOL_STACKSIZE)
+endif
+
 CFLAGS   += ${COMMON_FLAGS}
 CXXFLAGS += ${COMMON_FLAGS}
 
diff --git a/mlearning/tflite-micro/tflm_tool.cc 
b/mlearning/tflite-micro/tflm_tool.cc
new file mode 100644
index 000000000..df71500b2
--- /dev/null
+++ b/mlearning/tflite-micro/tflm_tool.cc
@@ -0,0 +1,150 @@
+/****************************************************************************
+ * apps/mlearning/tflite-micro/tflm_tool.cc
+ *
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.  The
+ * ASF licenses this file to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance with the
+ * License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ ****************************************************************************/
+
+/****************************************************************************
+ * Included Files
+ ****************************************************************************/
+
+#include <unistd.h>
+
+#include <cstdint>
+#include <fstream>
+#include <memory>
+
+#include "tensorflow/lite/micro/micro_interpreter.h"
+#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
+#include "tensorflow/lite/micro/micro_profiler.h"
+
+/****************************************************************************
+ * Private Functions
+ ****************************************************************************/
+
+static void usage(void)
+{
+  printf("\nUtility to use tflite micro on nuttx.\n"
+    "[ -C       ] Compile tflite model into c++ codes.\n"
+    "[ -E       ] Do once evaluation (for profiling).\n"
+    "[ -i <str> ] Readable model file path.\n"
+    "[ -o <str> ] Writable c++ file path.\n"
+    "[ -p <str> ] Prefix of compiled code.\n"
+    "[ -a <int> ] Arena size (mempool).\n"
+    "[ -h       ] Print this message.\n");
+}
+
+/****************************************************************************
+ * Public Functions
+ ****************************************************************************/
+
+extern "C" int main(int argc, FAR char* argv[])
+{
+  const char* modelFileName = nullptr;
+  const char* codeFileName = nullptr;
+  const char* prefix = "NXAI";
+  bool need_compile = false;
+  bool need_invoke = false;
+  int arenaSize = 1024 * 8;
+
+  int ch;
+  while ((ch = getopt(argc, argv, "CEhi:o:p:a:")) != EOF)
+    {
+      switch (ch)
+        {
+          case 'C':
+            need_compile = true;
+            break;
+          case 'E':
+            need_invoke = true;
+            break;
+          case 'p':
+            prefix = optarg;
+            break;
+          case 'i':
+            modelFileName = optarg;
+            break;
+          case 'o':
+            codeFileName = optarg;
+            break;
+          case 'a':
+            arenaSize = strtol(optarg, NULL, 0);
+            break;
+          case 'h':
+          default:
+            usage();
+            return -1;
+        }
+    }
+
+  if (!modelFileName || !codeFileName)
+    {
+      usage();
+      return -1;
+    }
+
+  std::ifstream ifs(modelFileName, std::ios::binary);
+  ifs.seekg(0, std::ios::end);
+  size_t modelSize = ifs.tellg();
+  std::unique_ptr<uint8_t[]> pModel(new uint8_t[modelSize]);
+
+  ifs.seekg(0, std::ios::beg);
+  ifs.read(reinterpret_cast<char*>(pModel.get()), modelSize);
+  ifs.close();
+
+  /* HACK: can change operators here. */
+
+  tflite::MicroMutableOpResolver<8> resolver;
+  resolver.AddConv2D(tflite::Register_CONV_2D_INT8());
+  resolver.AddMaxPool2D(tflite::Register_MAX_POOL_2D_INT8());
+  resolver.AddQuantize(tflite::Register_QUANTIZE_FLOAT32_INT8());
+  resolver.AddDequantize(tflite::Register_DEQUANTIZE_INT8());
+  resolver.AddMean(tflite::Register_MEAN_INT8());
+  resolver.AddReshape();
+  resolver.AddFullyConnected(tflite::Register_FULLY_CONNECTED_INT8());
+  resolver.AddSoftmax(tflite::Register_SOFTMAX_INT8());
+
+  std::unique_ptr<uint8_t[]> pArena(new uint8_t[arenaSize]);
+
+  tflite::MicroProfiler profiler;
+  tflite::MicroInterpreter interpreter(tflite::GetModel(pModel.get()),
+    resolver, pArena.get(), arenaSize, nullptr,
+    reinterpret_cast<tflite::MicroProfilerInterface*>(&profiler));
+
+  /* HACK: can add testcases here. */
+
+  if (need_invoke)
+    {
+      interpreter.Invoke();
+      profiler.LogCsv();
+      profiler.LogTicksPerTagCsv();
+    }
+
+  if (need_compile)
+    {
+#ifdef TFLITE_MODEL_COMPILER
+      std::ofstream ofs(codeFileName);
+      interpreter.Compile(ofs, prefix);
+      ofs.close();
+#else
+      printf("Not supported compiling %s.\n", prefix);
+#endif
+    }
+
+  printf("nxai done!\n");
+  return 0;
+}
\ No newline at end of file

Reply via email to