This is an automated email from the ASF dual-hosted git repository.
riteshghorse pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 6ea4c758fa0 [Python] Add setup method to PostProcessor DoFn (#26177)
6ea4c758fa0 is described below
commit 6ea4c758fa035dbccd36120be86ec2f3c0c88d00
Author: Ritesh Ghorse <[email protected]>
AuthorDate: Mon Apr 10 08:24:16 2023 -0400
[Python] Add setup method to PostProcessor DoFn (#26177)
* add setup method to PostProcessor
* add setup method to tf notebook as well
* rm unused args
---
.../run_inference_with_tensorflow_hub.ipynb | 55 ++++++++++++++--------
.../inference/tensorflow_imagenet_segmentation.py | 18 +++++--
2 files changed, 48 insertions(+), 25 deletions(-)
diff --git a/examples/notebooks/beam-ml/run_inference_with_tensorflow_hub.ipynb
b/examples/notebooks/beam-ml/run_inference_with_tensorflow_hub.ipynb
index e477e86d9d0..c2447477dfc 100644
--- a/examples/notebooks/beam-ml/run_inference_with_tensorflow_hub.ipynb
+++ b/examples/notebooks/beam-ml/run_inference_with_tensorflow_hub.ipynb
@@ -123,7 +123,7 @@
"metadata": {
"id": "H4-ZvkcTv7MO"
},
- "execution_count": 4,
+ "execution_count": 3,
"outputs": []
},
{
@@ -135,7 +135,7 @@
"metadata": {
"id": "n3M6FNaUwBbl"
},
- "execution_count": 16,
+ "execution_count": 4,
"outputs": []
},
{
@@ -152,23 +152,31 @@
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
- "height": 241
+ "height": 276
},
"id": "q23ip_HkwL3G",
- "outputId": "76417420-788d-4b2b-bd94-782973cfb4b8"
+ "outputId": "051bfa77-ce1f-4ee2-abcd-26ae7f037180"
},
- "execution_count": 17,
+ "execution_count": 5,
"outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Downloading data from
https://storage.googleapis.com/apache-beam-samples/image_captioning/Cat-with-beanie.jpg\n",
+ "1812110/1812110 [==============================] - 0s 0us/step\n"
+ ]
+ },
{
"output_type": "execute_result",
"data": {
"text/plain": [
- "<PIL.Image.Image image mode=RGB size=224x224 at 0x7F1F51067850>"
+ "<PIL.Image.Image image mode=RGB size=224x224 at 0x7F6563C90CA0>"
],
"image/png":
"iVBORw0KGgoAAAANSUhEUgAAAOAAAADgCAIAAACVT/22AAAKMWlDQ1BJQ0MgUHJvZmlsZQAAeJydlndUU9kWh8+9N71QkhCKlNBraFICSA29SJEuKjEJEErAkAAiNkRUcERRkaYIMijggKNDkbEiioUBUbHrBBlE1HFwFBuWSWStGd+8ee/Nm98f935rn73P3Wfvfda6AJD8gwXCTFgJgAyhWBTh58WIjYtnYAcBDPAAA2wA4HCzs0IW+EYCmQJ82IxsmRP4F726DiD5+yrTP4zBAP+flLlZIjEAUJiM5/L42VwZF8k4PVecJbdPyZi2NE3OMErOIlmCMlaTc/IsW3z2mWUPOfMyhDwZy3PO4mXw5Nwn4405Er6MkWAZF+cI+LkyviZjg3RJhkDGb+SxGXxONgAoktwu5nNTZGwtY5IoMoIt43kA4EjJX/DSL1jMzxPLD8XOzFouEiSniB
[...]
},
"metadata": {},
- "execution_count": 17
+ "execution_count": 5
}
]
},
@@ -182,7 +190,7 @@
"metadata": {
"id": "QLFCEisBwPlz"
},
- "execution_count": 18,
+ "execution_count": 6,
"outputs": []
},
{
@@ -199,31 +207,38 @@
" \"\"\"Process the PredictionResult to get the predicted label.\n",
" Returns predicted label.\n",
" \"\"\"\n",
- " def process(self, element: PredictionResult) -> Iterable[str]:\n",
- " predicted_class = np.argmax(element.inference)\n",
- " labels_path = tf.keras.utils.get_file(\n",
+ " def __init__(self):\n",
+ " super().__init__()\n",
+ " self._labels_path = None\n",
+ " self._imagenet_labels = None\n",
+ "\n",
+ " def setup(self):\n",
+ " self._labels_path = tf.keras.utils.get_file(\n",
" 'ImageNetLabels.txt',\n",
"
'https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt'\n",
" )\n",
- " imagenet_labels =
np.array(open(labels_path).read().splitlines())\n",
- " predicted_class_name = imagenet_labels[predicted_class]\n",
+ " self._imagenet_labels =
np.array(open(self._labels_path).read().splitlines())\n",
+ "\n",
+ " def process(self, element: PredictionResult) -> Iterable[str]:\n",
+ " predicted_class = np.argmax(element.inference)\n",
+ " predicted_class_name = self._imagenet_labels[predicted_class]\n",
" yield \"Predicted Label:
{}\".format(predicted_class_name.title())\n",
"\n",
"with beam.Pipeline() as p:\n",
- " _ = (p | beam.Create([img_tensor])\n",
- " | RunInference(model_handler)\n",
- " | beam.ParDo(PostProcessor())\n",
- " | beam.Map(print)\n",
- " )"
+ " _ = (p\n",
+ " | \"Create PCollection\" >> beam.Create([img_tensor])\n",
+ " | \"Perform inference\" >> RunInference(model_handler)\n",
+ " | \"Post Processing\" >> beam.ParDo(PostProcessor())\n",
+ " | \"Print\" >> beam.Map(print))"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "MSKMc_s6wSUH",
- "outputId": "17f86150-7299-4257-c3b7-22a1725020cf"
+ "outputId": "73695bae-75ae-4f02-a0a5-750531b8f90b"
},
- "execution_count": 19,
+ "execution_count": 8,
"outputs": [
{
"output_type": "stream",
diff --git
a/sdks/python/apache_beam/examples/inference/tensorflow_imagenet_segmentation.py
b/sdks/python/apache_beam/examples/inference/tensorflow_imagenet_segmentation.py
index 7b9e4aba6aa..52863152538 100644
---
a/sdks/python/apache_beam/examples/inference/tensorflow_imagenet_segmentation.py
+++
b/sdks/python/apache_beam/examples/inference/tensorflow_imagenet_segmentation.py
@@ -37,14 +37,22 @@ class PostProcessor(beam.DoFn):
"""Process the PredictionResult to get the predicted label.
Returns predicted label.
"""
- def process(self, element: PredictionResult) -> Iterable[str]:
- predicted_class = numpy.argmax(element.inference, axis=-1)
- labels_path = tf.keras.utils.get_file(
+ def __init__(self):
+ super().__init__()
+ self._imagenet_labels = None
+ self._labels_path = None
+
+ def setup(self):
+ self._labels_path = tf.keras.utils.get_file(
'ImageNetLabels.txt',
'https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt'
# pylint: disable=line-too-long
)
- imagenet_labels = numpy.array(open(labels_path).read().splitlines())
- predicted_class_name = imagenet_labels[predicted_class]
+ self._imagenet_labels = numpy.array(
+ open(self._labels_path).read().splitlines())
+
+ def process(self, element: PredictionResult) -> Iterable[str]:
+ predicted_class = numpy.argmax(element.inference, axis=-1)
+ predicted_class_name = self._imagenet_labels[predicted_class]
yield predicted_class_name.title()