areusch commented on a change in pull request #9074:
URL: https://github.com/apache/tvm/pull/9074#discussion_r728244569
##########
File path: python/tvm/driver/tvmc/frontends.py
##########
@@ -88,10 +96,54 @@ def import_keras():
from tensorflow import keras
return tf, keras
+ except ImportError:
Review comment:
aside from the sys.stderr redirection above and keras/tf dual-import
below, i almost think these `import_` functions are similar enough we should
consolidate them into one and remove `create_import_error_string` (e.g. just do
it from the one place). what do you think about doing this? i could be
persuaded either way, but it seems like `import_frontend_package("onnx")` could
be implemented like:
```
def import_frontend_package(pkg_name : str, from_pkg_name=None :
Optional[str], hide_stderr=False : bool) -> Module:
try:
if hide_stderr:
stderr = sys.stderr
sys.stderr = open(os.devnull, "w")
return importlib.import_module(pkg_name, from=from_pkg_naem)
except ImportError:
raise TVMCException(f"Error importing required frontend package
{pkg_name}")
finally:
if hide_stderr:
sys.stderr = stderr
```
another property i like about this is that it's very clear (once you know
the pattern) exactly what's being passed to "importlib.import_module()" whereas
the current structure could change between e.g. import_keras, import_onnx, etc.
##########
File path: python/tvm/driver/tvmc/frontends.py
##########
@@ -88,10 +96,54 @@ def import_keras():
from tensorflow import keras
return tf, keras
+ except ImportError:
+ raise TVMCException(create_import_error_string("Tensorflow",
"tensorflow"))
finally:
sys.stderr = stderr
+def import_onnx():
+ """Lazy import function for onnx"""
+ try:
+ # pylint: disable=C0415
Review comment:
i think (but am not 100% sure) that including this on its own line
disables that check for the whole file. possible to move it onto the next line?
```suggestion
import onnx as _onnx # pylint: disable=C0415
```
##########
File path: tests/python/driver/tvmc/test_frontends.py
##########
@@ -211,3 +225,49 @@ def
test_load_model___wrong_language__to_pytorch(tflite_mobilenet_v1_1_quant):
model_format="pytorch",
shape_dict={"input": [1, 3, 224, 224]},
)
+
+
+def test_import_keras_friendly_message(keras_resnet50, monkeypatch):
+ # some CI environments wont offer TFLite, so skip in case it is not present
Review comment:
can you fix the tflite part of the comment?
##########
File path: python/tvm/driver/tvmc/frontends.py
##########
@@ -88,10 +96,54 @@ def import_keras():
from tensorflow import keras
return tf, keras
+ except ImportError:
+ raise TVMCException(create_import_error_string("Tensorflow",
"tensorflow"))
finally:
sys.stderr = stderr
+def import_onnx():
+ """Lazy import function for onnx"""
+ try:
+ # pylint: disable=C0415
+ import onnx as _onnx
Review comment:
wondering if there is a reason to keep "as _onnx" (and other "as") here?
did it raise some linter issues?
##########
File path: python/tvm/driver/tvmc/frontends.py
##########
@@ -88,10 +96,54 @@ def import_keras():
from tensorflow import keras
return tf, keras
+ except ImportError:
+ raise TVMCException(create_import_error_string("Tensorflow",
"tensorflow"))
finally:
sys.stderr = stderr
+def import_onnx():
+ """Lazy import function for onnx"""
+ try:
+ # pylint: disable=C0415
+ import onnx as _onnx
+ except ImportError:
+ raise TVMCException(create_import_error_string("ONNX", "onnx"))
+ return _onnx
+
+
+def import_tensorflow():
+ """Lazy import function for tensorflow"""
+ try:
+ # pylint: disable=C0415
+ import tensorflow as tf
+ except ImportError:
+ raise TVMCException(create_import_error_string("Tensorflow",
"tensorflow"))
+ return tf
+
+
+def import_torch():
+ """Lazy import function for torch"""
+ try:
+ # pylint: disable=C0415
+ import torch as tc
+ except ImportError:
+ raise TVMCException(create_import_error_string("Torch", "torch"))
Review comment:
i'd suggest we either say the following or eliminate the first parameter
here.
```suggestion
raise TVMCException(create_import_error_string("PyTorch", "torch"))
```
##########
File path: tests/python/driver/tvmc/test_frontends.py
##########
@@ -211,3 +212,43 @@ def
test_load_model___wrong_language__to_pytorch(tflite_mobilenet_v1_1_quant):
model_format="pytorch",
shape_dict={"input": [1, 3, 224, 224]},
)
+
+
[email protected]("tvm.driver.tvmc.frontends.import_keras")
+def test_import_keras_friendly_message(import_keras_mock, keras_resnet50):
Review comment:
@ophirfrish could you look at this one? i think we don't need
importorskip after all.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]