This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new f942d19788 [TVMC] Fix error while compile paddle model with tvmc
(#11730)
f942d19788 is described below
commit f942d197889ee93fc112ca346ca8366d29933fac
Author: Jason <[email protected]>
AuthorDate: Thu Jun 16 01:02:04 2022 +0800
[TVMC] Fix error while compile paddle model with tvmc (#11730)
The tvmc command will throw a error while the passed path of model is not
exist, But for PaddlePaddle model, it contains 2 file model_name.pdmodel and
model_name.pdiparams, we only pass the prefix like inference_model/model_name.
This pr is same with https://github.com/apache/tvm/pull/11108
Since the origin PR didn't update for a long time, I send this new PR
---
python/tvm/driver/tvmc/frontends.py | 14 ++++++++++--
tests/python/driver/tvmc/conftest.py | 2 +-
tests/python/driver/tvmc/test_command_line.py | 33 +++++++++++++++++++++++++++
3 files changed, 46 insertions(+), 3 deletions(-)
diff --git a/python/tvm/driver/tvmc/frontends.py
b/python/tvm/driver/tvmc/frontends.py
index a3222782c6..cfe5a4ac7b 100644
--- a/python/tvm/driver/tvmc/frontends.py
+++ b/python/tvm/driver/tvmc/frontends.py
@@ -21,6 +21,7 @@ Frontend classes do lazy-loading of modules on purpose, to
reduce time spent on
loading the tool.
"""
import logging
+import os
import sys
import importlib
from abc import ABC
@@ -268,7 +269,7 @@ class PaddleFrontend(Frontend):
@staticmethod
def suffixes():
- return ["pdmodel", "pdiparams"]
+ return ["pdmodel"]
def load(self, path, shape_dict=None, **kwargs):
# pylint: disable=C0415
@@ -277,9 +278,18 @@ class PaddleFrontend(Frontend):
paddle.enable_static()
paddle.disable_signal_handler()
+ if not os.path.exists(path):
+ raise TVMCException("File {} is not exist.".format(path))
+ if not path.endswith(".pdmodel"):
+ raise TVMCException("Path of model file should be endwith suffixes
'.pdmodel'.")
+ prefix = "".join(path.strip().split(".")[:-1])
+ params_file_path = prefix + ".pdiparams"
+ if not os.path.exists(params_file_path):
+ raise TVMCException("File {} is not
exist.".format(params_file_path))
+
# pylint: disable=E1101
exe = paddle.static.Executor(paddle.CPUPlace())
- prog, _, _ = paddle.static.load_inference_model(path, exe)
+ prog, _, _ = paddle.static.load_inference_model(prefix, exe)
return relay.frontend.from_paddle(prog, shape_dict=shape_dict,
**kwargs)
diff --git a/tests/python/driver/tvmc/conftest.py
b/tests/python/driver/tvmc/conftest.py
index efce13e38c..fcf079620e 100644
--- a/tests/python/driver/tvmc/conftest.py
+++ b/tests/python/driver/tvmc/conftest.py
@@ -160,7 +160,7 @@ def paddle_resnet50(tmpdir_factory):
model_url = "paddle_resnet50.tar"
model_file = download_and_untar(
"{}/{}".format(base_url, model_url),
- "paddle_resnet50/model",
+ "paddle_resnet50/model.pdmodel",
temp_dir=tmpdir_factory.mktemp("data"),
)
return model_file
diff --git a/tests/python/driver/tvmc/test_command_line.py
b/tests/python/driver/tvmc/test_command_line.py
index 5b15492aa4..0fddb7073f 100644
--- a/tests/python/driver/tvmc/test_command_line.py
+++ b/tests/python/driver/tvmc/test_command_line.py
@@ -20,8 +20,10 @@ import pytest
import shutil
from pytest_lazyfixture import lazy_fixture
+from unittest import mock
from tvm.driver.tvmc.main import _main
from tvm.driver.tvmc.model import TVMCException
+from tvm.driver.tvmc import compiler
@pytest.mark.skipif(
@@ -155,3 +157,34 @@ def test_tvmc_tune_file_check(capsys, invalid_input):
)
on_assert_error = f"'tvmc tune' failed to check invalid FILE:
{invalid_input}"
assert captured.err == expected_err, on_assert_error
+
+
[email protected]
+def paddle_model(paddle_resnet50):
+ # If we can't import "paddle" module, skip testing paddle as the input
model.
+ if pytest.importorskip("paddle", reason="'paddle' module not installed"):
+ return paddle_resnet50
+
+
[email protected](
+ "model",
+ [
+ lazy_fixture("paddle_model"),
+ ],
+)
+# compile_model() can take too long and is tested elsewhere, hence it's mocked
below
[email protected](compiler, "compile_model")
+# @mock.patch.object(compiler, "compile_model")
+def test_tvmc_compile_input_model(mock_compile_model, tmpdir_factory, model):
+
+ output_dir = tmpdir_factory.mktemp("output")
+ output_file = output_dir / "model.tar"
+
+ compile_cmd = (
+ f"tvmc compile --target 'llvm' {model} --model-format paddle --output
{output_file}"
+ )
+ run_arg = compile_cmd.split(" ")[1:]
+
+ _main(run_arg)
+
+ mock_compile_model.assert_called_once()