This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new f942d19788 [TVMC] Fix error while compile paddle model with tvmc 
(#11730)
f942d19788 is described below

commit f942d197889ee93fc112ca346ca8366d29933fac
Author: Jason <[email protected]>
AuthorDate: Thu Jun 16 01:02:04 2022 +0800

    [TVMC] Fix error while compile paddle model with tvmc (#11730)
    
    The tvmc command will throw a error while the passed path of model is not 
exist, But for PaddlePaddle model, it contains 2 file model_name.pdmodel and 
model_name.pdiparams, we only pass the prefix like inference_model/model_name.
    
    This pr is same with https://github.com/apache/tvm/pull/11108
    Since the origin PR didn't update for a long time, I send this new PR
---
 python/tvm/driver/tvmc/frontends.py           | 14 ++++++++++--
 tests/python/driver/tvmc/conftest.py          |  2 +-
 tests/python/driver/tvmc/test_command_line.py | 33 +++++++++++++++++++++++++++
 3 files changed, 46 insertions(+), 3 deletions(-)

diff --git a/python/tvm/driver/tvmc/frontends.py 
b/python/tvm/driver/tvmc/frontends.py
index a3222782c6..cfe5a4ac7b 100644
--- a/python/tvm/driver/tvmc/frontends.py
+++ b/python/tvm/driver/tvmc/frontends.py
@@ -21,6 +21,7 @@ Frontend classes do lazy-loading of modules on purpose, to 
reduce time spent on
 loading the tool.
 """
 import logging
+import os
 import sys
 import importlib
 from abc import ABC
@@ -268,7 +269,7 @@ class PaddleFrontend(Frontend):
 
     @staticmethod
     def suffixes():
-        return ["pdmodel", "pdiparams"]
+        return ["pdmodel"]
 
     def load(self, path, shape_dict=None, **kwargs):
         # pylint: disable=C0415
@@ -277,9 +278,18 @@ class PaddleFrontend(Frontend):
         paddle.enable_static()
         paddle.disable_signal_handler()
 
+        if not os.path.exists(path):
+            raise TVMCException("File {} is not exist.".format(path))
+        if not path.endswith(".pdmodel"):
+            raise TVMCException("Path of model file should be endwith suffixes 
'.pdmodel'.")
+        prefix = "".join(path.strip().split(".")[:-1])
+        params_file_path = prefix + ".pdiparams"
+        if not os.path.exists(params_file_path):
+            raise TVMCException("File {} is not 
exist.".format(params_file_path))
+
         # pylint: disable=E1101
         exe = paddle.static.Executor(paddle.CPUPlace())
-        prog, _, _ = paddle.static.load_inference_model(path, exe)
+        prog, _, _ = paddle.static.load_inference_model(prefix, exe)
 
         return relay.frontend.from_paddle(prog, shape_dict=shape_dict, 
**kwargs)
 
diff --git a/tests/python/driver/tvmc/conftest.py 
b/tests/python/driver/tvmc/conftest.py
index efce13e38c..fcf079620e 100644
--- a/tests/python/driver/tvmc/conftest.py
+++ b/tests/python/driver/tvmc/conftest.py
@@ -160,7 +160,7 @@ def paddle_resnet50(tmpdir_factory):
     model_url = "paddle_resnet50.tar"
     model_file = download_and_untar(
         "{}/{}".format(base_url, model_url),
-        "paddle_resnet50/model",
+        "paddle_resnet50/model.pdmodel",
         temp_dir=tmpdir_factory.mktemp("data"),
     )
     return model_file
diff --git a/tests/python/driver/tvmc/test_command_line.py 
b/tests/python/driver/tvmc/test_command_line.py
index 5b15492aa4..0fddb7073f 100644
--- a/tests/python/driver/tvmc/test_command_line.py
+++ b/tests/python/driver/tvmc/test_command_line.py
@@ -20,8 +20,10 @@ import pytest
 import shutil
 
 from pytest_lazyfixture import lazy_fixture
+from unittest import mock
 from tvm.driver.tvmc.main import _main
 from tvm.driver.tvmc.model import TVMCException
+from tvm.driver.tvmc import compiler
 
 
 @pytest.mark.skipif(
@@ -155,3 +157,34 @@ def test_tvmc_tune_file_check(capsys, invalid_input):
     )
     on_assert_error = f"'tvmc tune' failed to check invalid FILE: 
{invalid_input}"
     assert captured.err == expected_err, on_assert_error
+
+
[email protected]
+def paddle_model(paddle_resnet50):
+    # If we can't import "paddle" module, skip testing paddle as the input 
model.
+    if pytest.importorskip("paddle", reason="'paddle' module not installed"):
+        return paddle_resnet50
+
+
[email protected](
+    "model",
+    [
+        lazy_fixture("paddle_model"),
+    ],
+)
+# compile_model() can take too long and is tested elsewhere, hence it's mocked 
below
[email protected](compiler, "compile_model")
+# @mock.patch.object(compiler, "compile_model")
+def test_tvmc_compile_input_model(mock_compile_model, tmpdir_factory, model):
+
+    output_dir = tmpdir_factory.mktemp("output")
+    output_file = output_dir / "model.tar"
+
+    compile_cmd = (
+        f"tvmc compile --target 'llvm' {model} --model-format paddle --output 
{output_file}"
+    )
+    run_arg = compile_cmd.split(" ")[1:]
+
+    _main(run_arg)
+
+    mock_compile_model.assert_called_once()

Reply via email to