This is an automated email from the ASF dual-hosted git repository.
riteshghorse pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 4fbcfc57664 exception handling for loading models (#27186)
4fbcfc57664 is described below
commit 4fbcfc57664ad467dbae60e5ad54a2ba6a1334de
Author: Ritesh Ghorse <[email protected]>
AuthorDate: Wed Jun 21 09:39:07 2023 -0400
exception handling for loading models (#27186)
---
sdks/python/apache_beam/ml/inference/tensorflow_inference.py | 11 ++++++++++-
.../apache_beam/ml/inference/tensorflow_inference_test.py | 6 ++++++
2 files changed, 16 insertions(+), 1 deletion(-)
diff --git a/sdks/python/apache_beam/ml/inference/tensorflow_inference.py
b/sdks/python/apache_beam/ml/inference/tensorflow_inference.py
index d1f236dc53b..991ae971d9e 100644
--- a/sdks/python/apache_beam/ml/inference/tensorflow_inference.py
+++ b/sdks/python/apache_beam/ml/inference/tensorflow_inference.py
@@ -56,7 +56,15 @@ class ModelType(enum.Enum):
def _load_model(model_uri, custom_weights, load_model_args):
- model = tf.keras.models.load_model(hub.resolve(model_uri), **load_model_args)
+ try:
+ model = tf.keras.models.load_model(
+ hub.resolve(model_uri), **load_model_args)
+ except Exception as e:
+ raise ValueError(
+ "Unable to load the TensorFlow model: {exception}. Make sure you've \
+ saved the model with TF2 format. Check out the list of TF2 Models on \
+ TensorFlow Hub -
https://tfhub.dev/s?subtype=module,placeholder&tf-version=tf2." # pylint:
disable=line-too-long
+ .format(exception=e))
if custom_weights:
model.load_weights(custom_weights)
return model
@@ -156,6 +164,7 @@ class TFModelHandlerNumpy(ModelHandler[numpy.ndarray,
"Callable create_model_fn must be passed"
"with ModelType.SAVED_WEIGHTS")
return _load_model_from_weights(self._create_model_fn, self._model_uri)
+
return _load_model(
self._model_uri, self._custom_weights, self._load_model_args)
diff --git a/sdks/python/apache_beam/ml/inference/tensorflow_inference_test.py
b/sdks/python/apache_beam/ml/inference/tensorflow_inference_test.py
index 31dde594010..52c525cc0ea 100644
--- a/sdks/python/apache_beam/ml/inference/tensorflow_inference_test.py
+++ b/sdks/python/apache_beam/ml/inference/tensorflow_inference_test.py
@@ -363,6 +363,12 @@ class TFRunInferenceTest(unittest.TestCase):
for actual, expected in zip(inferences, expected_predictions):
self.assertTrue(_compare_tensor_prediction_result(actual[1],
expected[1]))
+ def test_load_model_exception(self):
+ with self.assertRaises(ValueError):
+ tensorflow_inference._load_model(
+
"https://tfhub.dev/google/imagenet/mobilenet_v1_075_192/quantops/classification/3",
# pylint: disable=line-too-long
+ None, {})
+
@pytest.mark.uses_tf
class TFRunInferenceTestWithMocks(unittest.TestCase):