jwfromm commented on a change in pull request #10722:
URL: https://github.com/apache/tvm/pull/10722#discussion_r839009864
##########
File path: python/tvm/driver/tvmc/model.py
##########
@@ -281,7 +320,9 @@ def export_package(
if output_format == "mlf" and cross:
raise TVMCException("Specifying the MLF output and a cross
compiler is not supported.")
- if output_format in ["so", "tar"]:
+ if use_vm:
Review comment:
Can we check the type of `executor_factory` instead of having a `use_vm`
input for this function?
##########
File path: python/tvm/driver/tvmc/compiler.py
##########
@@ -319,19 +336,18 @@ def compile_model(
dump_code = [dump_code]
dumps = {}
for source_type in dump_code:
- lib = graph_module.get_lib()
+ if use_vm:
+ _, lib = graph_module.save()
Review comment:
I think we can instead do `lib = graph_module.lib()` to be a little more
efficient.
##########
File path: tests/python/driver/tvmc/test_runner.py
##########
@@ -79,11 +79,47 @@ def test_run_tflite_module__with_profile__valid_input(
pytest.importorskip("tflite")
inputs = np.load(imagenet_cat)
+ input_dict = {"input": inputs["input"].astype("uint8")}
tflite_compiled_model = tflite_compile_model(tflite_mobilenet_v1_1_quant)
result = tvmc.run(
tflite_compiled_model,
- inputs=inputs,
+ inputs=input_dict,
+ hostname=None,
+ device="cpu",
+ profile=True,
+ )
+
+ # collect the top 5 results
+ top_5_results = get_top_results(result, 5)
+ top_5_ids = top_5_results[0]
+
+ # IDs were collected from this reference:
+ # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/
+ # java/demo/app/src/main/assets/labels_mobilenet_quant_v1_224.txt
+ tiger_cat_mobilenet_id = 283
+
+ assert (
+ tiger_cat_mobilenet_id in top_5_ids
+ ), "tiger cat is expected in the top-5 for mobilenet v1"
+ assert type(result.outputs) is dict
+ assert type(result.times) is BenchmarkResult
+ assert "output_0" in result.outputs.keys()
+
+
+def test_run_tflite_module__with_profile_vm__valid_input(
+ tflite_mobilenet_v1_1_quant, tflite_compile_model, imagenet_cat
+):
+ # some CI environments wont offer TFLite, so skip in case it is not present
+ pytest.importorskip("tflite")
+
+ inputs = np.load(imagenet_cat)
+ input_dict = {"input": inputs["input"].astype("uint8")}
+
+ tflite_compiled_model = tflite_compile_model(tflite_mobilenet_v1_1_quant,
use_vm=True)
Review comment:
If the only difference between these two tests is `use_vm=True`, we
should make a new function called
`verify_run_tflite_module__with_profile__valid_input` that has a `use_vm`
argument then just have something like this
```
def test_run_tflite_module__with_profile__valid_input(
tflite_mobilenet_v1_1_quant, tflite_compile_model, imagenet_cat):
verify_run_tflite_module__with_profile__valid_input(tflite_mobilenet_v1_1_quant,
tflite_compile_model, imagenet_cat, use_vm=False)
verify_run_tflite_module__with_profile__valid_input(tflite_mobilenet_v1_1_quant,
tflite_compile_model, imagenet_cat, use_vm=True)
```
##########
File path: tests/python/driver/tvmc/test_model.py
##########
@@ -45,6 +46,31 @@ def test_tvmc_workflow(keras_simple):
assert "output_0" in result.outputs.keys()
+def test_tvmc_workflow_use_vm(keras_simple):
+ pytest.importorskip("tensorflow")
+ import tensorflow as tf
+
+ # Reset so the input name remains consistent across unit test runs
+ tf.keras.backend.clear_session()
+
+ tvmc_model = tvmc.load(keras_simple)
+ tuning_records = tvmc.tune(tvmc_model, target="llvm",
enable_autoscheduler=True, trials=2)
+ tvmc_package = tvmc.compile(
+ tvmc_model, tuning_records=tuning_records, target="llvm", use_vm=True
+ )
+
+ input_dict = {"input_1": np.random.uniform(size=(1, 32, 32,
3)).astype("float32")}
+ result = tvmc.run(tvmc_package, device="cpu", end_to_end=True,
inputs=input_dict)
+
+ assert type(tvmc_model) is TVMCModel
+ assert type(tvmc_package) is TVMCPackage
+ assert type(result) is TVMCResult
+ assert path.exists(tuning_records)
+ assert type(result.outputs) is dict
+ assert type(result.times) is BenchmarkResult
+ assert "output_0" in result.outputs.keys()
+
+
Review comment:
I think we should also modify `test_save_load_model` to make sure that
saving and loading with `use_vm=True` works as expected.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]