AnandInguva commented on code in PR #25368:
URL: https://github.com/apache/beam/pull/25368#discussion_r1106245269


##########
sdks/python/apache_beam/examples/inference/tensorflow_imagenet_segmentation.py:
##########
@@ -0,0 +1,129 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import argparse
+import logging
+from typing import Iterable
+from typing import Iterator
+
+import numpy
+
+import apache_beam as beam
+import tensorflow as tf
+from apache_beam.ml.inference.base import PredictionResult
+from apache_beam.ml.inference.base import RunInference
+from apache_beam.ml.inference.tensorflow_inference import TFModelHandlerTensor
+from apache_beam.options.pipeline_options import PipelineOptions
+from apache_beam.options.pipeline_options import SetupOptions
+from apache_beam.runners.runner import PipelineResult
+
+
+class PostProcessor(beam.DoFn):
+  """Process the PredictionResult to get the predicted label.
+  Returns predicted label.
+  """
+  def process(self, element: PredictionResult) -> Iterable[str]:
+    print("prediction result---->: %", element)
+    predicted_class = numpy.argmax(element.inference[0], axis=-1)
+    labels_path = tf.keras.utils.get_file(
+        'ImageNetLabels.txt',
+        
'https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt'
  # pylint: disable=line-too-long

Review Comment:
   Can we use parentheses instead suppressing the warning?



##########
sdks/python/apache_beam/ml/inference/tensorflow_inference_it_test.py:
##########
@@ -0,0 +1,112 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""End-to-End test for Tensorflow Inference"""
+
+import logging
+import unittest
+import uuid
+
+import pytest
+
+from apache_beam.io.filesystems import FileSystems
+from apache_beam.testing.test_pipeline import TestPipeline
+
+# pylint: disable=ungrouped-imports
+try:
+  import tensorflow as tf
+  from apache_beam.examples.inference import tensorflow_imagenet_segmentation
+  from apache_beam.examples.inference import tensorflow_mnist_classification
+except ImportError as e:
+  tf = None
+
+
+def process_outputs(filepath):
+  with FileSystems().open(filepath) as f:
+    lines = f.readlines()
+  lines = [l.decode('utf-8').strip('\n') for l in lines]
+  return lines
+
+
[email protected](
+    tf is None, 'Missing dependencies. '
+    'Test depends on tensorflow')
[email protected]_tf
[email protected]_postcommit
+class TensorflowInference(unittest.TestCase):
+  def test_tf_mnist_classification(self):
+    test_pipeline = TestPipeline(is_integration_test=True)
+    input_file = 
'gs://clouddfe-riteshghorse/tf/mnist/dataset/testing_inputs_it_mnist_data.csv'  
# pylint: disable=line-too-long
+    output_file_dir = 'gs://clouddfe-riteshghorse/tf/mnist/output/'
+    output_file = '/'.join([output_file_dir, str(uuid.uuid4()), 'result.txt'])
+    model_path = 'gs://clouddfe-riteshghorse/tf/mnist/model/'
+    extra_opts = {
+        'input': input_file,
+        'output': output_file,
+        'model_path': model_path,
+    }
+    tensorflow_mnist_classification.run(
+        test_pipeline.get_full_options_as_args(**extra_opts),
+        save_main_session=False)
+    self.assertEqual(FileSystems().exists(output_file), True)
+
+    expected_output_filepath = 
'gs://clouddfe-riteshghorse/tf/mnist/output/testing_expected_outputs_test_sklearn_mnist_classification_actuals.txt'
  # pylint: disable=line-too-long
+    expected_outputs = process_outputs(expected_output_filepath)
+
+    predicted_outputs = process_outputs(output_file)
+    self.assertEqual(len(expected_outputs), len(predicted_outputs))
+
+    predictions_dict = {}
+    for i in range(len(predicted_outputs)):
+      true_label, prediction = predicted_outputs[i].split(',')
+      predictions_dict[true_label] = prediction
+
+    for i in range(len(expected_outputs)):
+      true_label, expected_prediction = expected_outputs[i].split(',')
+      self.assertEqual(predictions_dict[true_label], expected_prediction)
+
+  def test_tf_imagenet_image_classification(self):
+    test_pipeline = TestPipeline(is_integration_test=True)
+    input_file = 
'gs://clouddfe-riteshghorse/tf/imagenet/input/input_labels.txt'  # pylint: 
disable=line-too-long

Review Comment:
   Similar comment as above to use parentheses 
   ```suggestion
       input_file = ('gs://clouddfe-riteshghorse/tf
       /imagenet/input/input_labels.txt') 
   ```
   May need formatting. 



##########
sdks/python/apache_beam/examples/inference/tensorflow_imagenet_segmentation.py:
##########
@@ -0,0 +1,129 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import argparse
+import logging
+from typing import Iterable
+from typing import Iterator
+
+import numpy
+
+import apache_beam as beam
+import tensorflow as tf
+from apache_beam.ml.inference.base import PredictionResult
+from apache_beam.ml.inference.base import RunInference
+from apache_beam.ml.inference.tensorflow_inference import TFModelHandlerTensor
+from apache_beam.options.pipeline_options import PipelineOptions
+from apache_beam.options.pipeline_options import SetupOptions
+from apache_beam.runners.runner import PipelineResult
+
+
+class PostProcessor(beam.DoFn):
+  """Process the PredictionResult to get the predicted label.
+  Returns predicted label.
+  """
+  def process(self, element: PredictionResult) -> Iterable[str]:
+    print("prediction result---->: %", element)
+    predicted_class = numpy.argmax(element.inference[0], axis=-1)
+    labels_path = tf.keras.utils.get_file(
+        'ImageNetLabels.txt',
+        
'https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt'
  # pylint: disable=line-too-long
+    )
+    imagenet_labels = numpy.array(open(labels_path).read().splitlines())
+    predicted_class_name = imagenet_labels[predicted_class]
+    return predicted_class_name.title()
+
+
+def parse_known_args(argv):
+  """Parses args for the workflow."""
+  parser = argparse.ArgumentParser()
+  parser.add_argument(
+      '--input',
+      dest='input',
+      required=True,
+      help='Path to the text file containing image names.')
+  parser.add_argument(
+      '--output',
+      dest='output',
+      required=True,
+      help='Path to save output predictions.')
+  parser.add_argument(
+      '--model_path',
+      dest='model_path',
+      required=True,
+      help='Path to load the Tensorflow model for Inference.')
+  parser.add_argument(
+      '--image_dir', help='Path to the directory where images are stored.')
+  return parser.parse_known_args(argv)
+
+
+def filter_empty_lines(text: str) -> Iterator[str]:
+  if len(text.strip()) > 0:
+    yield text
+
+
+def read_image(image_name, image_dir):
+  from PIL import Image

Review Comment:
   why is this import local to function? 



##########
sdks/python/apache_beam/examples/inference/tensorflow_imagenet_segmentation.py:
##########
@@ -0,0 +1,129 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import argparse
+import logging
+from typing import Iterable
+from typing import Iterator
+
+import numpy
+
+import apache_beam as beam
+import tensorflow as tf
+from apache_beam.ml.inference.base import PredictionResult
+from apache_beam.ml.inference.base import RunInference
+from apache_beam.ml.inference.tensorflow_inference import TFModelHandlerTensor
+from apache_beam.options.pipeline_options import PipelineOptions
+from apache_beam.options.pipeline_options import SetupOptions
+from apache_beam.runners.runner import PipelineResult
+
+
+class PostProcessor(beam.DoFn):
+  """Process the PredictionResult to get the predicted label.
+  Returns predicted label.
+  """
+  def process(self, element: PredictionResult) -> Iterable[str]:
+    print("prediction result---->: %", element)

Review Comment:
   nit: remove print or use `logging.info`?



##########
sdks/python/apache_beam/ml/inference/tensorflow_inference.py:
##########
@@ -0,0 +1,253 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# pytype: skip-file
+
+import enum
+import sys
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import Iterable
+from typing import Optional
+from typing import Sequence
+from typing import Union
+
+import numpy
+
+import tensorflow as tf
+import tensorflow_hub as hub
+from apache_beam.ml.inference import utils
+from apache_beam.ml.inference.base import ModelHandler
+from apache_beam.ml.inference.base import PredictionResult
+
+__all__ = [
+    'TFModelHandlerNumpy',
+    'TFModelHandlerTensor',
+]
+
+TensorInferenceFn = Callable[[
+    tf.Module,
+    Sequence[Union[numpy.ndarray, tf.Tensor]],
+    Optional[Dict[str, Any]],
+    Optional[str]
+],
+                             Iterable[PredictionResult]]
+
+
+class ModelType(enum.Enum):
+  """Defines how a model file should be loaded."""
+  SAVED_MODEL = 1
+
+
+def _load_model(model_uri, model_type):
+  if model_type == ModelType.SAVED_MODEL:
+    return tf.keras.models.load_model(hub.resolve(model_uri))
+  else:
+    raise AssertionError('Unsupported model type for loading.')
+
+
+def default_numpy_inference_fn(
+    model: tf.Module,
+    batch: Sequence[numpy.ndarray],
+    inference_args: Optional[Dict[str, Any]] = None,
+    model_id: Optional[str] = None) -> Iterable[PredictionResult]:
+  vectorized_batch = numpy.stack(batch, axis=0)
+  if inference_args:
+    predictions = model(vectorized_batch, **inference_args)
+  else:
+    predictions = model(vectorized_batch)
+  return utils._convert_to_result(batch, predictions, model_id)
+
+
+def default_tensor_inference_fn(
+    model: tf.Module,
+    batch: Sequence[tf.Tensor],
+    inference_args: Optional[Dict[str, Any]] = None,
+    model_id: Optional[str] = None) -> Iterable[PredictionResult]:
+  vectorized_batch = tf.stack(batch, axis=0)
+  if inference_args:
+    predictions = model(vectorized_batch, **inference_args)
+  else:
+    predictions = model(vectorized_batch)
+  return utils._convert_to_result(batch, predictions, model_id)
+
+
+class TFModelHandlerNumpy(ModelHandler[numpy.ndarray,
+                                       PredictionResult,
+                                       tf.Module]):
+  def __init__(
+      self,
+      model_uri: str,
+      model_type: ModelType = ModelType.SAVED_MODEL,
+      *,
+      inference_fn: TensorInferenceFn = default_numpy_inference_fn):
+    """Implementation of the ModelHandler interface for Tensorflow.
+
+    Example Usage::
+
+      pcoll | RunInference(TFModelHandlerNumpy(model_uri="my_uri"))
+
+    See https://www.tensorflow.org/tutorials/keras/save_and_load for details.
+
+    Args:
+        model_uri (str): path to the trained model.
+        model_type (ModelType): type of model to be loaded.
+          Defaults to SAVED_MODEL.
+        inference_fn (TensorInferenceFn, Optional): inference function to use
+          during RunInference. Defaults to default_numpy_inference_fn.
+
+    **Supported Versions:** RunInference APIs in Apache Beam have been tested
+    with Tensorflow 2.11.

Review Comment:
   Since we added these two versions in the tox tests.



##########
sdks/python/apache_beam/ml/inference/tensorflow_inference.py:
##########
@@ -0,0 +1,253 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# pytype: skip-file
+
+import enum
+import sys
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import Iterable
+from typing import Optional
+from typing import Sequence
+from typing import Union
+
+import numpy
+
+import tensorflow as tf
+import tensorflow_hub as hub
+from apache_beam.ml.inference import utils
+from apache_beam.ml.inference.base import ModelHandler
+from apache_beam.ml.inference.base import PredictionResult
+
+__all__ = [
+    'TFModelHandlerNumpy',
+    'TFModelHandlerTensor',
+]
+
+TensorInferenceFn = Callable[[
+    tf.Module,
+    Sequence[Union[numpy.ndarray, tf.Tensor]],
+    Optional[Dict[str, Any]],
+    Optional[str]
+],
+                             Iterable[PredictionResult]]
+
+
+class ModelType(enum.Enum):
+  """Defines how a model file should be loaded."""
+  SAVED_MODEL = 1
+
+
+def _load_model(model_uri, model_type):
+  if model_type == ModelType.SAVED_MODEL:
+    return tf.keras.models.load_model(hub.resolve(model_uri))
+  else:
+    raise AssertionError('Unsupported model type for loading.')
+
+
+def default_numpy_inference_fn(
+    model: tf.Module,
+    batch: Sequence[numpy.ndarray],
+    inference_args: Optional[Dict[str, Any]] = None,
+    model_id: Optional[str] = None) -> Iterable[PredictionResult]:
+  vectorized_batch = numpy.stack(batch, axis=0)
+  if inference_args:
+    predictions = model(vectorized_batch, **inference_args)
+  else:
+    predictions = model(vectorized_batch)
+  return utils._convert_to_result(batch, predictions, model_id)
+
+
+def default_tensor_inference_fn(
+    model: tf.Module,
+    batch: Sequence[tf.Tensor],
+    inference_args: Optional[Dict[str, Any]] = None,
+    model_id: Optional[str] = None) -> Iterable[PredictionResult]:
+  vectorized_batch = tf.stack(batch, axis=0)
+  if inference_args:
+    predictions = model(vectorized_batch, **inference_args)
+  else:
+    predictions = model(vectorized_batch)
+  return utils._convert_to_result(batch, predictions, model_id)
+
+
+class TFModelHandlerNumpy(ModelHandler[numpy.ndarray,
+                                       PredictionResult,
+                                       tf.Module]):
+  def __init__(
+      self,
+      model_uri: str,
+      model_type: ModelType = ModelType.SAVED_MODEL,
+      *,
+      inference_fn: TensorInferenceFn = default_numpy_inference_fn):
+    """Implementation of the ModelHandler interface for Tensorflow.
+
+    Example Usage::
+
+      pcoll | RunInference(TFModelHandlerNumpy(model_uri="my_uri"))
+
+    See https://www.tensorflow.org/tutorials/keras/save_and_load for details.
+
+    Args:
+        model_uri (str): path to the trained model.
+        model_type (ModelType): type of model to be loaded.
+          Defaults to SAVED_MODEL.
+        inference_fn (TensorInferenceFn, Optional): inference function to use
+          during RunInference. Defaults to default_numpy_inference_fn.
+
+    **Supported Versions:** RunInference APIs in Apache Beam have been tested
+    with Tensorflow 2.11.

Review Comment:
   ```suggestion
       with Tensorflow 2.9, 2.10, and 2.11.
   ```



##########
sdks/python/apache_beam/ml/inference/tensorflow_inference_test.py:
##########
@@ -0,0 +1,147 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# pytype: skip-file
+
+import unittest
+
+import numpy
+import pytest
+
+try:
+  import tensorflow as tf
+  from apache_beam.ml.inference.sklearn_inference_test import 
_compare_prediction_result
+  from apache_beam.ml.inference.base import KeyedModelHandler, PredictionResult
+  from apache_beam.ml.inference.tensorflow_inference import 
TFModelHandlerNumpy, TFModelHandlerTensor
+except ImportError:
+  raise unittest.SkipTest('Tensorflow dependencies are not installed')
+
+
+class FakeTFNumpyModel:
+  def predict(self, input: numpy.ndarray):
+    return numpy.multiply(input, 10)
+
+
+class FakeTFTensorModel:
+  def predict(self, input: tf.Tensor, add=False):
+    if add:
+      return tf.math.add(tf.math.multiply(input, 10), 10)
+    return tf.math.multiply(input, 10)
+
+
+def _compare_tensor_prediction_result(x, y):
+  return tf.math.equal(x.inference, y.inference)
+
+
[email protected]_tf
+class TFRunInferenceTest(unittest.TestCase):
+  def test_predict_numpy(self):
+    fake_model = FakeTFNumpyModel()
+    inference_runner = TFModelHandlerNumpy(model_uri='unused')
+    batched_examples = [numpy.array([1]), numpy.array([10]), 
numpy.array([100])]
+    expected_predictions = [
+        PredictionResult(numpy.array([1]), 10),
+        PredictionResult(numpy.array([10]), 100),
+        PredictionResult(numpy.array([100]), 1000)
+    ]
+    inferences = inference_runner.run_inference(batched_examples, fake_model)
+    for actual, expected in zip(inferences, expected_predictions):
+      self.assertTrue(_compare_prediction_result(actual, expected))
+
+  def test_predict_tensor(self):
+    fake_model = FakeTFTensorModel()
+    inference_runner = TFModelHandlerTensor(model_uri='unused')
+    batched_examples = [
+        tf.convert_to_tensor(numpy.array([1])),
+        tf.convert_to_tensor(numpy.array([10])),
+        tf.convert_to_tensor(numpy.array([100])),
+    ]
+    expected_predictions = [
+        PredictionResult(ex, pred) for ex,
+        pred in zip(
+            batched_examples,
+            [tf.math.multiply(n, 10) for n in batched_examples])
+    ]
+
+    inferences = inference_runner.run_inference(batched_examples, fake_model)
+    for actual, expected in zip(inferences, expected_predictions):
+      self.assertTrue(_compare_tensor_prediction_result(actual, expected))
+
+  def test_predict_tensor_with_args(self):
+    fake_model = FakeTFTensorModel()
+    inference_runner = TFModelHandlerTensor(model_uri='unused')
+    batched_examples = [
+        tf.convert_to_tensor(numpy.array([1])),
+        tf.convert_to_tensor(numpy.array([10])),
+        tf.convert_to_tensor(numpy.array([100])),
+    ]
+    expected_predictions = [
+        PredictionResult(ex, pred) for ex,
+        pred in zip(
+            batched_examples, [
+                tf.math.add(tf.math.multiply(n, 10), 10)
+                for n in batched_examples
+            ])
+    ]
+
+    inferences = inference_runner.run_inference(
+        batched_examples, fake_model, inference_args={"add": True})
+    for actual, expected in zip(inferences, expected_predictions):
+      self.assertTrue(_compare_tensor_prediction_result(actual, expected))
+
+  def test_predict_keyed_numpy(self):
+    fake_model = FakeTFNumpyModel()
+    inference_runner = KeyedModelHandler(
+        TFModelHandlerNumpy(model_uri='unused'))
+    batched_examples = [
+        ('k1', numpy.array([1], dtype=numpy.int64)),
+        ('k2', numpy.array([10], dtype=numpy.int64)),
+        ('k3', numpy.array([100], dtype=numpy.int64)),
+    ]
+    expected_predictions = [
+        (ex[0], PredictionResult(ex[1], pred)) for ex,
+        pred in zip(
+            batched_examples,
+            [numpy.multiply(n[1], 10) for n in batched_examples])
+    ]
+    inferences = inference_runner.run_inference(batched_examples, fake_model)
+    for actual, expected in zip(inferences, expected_predictions):
+      self.assertTrue(_compare_prediction_result(actual[1], expected[1]))
+
+  @pytest.mark.uses_tf

Review Comment:
   nit: remove the marker since it is already present on the class def



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to