This is an automated email from the ASF dual-hosted git repository.

gromero pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 5b586605da Better version handling for Arduino (#11043)
5b586605da is described below

commit 5b586605da10d5c7d148ca4b37aa1c1abb3de966
Author: Gavin Uberti <[email protected]>
AuthorDate: Wed Apr 20 11:08:09 2022 -0400

    Better version handling for Arduino (#11043)
    
    * Fix bug allowing microTVM to be used with Arduino version v0.20 and
    above (see changes to _parse_connected_boards) and adds relevant unit
    tests.
    
    * Only perform version check when calling build or flash (things that
    actually require arduino-cli), and adds relevant unit tests.
    
    * Only raise a warning if the arduino-cli version present is below the
    min version (previously any version other than v0.18 would cause an
    error).
    
    * Change version comparison to use version.check, like the rest of TVM
---
 .../template_project/microtvm_api_server.py        | 104 ++++++++++++---------
 .../tests/test_arduino_microtvm_api_server.py      |  82 +++++++++++++---
 2 files changed, 125 insertions(+), 61 deletions(-)

diff --git a/apps/microtvm/arduino/template_project/microtvm_api_server.py 
b/apps/microtvm/arduino/template_project/microtvm_api_server.py
index bb4b54d8fb..95f941fe34 100644
--- a/apps/microtvm/arduino/template_project/microtvm_api_server.py
+++ b/apps/microtvm/arduino/template_project/microtvm_api_server.py
@@ -33,8 +33,9 @@ import time
 from string import Template
 import re
 
-import serial
+from packaging import version
 import serial.tools.list_ports
+
 from tvm.micro.project_api import server
 
 _LOG = logging.getLogger(__name__)
@@ -46,10 +47,7 @@ MODEL_LIBRARY_FORMAT_PATH = API_SERVER_DIR / 
MODEL_LIBRARY_FORMAT_RELPATH
 
 IS_TEMPLATE = not (API_SERVER_DIR / MODEL_LIBRARY_FORMAT_RELPATH).exists()
 
-# Used to check Arduino CLI version installed on the host.
-# We only check two levels of the version.
-ARDUINO_CLI_VERSION = 0.18
-
+MIN_ARDUINO_CLI_VERSION = version.parse("0.18.0")
 
 BOARDS = API_SERVER_DIR / "boards.json"
 
@@ -113,7 +111,7 @@ PROJECT_OPTIONS = [
     ),
     server.ProjectOption(
         "warning_as_error",
-        optional=["generate_project"],
+        optional=["build", "flash"],
         type="bool",
         help="Treat warnings as errors and raise an Exception.",
     ),
@@ -126,6 +124,7 @@ class Handler(server.ProjectAPIHandler):
         self._proc = None
         self._port = None
         self._serial = None
+        self._version = None
 
     def server_info_query(self, tvm_version):
         return server.ServerInfo(
@@ -314,25 +313,7 @@ class Handler(server.ProjectAPIHandler):
         # It's probably a standard C/C++ header
         return include_path
 
-    def _get_platform_version(self, arduino_cli_path: str) -> float:
-        # sample output of this command:
-        # 'arduino-cli alpha Version: 0.18.3 Commit: d710b642 Date: 
2021-05-14T12:36:58Z\n'
-        version_output = subprocess.check_output([arduino_cli_path, 
"version"], encoding="utf-8")
-        full_version = re.findall("version: ([\.0-9]*)", 
version_output.lower())
-        full_version = full_version[0].split(".")
-        version = float(f"{full_version[0]}.{full_version[1]}")
-
-        return version
-
     def generate_project(self, model_library_format_path, standalone_crt_dir, 
project_dir, options):
-        # Check Arduino version
-        version = 
self._get_platform_version(self._get_arduino_cli_cmd(options))
-        if version != ARDUINO_CLI_VERSION:
-            message = f"Arduino CLI version found is not supported: found 
{version}, expected {ARDUINO_CLI_VERSION}."
-            if options.get("warning_as_error") is not None and 
options["warning_as_error"]:
-                raise server.ServerError(message=message)
-            _LOG.warning(message)
-
         # Reference key directories with pathlib
         project_dir = pathlib.Path(project_dir)
         project_dir.mkdir()
@@ -368,11 +349,45 @@ class Handler(server.ProjectAPIHandler):
         # Recursively change includes
         self._convert_includes(project_dir, source_dir)
 
+    def _get_arduino_cli_cmd(self, options: dict):
+        arduino_cli_cmd = options.get("arduino_cli_cmd", ARDUINO_CLI_CMD)
+        assert arduino_cli_cmd, "'arduino_cli_cmd' command not passed and not 
found by default!"
+        return arduino_cli_cmd
+
+    def _get_platform_version(self, arduino_cli_path: str) -> float:
+        # sample output of this command:
+        # 'arduino-cli alpha Version: 0.18.3 Commit: d710b642 Date: 
2021-05-14T12:36:58Z\n'
+        version_output = subprocess.run(
+            [arduino_cli_path, "version"], check=True, stdout=subprocess.PIPE
+        ).stdout.decode("utf-8")
+        str_version = re.search(r"Version: ([\.0-9]*)", 
version_output).group(1)
+
+        # Using too low a version should raise an error. Note that naively
+        # comparing floats will fail here: 0.7 > 0.21, but 0.21 is a higher
+        # version (hence we need version.parse)
+        return version.parse(str_version)
+
+    # This will only be run for build and upload
+    def _check_platform_version(self, options):
+        if not self._version:
+            cli_command = self._get_arduino_cli_cmd(options)
+            self._version = self._get_platform_version(cli_command)
+
+        if self._version < MIN_ARDUINO_CLI_VERSION:
+            message = (
+                f"Arduino CLI version too old: found {self._version}, "
+                f"need at least {str(MIN_ARDUINO_CLI_VERSION)}."
+            )
+            if options.get("warning_as_error") is not None and 
options["warning_as_error"]:
+                raise server.ServerError(message=message)
+            _LOG.warning(message)
+
     def _get_fqbn(self, options):
         o = BOARD_PROPERTIES[options["arduino_board"]]
         return f"{o['package']}:{o['architecture']}:{o['board']}"
 
     def build(self, options):
+        self._check_platform_version(options)
         BUILD_DIR.mkdir()
 
         compile_cmd = [
@@ -391,19 +406,14 @@ class Handler(server.ProjectAPIHandler):
         # Specify project to compile
         subprocess.run(compile_cmd, check=True)
 
-    BOARD_LIST_HEADERS = ("Port", "Type", "Board Name", "FQBN", "Core")
+    POSSIBLE_BOARD_LIST_HEADERS = ("Port", "Protocol", "Type", "Board Name", 
"FQBN", "Core")
 
-    def _get_arduino_cli_cmd(self, options: dict):
-        arduino_cli_cmd = options.get("arduino_cli_cmd", ARDUINO_CLI_CMD)
-        assert arduino_cli_cmd, "'arduino_cli_cmd' command not passed and not 
found by default!"
-        return arduino_cli_cmd
-
-    def _parse_boards_tabular_str(self, tabular_str):
+    def _parse_connected_boards(self, tabular_str):
         """Parses the tabular output from `arduino-cli board list` into a 2D 
array
 
         Examples
         --------
-        >>> list(_parse_boards_tabular_str(bytes(
+        >>> list(_parse_connected_boards(bytes(
         ...     "Port         Type              Board Name FQBN                
          Core               \n"
         ...     "/dev/ttyS4   Serial Port       Unknown                        
                             \n"
         ...     "/dev/ttyUSB0 Serial Port (USB) Spresense  
SPRESENSE:spresense:spresense SPRESENSE:spresense\n"
@@ -414,20 +424,21 @@ class Handler(server.ProjectAPIHandler):
 
         """
 
-        str_rows = tabular_str.split("\n")[:-2]
-        header = str_rows[0]
-        indices = [header.index(h) for h in self.BOARD_LIST_HEADERS] + 
[len(header)]
+        # Which column headers are present depends on the version of 
arduino-cli
+        column_regex = r"\s*|".join(self.POSSIBLE_BOARD_LIST_HEADERS) + r"\s*"
+        str_rows = tabular_str.split("\n")
+        column_headers = list(re.finditer(column_regex, str_rows[0]))
+        assert len(column_headers) > 0
 
         for str_row in str_rows[1:]:
-            parsed_row = []
-            for cell_index in range(len(self.BOARD_LIST_HEADERS)):
-                start = indices[cell_index]
-                end = indices[cell_index + 1]
-                str_cell = str_row[start:end]
+            if not str_row.strip():
+                continue
+            device = {}
 
-                # Remove trailing whitespace used for padding
-                parsed_row.append(str_cell.rstrip())
-            yield parsed_row
+            for column in column_headers:
+                col_name = column.group(0).strip().lower()
+                device[col_name] = str_row[column.start() : 
column.end()].strip()
+            yield device
 
     def _auto_detect_port(self, options):
         list_cmd = [self._get_arduino_cli_cmd(options), "board", "list"]
@@ -436,9 +447,9 @@ class Handler(server.ProjectAPIHandler):
         ).stdout.decode("utf-8")
 
         desired_fqbn = self._get_fqbn(options)
-        for line in self._parse_boards_tabular_str(list_cmd_output):
-            if line[3] == desired_fqbn:
-                return line[0]
+        for device in self._parse_connected_boards(list_cmd_output):
+            if device["fqbn"] == desired_fqbn:
+                return device["port"]
 
         # If no compatible boards, raise an error
         raise BoardAutodetectFailed()
@@ -453,6 +464,7 @@ class Handler(server.ProjectAPIHandler):
         return self._port
 
     def flash(self, options):
+        self._check_platform_version(options)
         port = self._get_arduino_port(options)
 
         upload_cmd = [
diff --git 
a/apps/microtvm/arduino/template_project/tests/test_arduino_microtvm_api_server.py
 
b/apps/microtvm/arduino/template_project/tests/test_arduino_microtvm_api_server.py
index 00969a5a89..34659bca56 100644
--- 
a/apps/microtvm/arduino/template_project/tests/test_arduino_microtvm_api_server.py
+++ 
b/apps/microtvm/arduino/template_project/tests/test_arduino_microtvm_api_server.py
@@ -20,8 +20,11 @@ import sys
 from pathlib import Path
 from unittest import mock
 
+from packaging import version
 import pytest
 
+from tvm.micro.project_api import server
+
 sys.path.insert(0, str(Path(__file__).parent.parent))
 import microtvm_api_server
 
@@ -63,53 +66,102 @@ class TestGenerateProject:
         )
         assert valid_output == valid_arduino_import
 
-    BOARD_CONNECTED_OUTPUT = bytes(
+    # Format for arduino-cli v0.18.2
+    BOARD_CONNECTED_V18 = (
         "Port         Type              Board Name          FQBN               
         Core             \n"
         "/dev/ttyACM0 Serial Port (USB) Arduino Nano 33 BLE 
arduino:mbed_nano:nano33ble arduino:mbed_nano\n"
         "/dev/ttyACM1 Serial Port (USB) Arduino Nano 33     
arduino:mbed_nano:nano33    arduino:mbed_nano\n"
         "/dev/ttyS4   Serial Port       Unknown                                
                          \n"
-        "\n",
-        "utf-8",
+        "\n"
+    )
+    # Format for arduino-cli v0.21.1 and above
+    BOARD_CONNECTED_V21 = (
+        "Port         Protocol Type Board Name FQBN                        
Core             \n"
+        "/dev/ttyACM0 serial                   arduino:mbed_nano:nano33ble 
arduino:mbed_nano\n"
+        "\n"
     )
-    BOARD_DISCONNECTED_OUTPUT = bytes(
-        "Port       Type        Board Name FQBN Core\n"
-        "/dev/ttyS4 Serial Port Unknown             \n"
-        "\n",
-        "utf-8",
+    BOARD_DISCONNECTED_V21 = (
+        "Port       Protocol Type        Board Name FQBN Core\n"
+        "/dev/ttyS4 serial   Serial Port Unknown\n"
+        "\n"
     )
 
+    def test_parse_connected_boards(self):
+        h = microtvm_api_server.Handler()
+        boards = h._parse_connected_boards(self.BOARD_CONNECTED_V21)
+        assert list(boards) == [
+            {
+                "port": "/dev/ttyACM0",
+                "protocol": "serial",
+                "type": "",
+                "board name": "",
+                "fqbn": "arduino:mbed_nano:nano33ble",
+                "core": "arduino:mbed_nano",
+            }
+        ]
+
     @mock.patch("subprocess.run")
-    def test_auto_detect_port(self, mock_subprocess_run):
+    def test_auto_detect_port(self, mock_run):
         process_mock = mock.Mock()
         handler = microtvm_api_server.Handler()
 
         # Test it returns the correct port when a board is connected
-        mock_subprocess_run.return_value.stdout = self.BOARD_CONNECTED_OUTPUT
+        mock_run.return_value.stdout = bytes(self.BOARD_CONNECTED_V18, "utf-8")
+        assert handler._auto_detect_port(self.DEFAULT_OPTIONS) == 
"/dev/ttyACM0"
+
+        # Should work with old or new arduino-cli version
+        mock_run.return_value.stdout = bytes(self.BOARD_CONNECTED_V21, "utf-8")
         assert handler._auto_detect_port(self.DEFAULT_OPTIONS) == 
"/dev/ttyACM0"
 
         # Test it raises an exception when no board is connected
-        mock_subprocess_run.return_value.stdout = 
self.BOARD_DISCONNECTED_OUTPUT
+        mock_run.return_value.stdout = bytes(self.BOARD_DISCONNECTED_V21, 
"utf-8")
         with pytest.raises(microtvm_api_server.BoardAutodetectFailed):
             handler._auto_detect_port(self.DEFAULT_OPTIONS)
 
         # Test that the FQBN needs to match EXACTLY
         handler._get_fqbn = 
mock.MagicMock(return_value="arduino:mbed_nano:nano33")
-        mock_subprocess_run.return_value.stdout = self.BOARD_CONNECTED_OUTPUT
+        mock_run.return_value.stdout = bytes(self.BOARD_CONNECTED_V18, "utf-8")
         assert (
             handler._auto_detect_port({**self.DEFAULT_OPTIONS, 
"arduino_board": "nano33"})
             == "/dev/ttyACM1"
         )
 
+    BAD_CLI_VERSION = "arduino-cli  Version: 0.7.1 Commit: 7668c465 Date: 
2019-12-31T18:24:32Z\n"
+    GOOD_CLI_VERSION = "arduino-cli  Version: 0.21.1 Commit: 9fcbb392 Date: 
2022-02-24T15:41:45Z\n"
+
+    @mock.patch("subprocess.run")
+    def test_auto_detect_port(self, mock_run):
+        handler = microtvm_api_server.Handler()
+        mock_run.return_value.stdout = bytes(self.GOOD_CLI_VERSION, "utf-8")
+        handler._check_platform_version(self.DEFAULT_OPTIONS)
+        assert handler._version == version.parse("0.21.1")
+
+        handler = microtvm_api_server.Handler()
+        mock_run.return_value.stdout = bytes(self.BAD_CLI_VERSION, "utf-8")
+        with pytest.raises(server.ServerError) as error:
+            handler._check_platform_version({"warning_as_error": True})
+
     @mock.patch("subprocess.run")
-    def test_flash(self, mock_subprocess_run):
+    def test_flash(self, mock_run):
+        mock_run.return_value.stdout = bytes(self.GOOD_CLI_VERSION, "utf-8")
+
         handler = microtvm_api_server.Handler()
         handler._port = "/dev/ttyACM0"
 
         # Test no exception thrown when command works
         handler.flash(self.DEFAULT_OPTIONS)
-        mock_subprocess_run.assert_called_once()
+
+        # Test we checked version then called upload
+        assert mock_run.call_count == 2
+        assert mock_run.call_args_list[0][0] == (["arduino-cli", "version"],)
+        assert mock_run.call_args_list[1][0][0][0:2] == ["arduino-cli", 
"upload"]
+        mock_run.reset_mock()
 
         # Test exception raised when `arduino-cli upload` returns error code
-        mock_subprocess_run.side_effect = subprocess.CalledProcessError(2, [])
+        mock_run.side_effect = subprocess.CalledProcessError(2, [])
         with pytest.raises(subprocess.CalledProcessError):
             handler.flash(self.DEFAULT_OPTIONS)
+
+        # Version information should be cached and not checked again
+        mock_run.assert_called_once()
+        assert mock_run.call_args[0][0][0:2] == ["arduino-cli", "upload"]

Reply via email to