This is an automated email from the ASF dual-hosted git repository.

jwfromm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 8775a80  [TVMC] Support compiling and running with VM  (#10722)
8775a80 is described below

commit 8775a805d17d9e46b1cdcadd00f22d3471270ef1
Author: Margaret Qian <[email protected]>
AuthorDate: Thu Mar 31 12:30:09 2022 -0400

    [TVMC] Support compiling and running with VM  (#10722)
    
    * introduce vm compile path
    
    * support vm in tvmc
    
    * cleanup + lint
    
    * add profiler + simplify vm case in tvmcpackage
    
    * address comments + parametrize tests
    
    Co-authored-by: Margaret Qian <[email protected]>
---
 python/tvm/driver/tvmc/compiler.py        |  77 +++++++++++++++---
 python/tvm/driver/tvmc/model.py           |  94 ++++++++++++++++++----
 python/tvm/driver/tvmc/runner.py          | 127 +++++++++++++++++++-----------
 tests/python/driver/tvmc/test_compiler.py |  50 ++++++------
 tests/python/driver/tvmc/test_model.py    |  21 +++--
 tests/python/driver/tvmc/test_runner.py   |   8 +-
 6 files changed, 274 insertions(+), 103 deletions(-)

diff --git a/python/tvm/driver/tvmc/compiler.py 
b/python/tvm/driver/tvmc/compiler.py
index 66c7b35..8f24dd4 100644
--- a/python/tvm/driver/tvmc/compiler.py
+++ b/python/tvm/driver/tvmc/compiler.py
@@ -201,6 +201,7 @@ def compile_model(
     disabled_pass: Optional[str] = None,
     pass_context_configs: Optional[List[str]] = None,
     additional_target_options: Optional[Dict[str, Dict[str, Any]]] = None,
+    use_vm: bool = False,
 ):
     """Compile a model from a supported framework into a TVM module.
 
@@ -248,7 +249,8 @@ def compile_model(
         PassContext.
     additional_target_options: Optional[Dict[str, Dict[str, Any]]]
         Additional target options in a dictionary to combine with initial 
Target arguments
-
+    use_vm: bool
+        Whether to use the VM to compile the model as opposed to the graph 
executor
 
     Returns
     -------
@@ -291,8 +293,13 @@ def compile_model(
                     opt_level=opt_level, config=config, 
disabled_pass=disabled_pass
                 ):
                     logger.debug("building relay graph with autoscheduler")
-                    graph_module = relay.build(
-                        mod, target=tvm_target, executor=executor, 
runtime=runtime, params=params
+                    graph_module = build(
+                        mod,
+                        tvm_target=tvm_target,
+                        executor=executor,
+                        runtime=runtime,
+                        params=params,
+                        use_vm=use_vm,
                     )
         else:
             with autotvm.apply_history_best(tuning_records):
@@ -300,16 +307,26 @@ def compile_model(
                     opt_level=opt_level, config=config, 
disabled_pass=disabled_pass
                 ):
                     logger.debug("building relay graph with tuning records")
-                    graph_module = relay.build(
-                        mod, target=tvm_target, executor=executor, 
runtime=runtime, params=params
+                    graph_module = build(
+                        mod,
+                        tvm_target=tvm_target,
+                        executor=executor,
+                        runtime=runtime,
+                        params=params,
+                        use_vm=use_vm,
                     )
     else:
         with tvm.transform.PassContext(
             opt_level=opt_level, config=config, disabled_pass=disabled_pass
         ):
             logger.debug("building relay graph (no tuning records provided)")
-            graph_module = relay.build(
-                mod, target=tvm_target, executor=executor, runtime=runtime, 
params=params
+            graph_module = build(
+                mod,
+                tvm_target=tvm_target,
+                executor=executor,
+                runtime=runtime,
+                params=params,
+                use_vm=use_vm,
             )
 
     # Generate output dump files with sources
@@ -319,7 +336,10 @@ def compile_model(
         dump_code = [dump_code]
     dumps = {}
     for source_type in dump_code:
-        lib = graph_module.get_lib()
+        if use_vm:
+            lib = graph_module.lib
+        else:
+            lib = graph_module.get_lib()
         # TODO lib.get_source call have inconsistent behavior for unsupported
         #      formats (@leandron).
         source = str(mod) if source_type == "relay" else 
lib.get_source(source_type)
@@ -327,11 +347,7 @@ def compile_model(
 
     # Create a new tvmc model package object from the graph definition.
     package_path = tvmc_model.export_package(
-        graph_module,
-        package_path,
-        cross,
-        cross_options,
-        output_format,
+        graph_module, package_path, cross, cross_options, output_format
     )
 
     # Write dumps to file.
@@ -341,6 +357,41 @@ def compile_model(
     return TVMCPackage(package_path)
 
 
+def build(
+    mod: tvm.IRModule,
+    tvm_target: str,
+    executor: Executor,
+    runtime: Runtime,
+    params: Dict[str, tvm.nd.NDArray],
+    use_vm: bool,
+):
+    """
+    Builds the model with the provided executor.
+
+    Parameters
+    ----------
+    mod : tvm.IRModule
+        The relay module corresponding to this model.
+    tvm_target : str
+        The target for which to compile. Can be a plain string or
+        a path.
+    executor : Executor
+        The graph executor to build the model if use_vm is not True
+    runtime : Runtime
+        The runtime configuration.
+    params : dict
+        A parameter dictionary for the model.
+    use_vm: bool
+        Whether to use the VM to compile the model as opposed to the graph 
executor
+
+    """
+    if use_vm:
+        logger.debug("building with vm compile")
+        return relay.vm.compile(mod, target=tvm_target, params=params)
+    logger.debug("building with relay build")
+    return relay.build(mod, target=tvm_target, executor=executor, 
runtime=runtime, params=params)
+
+
 def save_dumps(module_name: str, dumps: Dict[str, str], dump_root: str = "."):
     """
     Serialize dump files to the disk.
diff --git a/python/tvm/driver/tvmc/model.py b/python/tvm/driver/tvmc/model.py
index 9a2617f..93ca27c 100644
--- a/python/tvm/driver/tvmc/model.py
+++ b/python/tvm/driver/tvmc/model.py
@@ -57,6 +57,8 @@ from tvm.contrib import utils
 from tvm.driver.tvmc import TVMCException
 from tvm.relay.backend.executor_factory import GraphExecutorFactoryModule
 from tvm.runtime.module import BenchmarkResult
+from tvm.runtime.vm import Executable
+
 
 try:
     from tvm.micro import export_model_library_format
@@ -182,6 +184,42 @@ class TVMCModel(object):
         """
         return self._tmp_dir.relpath("model_package.tar")
 
+    def export_vm_format(
+        self,
+        vm_exec: Executable,
+        package_path: Optional[str] = None,
+        lib_format: str = "so",
+    ):
+        """Save this TVMCModel compiled via vm to file.
+        Parameters
+        ----------
+        vm_exec : vm.Executable
+            The VM Executable containing compiled the compiled artifacts 
needed to run this model.
+        package_path : str, None
+            Where the model should be saved. Note that it will be packaged as 
a .tar file.
+            If not provided, the package will be saved to a generically named 
file in tmp.
+        lib_format : str
+            How to export the modules function library. Must be one of "so" or 
"tar".
+
+        Returns
+        -------
+        package_path : str
+            The path that the package was saved to.
+        """
+        lib_name = "lib." + lib_format
+        temp = self._tmp_dir
+        if package_path is None:
+            package_path = self.default_package_path()
+
+        path_lib = temp.relpath(lib_name)
+        vm_exec.mod.export_library(path_lib)
+        self.lib_path = path_lib
+        # Package up all the temp files into a tar file.
+        with tarfile.open(package_path, "w") as tar:
+            tar.add(path_lib, lib_name)
+
+        return package_path
+
     def export_classic_format(
         self,
         executor_factory: GraphExecutorFactoryModule,
@@ -248,7 +286,7 @@ class TVMCModel(object):
 
     def export_package(
         self,
-        executor_factory: GraphExecutorFactoryModule,
+        executor_factory: Union[GraphExecutorFactoryModule, Executable],
         package_path: Optional[str] = None,
         cross: Optional[Union[str, Callable]] = None,
         cross_options: Optional[str] = None,
@@ -281,7 +319,9 @@ class TVMCModel(object):
         if output_format == "mlf" and cross:
             raise TVMCException("Specifying the MLF output and a cross 
compiler is not supported.")
 
-        if output_format in ["so", "tar"]:
+        if isinstance(executor_factory, Executable):
+            package_path = self.export_vm_format(executor_factory, 
package_path, output_format)
+        elif output_format in ["so", "tar"]:
             package_path = self.export_classic_format(
                 executor_factory, package_path, cross, cross_options, 
output_format
             )
@@ -314,9 +354,16 @@ class TVMCPackage(object):
 
     project_dir : Path, str
         If given and loading a MLF file, the path to the project directory 
that contains the file.
+
+    use_vm : bool
+        Whether the graph module was compiled with vm or not.
     """
 
-    def __init__(self, package_path: str, project_dir: Optional[Union[Path, 
str]] = None):
+    def __init__(
+        self,
+        package_path: str,
+        project_dir: Optional[Union[Path, str]] = None,
+    ):
         self._tmp_dir = utils.tempdir()
         self.package_path = package_path
         self.import_package(self.package_path)
@@ -351,23 +398,40 @@ class TVMCPackage(object):
             self.type = "mlf"
         else:
             # Classic format
-            lib_name_so = "mod.so"
-            lib_name_tar = "mod.tar"
-            if os.path.exists(temp.relpath(lib_name_so)):
-                self.lib_name = lib_name_so
-            elif os.path.exists(temp.relpath(lib_name_tar)):
-                self.lib_name = lib_name_tar
+            classic_lib_name_so = "mod.so"
+            classic_lib_name_tar = "mod.tar"
+
+            # VM format
+            vm_lib_name_so = "lib.so"
+            vm_lib_name_tar = "lib.tar"
+
+            if os.path.exists(temp.relpath(classic_lib_name_so)):
+                self.lib_name = classic_lib_name_so
+                self.type = "classic"
+            elif os.path.exists(temp.relpath(classic_lib_name_tar)):
+                self.lib_name = classic_lib_name_tar
+                self.type = "classic"
+            elif os.path.exists(temp.relpath(vm_lib_name_so)):
+                self.lib_name = vm_lib_name_so
+                self.type = "vm"
+            elif os.path.exists(temp.relpath(vm_lib_name_tar)):
+                self.lib_name = vm_lib_name_tar
+                self.type = "vm"
             else:
                 raise TVMCException("Couldn't find exported library in the 
package.")
-            self.lib_path = temp.relpath(self.lib_name)
 
-            graph = temp.relpath("mod.json")
-            params = temp.relpath("mod.params")
+            self.lib_path = temp.relpath(self.lib_name)
 
-            self.type = "classic"
+            graph, params = None, None
+            if self.type == "classic":
+                graph = temp.relpath("mod.json")
+                params = temp.relpath("mod.params")
 
-        with open(params, "rb") as param_file:
-            self.params = bytearray(param_file.read())
+        if params is not None:
+            with open(params, "rb") as param_file:
+                self.params = bytearray(param_file.read())
+        else:
+            self.params = None
 
         if graph is not None:
             with open(graph) as graph_file:
diff --git a/python/tvm/driver/tvmc/runner.py b/python/tvm/driver/tvmc/runner.py
index 8db1272..1b6d823 100644
--- a/python/tvm/driver/tvmc/runner.py
+++ b/python/tvm/driver/tvmc/runner.py
@@ -28,9 +28,11 @@ import numpy as np
 
 import tvm
 from tvm import rpc
+from tvm.runtime import vm
 from tvm.autotvm.measure import request_remote
 from tvm.contrib import graph_executor as executor
 from tvm.contrib.debugger import debug_executor
+from tvm.runtime import profiler_vm
 from . import TVMCException
 from .arguments import TVMCSuppressedArgumentParser
 from .project import (
@@ -530,58 +532,93 @@ def run_module(
             assert device == "cpu"
             dev = session.cpu()
 
-        # TODO(gromero): Adjust for micro targets.
-        if profile:
-            logger.debug("Creating executor with profiling enabled.")
-            module = debug_executor.create(tvmc_package.graph, lib, dev, 
dump_root="./prof")
+        if tvmc_package.type == "vm":
+            assert inputs is not None, "vm runner requires inputs to be 
provided as a dict"
+
+            input_tensor = {}
+            for e, i in inputs.items():
+                input_tensor[e] = tvm.nd.array(i, dev)
+
+            if profile:
+                logger.debug("Creating vm with profile enabled.")
+                exe = profiler_vm.VirtualMachineProfiler(lib, dev)
+                res = exe.profile(**input_tensor, func_name="main")
+                # This print is intentional
+                print(res)
+            else:
+                exe = vm.VirtualMachine(lib, dev)
+
+            exe_outputs = exe.invoke("main", **input_tensor)
+            times = exe.benchmark(
+                dev,
+                **input_tensor,
+                func_name="main",
+                repeat=repeat,
+                number=number,
+                end_to_end=end_to_end,
+            )
+
+            # Special handling if the output only has a single value
+            if not isinstance(exe_outputs, list):
+                exe_outputs = [exe_outputs]
+
+            outputs = {}
+            for i, val in enumerate(exe_outputs):
+                output_name = "output_{}".format(i)
+                outputs[output_name] = val.numpy()
         else:
-            if device == "micro":
-                logger.debug("Creating executor (micro) with profiling 
disabled.")
-                module = 
tvm.micro.create_local_graph_executor(tvmc_package.graph, lib, dev)
+            # TODO(gromero): Adjust for micro targets.
+            if profile:
+                logger.debug("Creating runtime with profiling enabled.")
+                module = debug_executor.create(tvmc_package.graph, lib, dev, 
dump_root="./prof")
             else:
-                logger.debug("Creating executor with profiling disabled.")
-                module = executor.create(tvmc_package.graph, lib, dev)
+                if device == "micro":
+                    logger.debug("Creating runtime (micro) with profiling 
disabled.")
+                    module = 
tvm.micro.create_local_graph_executor(tvmc_package.graph, lib, dev)
+                else:
+                    logger.debug("Creating runtime with profiling disabled.")
+                    module = executor.create(tvmc_package.graph, lib, dev)
 
-        logger.debug("Loading params into the runtime module.")
-        module.load_params(tvmc_package.params)
+            logger.debug("Loading params into the runtime module.")
+            module.load_params(tvmc_package.params)
 
-        logger.debug("Collecting graph input shape and type:")
-        shape_dict, dtype_dict = module.get_input_info()
-        logger.debug("Graph input shape: %s", shape_dict)
-        logger.debug("Graph input type: %s", dtype_dict)
+            logger.debug("Collecting graph input shape and type:")
+            shape_dict, dtype_dict = module.get_input_info()
+            logger.debug("Graph input shape: %s", shape_dict)
+            logger.debug("Graph input type: %s", dtype_dict)
 
-        inputs_dict = make_inputs_dict(shape_dict, dtype_dict, inputs, 
fill_mode)
+            inputs_dict = make_inputs_dict(shape_dict, dtype_dict, inputs, 
fill_mode)
 
-        logger.debug("Setting inputs to the module.")
-        module.set_input(**inputs_dict)
+            logger.debug("Setting inputs to the module.")
+            module.set_input(**inputs_dict)
 
-        # Run must be called explicitly if profiling
-        if profile:
-            logger.info("Running the module with profiling enabled.")
-            report = module.profile()
-            # This print is intentional
-            print(report)
+            # Run must be called explicitly if profiling
+            if profile:
+                logger.info("Running the module with profiling enabled.")
+                report = module.profile()
+                # This print is intentional
+                print(report)
 
-        if device == "micro":
-            # TODO(gromero): Fix time_evaluator() for micro targets. Once it's
-            # fixed module.benchmark() can be used instead and this if/else can
-            # be removed.
-            module.run()
-            times = []
-        else:
-            # Call the benchmarking function of the executor.
-            # Optionally measure e2e data transfers from the
-            # CPU to device memory overheads (e.g. PCIE
-            # overheads if the device is a discrete GPU).
-            if end_to_end:
-                dev = session.cpu()
-            times = module.benchmark(dev, number=number, repeat=repeat, 
end_to_end=end_to_end)
-
-        logger.debug("Collecting the output tensors.")
-        num_outputs = module.get_num_outputs()
-        outputs = {}
-        for i in range(num_outputs):
-            output_name = "output_{}".format(i)
-            outputs[output_name] = module.get_output(i).numpy()
+            if device == "micro":
+                # TODO(gromero): Fix time_evaluator() for micro targets. Once 
it's
+                # fixed module.benchmark() can be used instead and this 
if/else can
+                # be removed.
+                module.run()
+                times = []
+            else:
+                # Call the benchmarking function of the executor.
+                # Optionally measure e2e data transfers from the
+                # CPU to device memory overheads (e.g. PCIE
+                # overheads if the device is a discrete GPU).
+                if end_to_end:
+                    dev = session.cpu()
+                times = module.benchmark(dev, number=number, repeat=repeat, 
end_to_end=end_to_end)
+
+            logger.debug("Collecting the output tensors.")
+            num_outputs = module.get_num_outputs()
+            outputs = {}
+            for i in range(num_outputs):
+                output_name = "output_{}".format(i)
+                outputs[output_name] = module.get_output(i).numpy()
 
         return TVMCResult(outputs, times)
diff --git a/tests/python/driver/tvmc/test_compiler.py 
b/tests/python/driver/tvmc/test_compiler.py
index 4b21f4e..bc836de 100644
--- a/tests/python/driver/tvmc/test_compiler.py
+++ b/tests/python/driver/tvmc/test_compiler.py
@@ -49,21 +49,32 @@ def test_save_dumps(tmpdir_factory):
 # End to end tests for compilation
 
 
-def verify_compile_tflite_module(model, shape_dict=None):
-    pytest.importorskip("tflite")
-    tvmc_model = tvmc.load(model, shape_dict=shape_dict)
-    tvmc_package = tvmc.compile(tvmc_model, target="llvm", dump_code="ll", 
desired_layout="NCHW")
-    dumps_path = tvmc_package.package_path + ".ll"
-
+def verify_tvmc_package(tvmc_package, dumps_path, use_vm=False):
     # check for output types
     assert type(tvmc_package) is TVMCPackage
-    assert type(tvmc_package.graph) is str
-    assert type(tvmc_package.lib_path) is str
-    assert type(tvmc_package.params) is bytearray
     assert os.path.exists(dumps_path)
+    assert type(tvmc_package.lib_path) is str
+
+    if use_vm:
+        assert tvmc_package.graph is None
+        assert tvmc_package.params is None
+    else:
+        assert type(tvmc_package.graph) is str
+        assert type(tvmc_package.params) is bytearray
+
+
+def verify_compile_tflite_module(model, shape_dict=None, use_vm=False):
+    pytest.importorskip("tflite")
+    tvmc_model = tvmc.load(model, shape_dict=shape_dict)
+    tvmc_package = tvmc.compile(
+        tvmc_model, target="llvm", dump_code="ll", desired_layout="NCHW", 
use_vm=use_vm
+    )
+    dumps_path = tvmc_package.package_path + ".ll"
+    verify_tvmc_package(tvmc_package, dumps_path, use_vm=use_vm)
 
 
-def test_compile_tflite_module(tflite_mobilenet_v1_1_quant):
[email protected]("use_vm", [True, False])
+def test_compile_tflite_module(use_vm, tflite_mobilenet_v1_1_quant):
     # some CI environments wont offer tflite, so skip in case it is not present
     pytest.importorskip("tflite")
     # Check default compilation.
@@ -71,7 +82,7 @@ def test_compile_tflite_module(tflite_mobilenet_v1_1_quant):
     # Check with manual shape override
     shape_string = "input:[1,224,224,3]"
     shape_dict = tvmc.shape_parser.parse_shape_string(shape_string)
-    verify_compile_tflite_module(tflite_mobilenet_v1_1_quant, shape_dict)
+    verify_compile_tflite_module(tflite_mobilenet_v1_1_quant, shape_dict, 
use_vm=use_vm)
 
 
 # This test will be skipped if the AArch64 cross-compilation toolchain is not 
installed.
@@ -198,28 +209,23 @@ def 
test_cross_compile_options_aarch64_keras_module(keras_resnet50):
     assert os.path.exists(dumps_path)
 
 
-def verify_compile_onnx_module(model, shape_dict=None):
+def verify_compile_onnx_module(model, shape_dict=None, use_vm=False):
     # some CI environments wont offer onnx, so skip in case it is not present
     pytest.importorskip("onnx")
     tvmc_model = tvmc.load(model, shape_dict=shape_dict)
-    tvmc_package = tvmc.compile(tvmc_model, target="llvm", dump_code="ll")
+    tvmc_package = tvmc.compile(tvmc_model, target="llvm", dump_code="ll", 
use_vm=use_vm)
     dumps_path = tvmc_package.package_path + ".ll"
-
-    # check for output types
-    assert type(tvmc_package) is TVMCPackage
-    assert type(tvmc_package.graph) is str
-    assert type(tvmc_package.lib_path) is str
-    assert type(tvmc_package.params) is bytearray
-    assert os.path.exists(dumps_path)
+    verify_tvmc_package(tvmc_package, dumps_path, use_vm=use_vm)
 
 
-def test_compile_onnx_module(onnx_resnet50):
[email protected]("use_vm", [True, False])
+def test_compile_onnx_module(use_vm, onnx_resnet50):
     # Test default compilation
     verify_compile_onnx_module(onnx_resnet50)
     # Test with manual shape dict
     shape_string = "data:[1,3,200,200]"
     shape_dict = tvmc.shape_parser.parse_shape_string(shape_string)
-    verify_compile_onnx_module(onnx_resnet50, shape_dict)
+    verify_compile_onnx_module(onnx_resnet50, shape_dict, use_vm=use_vm)
 
 
 # This test will be skipped if the AArch64 cross-compilation toolchain is not 
installed.
diff --git a/tests/python/driver/tvmc/test_model.py 
b/tests/python/driver/tvmc/test_model.py
index 5fccfea..74c1c4d 100644
--- a/tests/python/driver/tvmc/test_model.py
+++ b/tests/python/driver/tvmc/test_model.py
@@ -17,6 +17,7 @@
 import platform
 import pytest
 import os
+import numpy as np
 
 from os import path
 
@@ -29,13 +30,22 @@ from tvm.runtime.module import BenchmarkResult
     platform.machine() == "aarch64",
     reason="Currently failing on AArch64 - see 
https://github.com/apache/tvm/issues/10673";,
 )
-def test_tvmc_workflow(keras_simple):
[email protected]("use_vm", [True, False])
+def test_tvmc_workflow(use_vm, keras_simple):
     pytest.importorskip("tensorflow")
+    import tensorflow as tf
+
+    # Reset so the input name remains consistent across unit test runs
+    tf.keras.backend.clear_session()
 
     tvmc_model = tvmc.load(keras_simple)
     tuning_records = tvmc.tune(tvmc_model, target="llvm", 
enable_autoscheduler=True, trials=2)
-    tvmc_package = tvmc.compile(tvmc_model, tuning_records=tuning_records, 
target="llvm")
-    result = tvmc.run(tvmc_package, device="cpu", end_to_end=True)
+    tvmc_package = tvmc.compile(
+        tvmc_model, tuning_records=tuning_records, target="llvm", use_vm=use_vm
+    )
+    input_dict = {"input_1": np.random.uniform(size=(1, 32, 32, 
3)).astype("float32")}
+
+    result = tvmc.run(tvmc_package, device="cpu", end_to_end=True, 
inputs=input_dict)
     assert type(tvmc_model) is TVMCModel
     assert type(tvmc_package) is TVMCPackage
     assert type(result) is TVMCResult
@@ -45,7 +55,8 @@ def test_tvmc_workflow(keras_simple):
     assert "output_0" in result.outputs.keys()
 
 
-def test_save_load_model(keras_simple, tmpdir_factory):
[email protected]("use_vm", [True, False])
+def test_save_load_model(use_vm, keras_simple, tmpdir_factory):
     pytest.importorskip("onnx")
 
     tmpdir = tmpdir_factory.mktemp("data")
@@ -55,7 +66,7 @@ def test_save_load_model(keras_simple, tmpdir_factory):
     tvmc.tune(tvmc_model, target="llvm", trials=2)
 
     # Create package artifacts
-    tvmc.compile(tvmc_model, target="llvm")
+    tvmc.compile(tvmc_model, target="llvm", use_vm=use_vm)
 
     # Save the model to disk
     model_path = os.path.join(tmpdir, "saved_model.tar")
diff --git a/tests/python/driver/tvmc/test_runner.py 
b/tests/python/driver/tvmc/test_runner.py
index 30ce2c6..3f4ab11 100644
--- a/tests/python/driver/tvmc/test_runner.py
+++ b/tests/python/driver/tvmc/test_runner.py
@@ -72,18 +72,20 @@ def test_get_top_results_keep_results():
     assert len(sut[1]) == expected_number_of_results_per_line
 
 
[email protected]("use_vm", [True, False])
 def test_run_tflite_module__with_profile__valid_input(
-    tflite_mobilenet_v1_1_quant, tflite_compile_model, imagenet_cat
+    use_vm, tflite_mobilenet_v1_1_quant, tflite_compile_model, imagenet_cat
 ):
     # some CI environments wont offer TFLite, so skip in case it is not present
     pytest.importorskip("tflite")
 
     inputs = np.load(imagenet_cat)
+    input_dict = {"input": inputs["input"].astype("uint8")}
 
-    tflite_compiled_model = tflite_compile_model(tflite_mobilenet_v1_1_quant)
+    tflite_compiled_model = tflite_compile_model(tflite_mobilenet_v1_1_quant, 
use_vm=use_vm)
     result = tvmc.run(
         tflite_compiled_model,
-        inputs=inputs,
+        inputs=input_dict,
         hostname=None,
         device="cpu",
         profile=True,

Reply via email to