This is an automated email from the ASF dual-hosted git repository.
echuraev pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b6502f4e27 Fix keras version problem (#15265)
b6502f4e27 is described below
commit b6502f4e278da391719155936aeefb6544115c1f
Author: Shikamaru:) <[email protected]>
AuthorDate: Fri Jul 14 18:20:16 2023 +0800
Fix keras version problem (#15265)
* Fix keras version problem
* Fix keras version problem
* Fix keras version problem
* Fix keras version problem
* Fix keras version problem
* Fix keras version problem
* Fix keras version problem
---
python/tvm/relay/frontend/keras.py | 15 +++++++++++----
1 file changed, 11 insertions(+), 4 deletions(-)
diff --git a/python/tvm/relay/frontend/keras.py
b/python/tvm/relay/frontend/keras.py
index 0d932cadcc..1913d4a268 100644
--- a/python/tvm/relay/frontend/keras.py
+++ b/python/tvm/relay/frontend/keras.py
@@ -1526,12 +1526,19 @@ def from_keras(model, shape=None, layout="NCHW"):
raise ValueError("Keras frontend currently supports tensorflow
backend only.")
if keras.backend.image_data_format() != "channels_last":
raise ValueError("Keras frontend currently supports data_format =
channels_last only.")
- expected_model_class = keras.engine.training.Model
- if hasattr(keras.engine, "InputLayer"):
- input_layer_class = keras.engine.InputLayer
+ try:
+ import keras.engine as E
+ except ImportError:
+ try:
+ import keras.src.engine as E
+ except ImportError:
+ raise ImportError("Cannot find Keras's engine")
+ expected_model_class = E.training.Model
+ if hasattr(E, "InputLayer"):
+ input_layer_class = E.InputLayer
else:
# TFlite >=2.6
- input_layer_class = keras.engine.input_layer.InputLayer
+ input_layer_class = E.input_layer.InputLayer
else:
# Importing from Tensorflow Keras (tf.keras)
try: