This is an automated email from the ASF dual-hosted git repository.

riteshghorse pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 4fbcfc57664 exception handling for loading models (#27186)
4fbcfc57664 is described below

commit 4fbcfc57664ad467dbae60e5ad54a2ba6a1334de
Author: Ritesh Ghorse <[email protected]>
AuthorDate: Wed Jun 21 09:39:07 2023 -0400

    exception handling for loading models (#27186)
---
 sdks/python/apache_beam/ml/inference/tensorflow_inference.py  | 11 ++++++++++-
 .../apache_beam/ml/inference/tensorflow_inference_test.py     |  6 ++++++
 2 files changed, 16 insertions(+), 1 deletion(-)

diff --git a/sdks/python/apache_beam/ml/inference/tensorflow_inference.py 
b/sdks/python/apache_beam/ml/inference/tensorflow_inference.py
index d1f236dc53b..991ae971d9e 100644
--- a/sdks/python/apache_beam/ml/inference/tensorflow_inference.py
+++ b/sdks/python/apache_beam/ml/inference/tensorflow_inference.py
@@ -56,7 +56,15 @@ class ModelType(enum.Enum):
 
 
 def _load_model(model_uri, custom_weights, load_model_args):
-  model = tf.keras.models.load_model(hub.resolve(model_uri), **load_model_args)
+  try:
+    model = tf.keras.models.load_model(
+        hub.resolve(model_uri), **load_model_args)
+  except Exception as e:
+    raise ValueError(
+        "Unable to load the TensorFlow model: {exception}. Make sure you've \
+        saved the model with TF2 format. Check out the list of TF2 Models on \
+        TensorFlow Hub - 
https://tfhub.dev/s?subtype=module,placeholder&tf-version=tf2.";  # pylint: 
disable=line-too-long
+        .format(exception=e))
   if custom_weights:
     model.load_weights(custom_weights)
   return model
@@ -156,6 +164,7 @@ class TFModelHandlerNumpy(ModelHandler[numpy.ndarray,
             "Callable create_model_fn must be passed"
             "with ModelType.SAVED_WEIGHTS")
       return _load_model_from_weights(self._create_model_fn, self._model_uri)
+
     return _load_model(
         self._model_uri, self._custom_weights, self._load_model_args)
 
diff --git a/sdks/python/apache_beam/ml/inference/tensorflow_inference_test.py 
b/sdks/python/apache_beam/ml/inference/tensorflow_inference_test.py
index 31dde594010..52c525cc0ea 100644
--- a/sdks/python/apache_beam/ml/inference/tensorflow_inference_test.py
+++ b/sdks/python/apache_beam/ml/inference/tensorflow_inference_test.py
@@ -363,6 +363,12 @@ class TFRunInferenceTest(unittest.TestCase):
     for actual, expected in zip(inferences, expected_predictions):
       self.assertTrue(_compare_tensor_prediction_result(actual[1], 
expected[1]))
 
+  def test_load_model_exception(self):
+    with self.assertRaises(ValueError):
+      tensorflow_inference._load_model(
+          
"https://tfhub.dev/google/imagenet/mobilenet_v1_075_192/quantops/classification/3";,
 # pylint: disable=line-too-long
+          None, {})
+
 
 @pytest.mark.uses_tf
 class TFRunInferenceTestWithMocks(unittest.TestCase):

Reply via email to