damccorm commented on code in PR #28777:
URL: https://github.com/apache/beam/pull/28777#discussion_r1346315930
##########
examples/notebooks/beam-ml/automatic_model_refresh.ipynb:
##########
@@ -1,605 +1,668 @@
{
- "nbformat": 4,
- "nbformat_minor": 0,
- "metadata": {
- "colab": {
- "provenance": []
- },
- "kernelspec": {
- "name": "python3",
- "display_name": "Python 3"
- },
- "language_info": {
- "name": "python"
- }
- },
- "cells": [{
- "cell_type": "code",
- "source": [
- "# @title ###### Licensed to the Apache
Software Foundation (ASF), Version 2.0 (the \"License\")\n",
- "\n",
- "# Licensed to the Apache Software Foundation
(ASF) under one\n",
- "# or more contributor license agreements. See
the NOTICE file\n",
- "# distributed with this work for additional
information\n",
- "# regarding copyright ownership. The ASF
licenses this file\n",
- "# to you under the Apache License, Version 2.0
(the\n",
- "# \"License\"); you may not use this file
except in compliance\n",
- "# with the License. You may obtain a copy of
the License at\n",
- "#\n",
- "#
http://www.apache.org/licenses/LICENSE-2.0\n",
- "#\n",
- "# Unless required by applicable law or agreed
to in writing,\n",
- "# software distributed under the License is
distributed on an\n",
- "# \"AS IS\" BASIS, WITHOUT WARRANTIES OR
CONDITIONS OF ANY\n",
- "# KIND, either express or implied. See the
License for the\n",
- "# specific language governing permissions and
limitations\n",
- "# under the License"
- ],
- "metadata": {
- "cellView": "form",
- "id": "OsFaZscKSPvo"
- },
- "execution_count": null,
- "outputs": [{
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "\n"
- ]
- }]
- },
- {
- "cell_type": "markdown",
- "source": [
- "# Update ML models in running pipelines\n",
- "\n",
- "<table align=\"left\">\n",
- " <td>\n",
- " <a target=\"_blank\"
href=\"https://colab.sandbox.google.com/github/apache/beam/blob/master/examples/notebooks/beam-ml/automatic_model_refresh.ipynb\"><img
src=\"https://raw.githubusercontent.com/google/or-tools/main/tools/colab_32px.png\"
/>Run in Google Colab</a>\n",
- " </td>\n",
- " <td>\n",
- " <a target=\"_blank\"
href=\"https://github.com/apache/beam/blob/master/examples/notebooks/beam-ml/automatic_model_refresh.ipynb\"><img
src=\"https://raw.githubusercontent.com/google/or-tools/main/tools/github_32px.png\"
/>View source on GitHub</a>\n",
- " </td>\n",
- "</table>\n"
- ],
- "metadata": {
- "id": "ZUSiAR62SgO8"
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "This notebook demonstrates how to perform
automatic model updates without stopping your Apache Beam pipeline.\n",
- "You can use side inputs to update your model
in real time, even while the Apache Beam pipeline is running. The side input is
passed in a `ModelHandler` configuration object. You can update the model
either by leveraging one of Apache Beam's provided patterns, such as the
`WatchFilePattern`, or by configuring a custom side input `PCollection` that
defines the logic for the model update.\n",
- "\n",
- "The pipeline in this notebook uses a
RunInference `PTransform` with TensorFlow machine learning (ML) models to run
inference on images. To update the model, it uses a side input `PCollection`
that emits `ModelMetadata`.\n",
- "For more information about side inputs, see
the [Side
inputs](https://beam.apache.org/documentation/programming-guide/#side-inputs)
section in the Apache Beam Programming Guide.\n",
- "\n",
- "This example uses `WatchFilePattern` as a side
input. `WatchFilePattern` is used to watch for file updates that match the
`file_pattern` based on timestamps. It emits the latest `ModelMetadata`, which
is used in the RunInference `PTransform` to automatically update the ML model
without stopping the Apache Beam pipeline.\n"
- ],
- "metadata": {
- "id": "tBtqF5UpKJNZ"
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "## Before you begin\n",
- "Install the dependencies required to run this
notebook.\n",
- "\n",
- "To use RunInference with side inputs for
automatic model updates, use Apache Beam version 2.46.0 or later."
- ],
- "metadata": {
- "id": "SPuXFowiTpWx"
- }
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {
- "id": "1RyTYsFEIOlA",
- "outputId":
"0e6b88a7-82d8-4d94-951c-046a9b8b7abb",
- "colab": {
- "base_uri": "https://localhost:8080/"
- }
- },
- "outputs": [{
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "\n"
- ]
- }],
- "source": [
- "!pip install apache_beam[gcp]>=2.46.0
--quiet\n",
- "!pip install tensorflow\n",
- "!pip install tensorflow_hub"
- ]
- },
- {
- "cell_type": "code",
- "source": [
- "# Imports required for the notebook.\n",
- "import logging\n",
- "import time\n",
- "from typing import Iterable\n",
- "from typing import Tuple\n",
- "\n",
- "import apache_beam as beam\n",
- "from
apache_beam.examples.inference.tensorflow_imagenet_segmentation import
PostProcessor\n",
- "from
apache_beam.examples.inference.tensorflow_imagenet_segmentation import
read_image\n",
- "from apache_beam.ml.inference.base import
PredictionResult\n",
- "from apache_beam.ml.inference.base import
RunInference\n",
- "from
apache_beam.ml.inference.tensorflow_inference import TFModelHandlerTensor\n",
- "from apache_beam.ml.inference.utils import
WatchFilePattern\n",
- "from apache_beam.options.pipeline_options
import GoogleCloudOptions\n",
- "from apache_beam.options.pipeline_options
import PipelineOptions\n",
- "from apache_beam.options.pipeline_options
import SetupOptions\n",
- "from apache_beam.options.pipeline_options
import StandardOptions\n",
- "from apache_beam.transforms.periodicsequence
import PeriodicImpulse\n",
- "import numpy\n",
- "from PIL import Image\n",
- "import tensorflow as tf"
- ],
- "metadata": {
- "id": "Rs4cwwNrIV9H"
- },
- "execution_count": 2,
- "outputs": [{
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "\n"
- ]
- }]
- },
- {
- "cell_type": "code",
- "source": [
- "# Authenticate to your Google Cloud
account.\n",
- "from google.colab import auth\n",
- "auth.authenticate_user()"
- ],
- "metadata": {
- "id": "jAKpPcmmGm03"
- },
- "execution_count": 3,
- "outputs": [{
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "\n"
- ]
- }]
- },
- {
- "cell_type": "markdown",
- "source": [
- "## Configure the runner\n",
- "\n",
- "This pipeline uses the Dataflow Runner. To run
the pipeline, you need to complete the following tasks:\n",
- "\n",
- "* Ensure that you have all the required
permissions to run the pipeline on Dataflow.\n",
- "* Configure the pipeline options for the
pipeline to run on Dataflow. Make sure the pipeline is using streaming mode.\n",
- "\n",
- "In the following code, replace `BUCKET_NAME`
with the the name of your Cloud Storage bucket."
- ],
- "metadata": {
- "id": "ORYNKhH3WQyP"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "options = PipelineOptions()\n",
- "options.view_as(StandardOptions).streaming =
True\n",
- "\n",
- "# Provide required pipeline options for the
Dataflow Runner.\n",
- "options.view_as(StandardOptions).runner =
\"DataflowRunner\"\n",
- "\n",
- "# Set the project to the default project in
your current Google Cloud environment.\n",
- "options.view_as(GoogleCloudOptions).project =
'your-project'\n",
- "\n",
- "# Set the Google Cloud region that you want to
run Dataflow in.\n",
- "options.view_as(GoogleCloudOptions).region =
'us-central1'\n",
- "\n",
- "# IMPORTANT: Replace BUCKET_NAME with the the
name of your Cloud Storage bucket.\n",
- "dataflow_gcs_location =
\"gs://BUCKET_NAME/tmp/\"\n",
- "\n",
- "# The Dataflow staging location. This location
is used to stage the Dataflow pipeline and the SDK binary.\n",
-
"options.view_as(GoogleCloudOptions).staging_location = '%s/staging' %
dataflow_gcs_location\n",
- "\n",
- "# The Dataflow temp location. This location is
used to store temporary files or intermediate results before outputting to the
sink.\n",
-
"options.view_as(GoogleCloudOptions).temp_location = '%s/temp' %
dataflow_gcs_location\n",
- "\n"
- ],
- "metadata": {
- "id": "wWjbnq6X-4uE"
- },
- "execution_count": 4,
- "outputs": [{
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "\n"
- ]
- }]
- },
- {
- "cell_type": "markdown",
- "source": [
- "Install the `tensorflow` and `tensorflow_hub`
dependencies on Dataflow. Use the `requirements_file` pipeline option to pass
these dependencies."
- ],
- "metadata": {
- "id": "HTJV8pO2Wcw4"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "# In a requirements file, define the
dependencies required for the pipeline.\n",
- "deps_required_for_pipeline =
['tensorflow>=2.12.0', 'tensorflow-hub>=0.10.0', 'Pillow>=9.0.0']\n",
- "requirements_file_path =
'./requirements.txt'\n",
- "# Write the dependencies to the requirements
file.\n",
- "with open(requirements_file_path, 'w') as
f:\n",
- " for dep in deps_required_for_pipeline:\n",
- " f.write(dep + '\\n')\n",
- "\n",
- "# Install the pipeline dependencies on
Dataflow.\n",
-
"options.view_as(SetupOptions).requirements_file = requirements_file_path"
- ],
- "metadata": {
- "id": "lEy4PkluWbdm"
- },
- "execution_count": 5,
- "outputs": [{
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "\n"
- ]
- }]
- },
- {
- "cell_type": "markdown",
- "source": [
- "## Use the TensorFlow model handler\n",
- " This example uses `TFModelHandlerTensor` as
the model handler and the `resnet_101` model trained on
[ImageNet](https://www.image-net.org/).\n",
- "\n",
- " Download the model from [Google Cloud
Storage](https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet101_weights_tf_dim_ordering_tf_kernels.h5)
(link downloads the model), and place it in the directory that you want to use
to update your model.\n",
- "\n",
- "In the following code, replace `BUCKET_NAME`
with the the name of your Cloud Storage bucket."
- ],
- "metadata": {
- "id": "_AUNH_GJk_NE"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "model_handler = TFModelHandlerTensor(\n",
- "
model_uri=\"gs://BUCKET_NAME/resnet101_weights_tf_dim_ordering_tf_kernels.h5\")"
- ],
- "metadata": {
- "id": "kkSnsxwUk-Sp"
- },
- "execution_count": 6,
- "outputs": [{
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "\n"
- ]
- }]
- },
- {
- "cell_type": "markdown",
- "source": [
- "## Preprocess images\n",
- "\n",
- "Use `preprocess_image` to run the inference,
read the image, and convert the image to a TensorFlow tensor."
- ],
- "metadata": {
- "id": "tZH0r0sL-if5"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "def preprocess_image(image_name,
image_dir):\n",
- " img = tf.keras.utils.get_file(image_name,
image_dir + image_name)\n",
- " img = Image.open(img).resize((224, 224))\n",
- " img = numpy.array(img) / 255.0\n",
- " img_tensor =
tf.cast(tf.convert_to_tensor(img[...]), dtype=tf.float32)\n",
- " return img_tensor"
- ],
- "metadata": {
- "id": "dU5imgTt-8Ne"
- },
- "execution_count": 7,
- "outputs": [{
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "\n"
- ]
- }]
- },
- {
- "cell_type": "code",
- "source": [
- "class PostProcessor(beam.DoFn):\n",
- " \"\"\"Process the PredictionResult to get
the predicted label.\n",
- " Returns predicted label.\n",
- " \"\"\"\n",
- " def process(self, element: PredictionResult)
-> Iterable[Tuple[str, str]]:\n",
- " predicted_class =
numpy.argmax(element.inference, axis=-1)\n",
- " labels_path = tf.keras.utils.get_file(\n",
- " 'ImageNetLabels.txt',\n",
- "
'https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt'
# pylint: disable=line-too-long\n",
- " )\n",
- " imagenet_labels =
numpy.array(open(labels_path).read().splitlines())\n",
- " predicted_class_name =
imagenet_labels[predicted_class]\n",
- " yield predicted_class_name.title(),
element.model_id"
- ],
- "metadata": {
- "id": "6V5tJxO6-gyt"
- },
- "execution_count": 8,
- "outputs": [{
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "\n"
- ]
- }]
- },
- {
- "cell_type": "code",
- "source": [
- "# Define the pipeline object.\n",
- "pipeline = beam.Pipeline(options=options)"
- ],
- "metadata": {
- "id": "GpdKk72O_NXT",
- "outputId":
"bcbaa8a6-0408-427a-de9e-78a6a7eefd7b",
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 400
- }
- },
- "execution_count": 9,
- "outputs": [{
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "\n"
- ]
- }]
- },
- {
- "cell_type": "markdown",
- "source": [
- "Next, review the pipeline steps and examine
the code.\n",
- "\n",
- "### Pipeline steps\n"
- ],
- "metadata": {
- "id": "elZ53uxc_9Hv"
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "1. Create a `PeriodicImpulse` transform, which
emits output every `n` seconds. The `PeriodicImpulse` transform generates an
infinite sequence of elements with a given runtime interval.\n",
- "\n",
- " In this example, `PeriodicImpulse` mimics
the Pub/Sub source. Because the inputs in a streaming pipeline arrive in
intervals, use `PeriodicImpulse` to output elements at `m` intervals.\n",
- "To learn more about `PeriodicImpulse`, see the
[`PeriodicImpulse`
code](https://github.com/apache/beam/blob/9c52e0594d6f0e59cd17ee005acfb41da508e0d5/sdks/python/apache_beam/transforms/periodicsequence.py#L150)."
- ],
- "metadata": {
- "id": "305tkV2sAD-S"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "start_timestamp = time.time() # start
timestamp of the periodic impulse\n",
- "end_timestamp = start_timestamp + 60 * 20 #
end timestamp of the periodic impulse (will run for 20 minutes).\n",
- "main_input_fire_interval = 60 # interval in
seconds at which the main input PCollection is emitted.\n",
- "side_input_fire_interval = 60 # interval in
seconds at which the side input PCollection is emitted.\n",
- "\n",
- "periodic_impulse = (\n",
- " pipeline\n",
- " | \"MainInputPcoll\" >>
PeriodicImpulse(\n",
- " start_timestamp=start_timestamp,\n",
- " stop_timestamp=end_timestamp,\n",
- "
fire_interval=main_input_fire_interval))"
- ],
- "metadata": {
- "id": "vUFStz66_Tbb",
- "outputId":
"39f2704b-021e-4d41-fce3-a2fac90a5bad",
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 133
- }
- },
- "execution_count": 10,
- "outputs": [{
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "\n"
- ]
- }]
- },
- {
- "cell_type": "markdown",
- "source": [
- "2. To read and preprocess the images, use the
`read_image` function. This example uses `Cat-with-beanie.jpg` for all
inferences.\n",
- "\n",
- " **Note**: Image used for prediction is
licensed in CC-BY. The creator is listed in the
[LICENSE.txt](https://storage.googleapis.com/apache-beam-samples/image_captioning/LICENSE.txt)
file."
- ],
- "metadata": {
- "id": "8-sal2rFAxP2"
- }
- },
- {
- "cell_type": "markdown",
- "source": [
-
""
- ],
- "metadata": {
- "id": "gW4cE8bhXS-d"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "image_data = (periodic_impulse |
beam.Map(lambda x: \"Cat-with-beanie.jpg\")\n",
- " | \"ReadImage\" >> beam.Map(lambda
image_name: read_image(\n",
- " image_name=image_name,
image_dir='https://storage.googleapis.com/apache-beam-samples/image_captioning/')))"
- ],
- "metadata": {
- "id": "dGg11TpV_aV6",
- "outputId":
"a57e8197-6756-4fd8-a664-f51ef2fea730",
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 204
- }
- },
- "execution_count": 11,
- "outputs": [{
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "\n"
- ]
- }]
- },
- {
- "cell_type": "markdown",
- "source": [
- "3. Pass the images to the RunInference
`PTransform`. RunInference takes `model_handler` and `model_metadata_pcoll` as
input parameters.\n",
- " * `model_metadata_pcoll` is a side input
`PCollection` to the RunInference `PTransform`. This side input is used to
update the `model_uri` in the `model_handler` without needing to stop the
Apache Beam pipeline\n",
- " * Use `WatchFilePattern` as side input to
watch a `file_pattern` matching `.h5` files. In this case, the `file_pattern`
is `'gs://BUCKET_NAME/*.h5'`.\n",
- "\n"
- ],
- "metadata": {
- "id": "eB0-ewd-BCKE"
- }
- },
- {
- "cell_type": "code",
- "source": [
- " # The side input used to watch for the .h5
file and update the model_uri of the TFModelHandlerTensor.\n",
- "file_pattern = 'gs://BUCKET_NAME/*.h5'\n",
- "side_input_pcoll = (\n",
- " pipeline\n",
- " | \"WatchFilePattern\" >>
WatchFilePattern(file_pattern=file_pattern,\n",
- "
interval=side_input_fire_interval,\n",
- "
stop_timestamp=end_timestamp))\n",
- "inferences = (\n",
- " image_data\n",
- " | \"ApplyWindowing\" >>
beam.WindowInto(beam.window.FixedWindows(10))\n",
- " | \"RunInference\" >>
RunInference(model_handler=model_handler,\n",
- "
model_metadata_pcoll=side_input_pcoll))"
- ],
- "metadata": {
- "id": "_AjvvexJ_hUq",
- "outputId":
"291fcc38-0abb-4b11-f840-4a850097a56f",
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 133
- }
- },
- "execution_count": 12,
- "outputs": [{
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "\n"
- ]
- }]
- },
- {
- "cell_type": "markdown",
- "source": [
- "4. Post-process the `PredictionResult`
object.\n",
- "When the inference is complete, RunInference
outputs a `PredictionResult` object that contains the fields `example`,
`inference`, and `model_id`. The `model_id` field identifies the model used to
run the inference. The `PostProcessor` returns the predicted label and the
model ID used to run the inference on the predicted label."
- ],
- "metadata": {
- "id": "lTA4wRWNDVis"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "post_processor = (\n",
- " inferences\n",
- " | \"PostProcessResults\" >>
beam.ParDo(PostProcessor())\n",
- " | \"LogResults\" >>
beam.Map(logging.info))"
- ],
- "metadata": {
- "id": "9TB76fo-_vZJ",
- "outputId":
"3e12d482-1bdf-4136-fbf7-9d5bb4bb62c3",
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 222
- }
- },
- "execution_count": 13,
- "outputs": [{
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "\n"
- ]
- }]
- },
- {
- "cell_type": "markdown",
- "source": [
- "### Watch for the model update\n",
- "\n",
- "After the pipeline starts processing data and
when you see output emitted from the RunInference `PTransform`, upload a
`resnet152` model saved in `.h5` format to a Google Cloud Storage bucket
location that matches the `file_pattern` you defined earlier. You can [download
a copy of the
model](https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet152_weights_tf_dim_ordering_tf_kernels.h5)
(link downloads the model). RunInference uses `WatchFilePattern` as a side
input to update the `model_uri` of `TFModelHandlerTensor`."
- ],
- "metadata": {
- "id": "wYp-mBHHjOjA"
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "## Run the pipeline\n",
- "\n",
- "Use the following code to run the pipeline."
- ],
- "metadata": {
- "id": "_ty03jDnKdKR"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "# Run the pipeline.\n",
- "result = pipeline.run().wait_until_finish()"
- ],
- "metadata": {
- "id": "wd0VJLeLEWBU",
- "outputId":
"3489c891-05d2-4739-d693-1899cfe78859",
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 186
- }
- },
- "execution_count": 14,
- "outputs": [{
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "\n"
- ]
- }]
- }
- ]
-}
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "provenance": []
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "cells": [
+ {
+ "cell_type": "code",
+ "source": [
+ "# @title ###### Licensed to the Apache Software Foundation (ASF),
Version 2.0 (the \"License\")\n",
+ "\n",
+ "# Licensed to the Apache Software Foundation (ASF) under one\n",
+ "# or more contributor license agreements. See the NOTICE file\n",
+ "# distributed with this work for additional information\n",
+ "# regarding copyright ownership. The ASF licenses this file\n",
+ "# to you under the Apache License, Version 2.0 (the\n",
+ "# \"License\"); you may not use this file except in compliance\n",
+ "# with the License. You may obtain a copy of the License at\n",
+ "#\n",
+ "# http://www.apache.org/licenses/LICENSE-2.0\n",
+ "#\n",
+ "# Unless required by applicable law or agreed to in writing,\n",
+ "# software distributed under the License is distributed on an\n",
+ "# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n",
+ "# KIND, either express or implied. See the License for the\n",
+ "# specific language governing permissions and limitations\n",
+ "# under the License"
+ ],
+ "metadata": {
+ "cellView": "form",
+ "id": "OsFaZscKSPvo",
+ "outputId": "f9903a54-13d4-403c-a705-a212be050fed"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Update ML models in running pipelines\n",
+ "\n",
+ "<table align=\"left\">\n",
+ " <td>\n",
+ " <a target=\"_blank\"
href=\"https://colab.sandbox.google.com/github/apache/beam/blob/master/examples/notebooks/beam-ml/automatic_model_refresh.ipynb\"><img
src=\"https://raw.githubusercontent.com/google/or-tools/main/tools/colab_32px.png\"
/>Run in Google Colab</a>\n",
+ " </td>\n",
+ " <td>\n",
+ " <a target=\"_blank\"
href=\"https://github.com/apache/beam/blob/master/examples/notebooks/beam-ml/automatic_model_refresh.ipynb\"><img
src=\"https://raw.githubusercontent.com/google/or-tools/main/tools/github_32px.png\"
/>View source on GitHub</a>\n",
+ " </td>\n",
+ "</table>\n"
+ ],
+ "metadata": {
+ "id": "ZUSiAR62SgO8"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "This notebook demonstrates how to perform automatic model updates
without stopping your Apache Beam pipeline.\n",
+ "You can use side inputs to update your model in real time, even while
the Apache Beam pipeline is running. The side input is passed in a
`ModelHandler` configuration object. You can update the model either by
leveraging one of Apache Beam's provided patterns, such as the
`WatchFilePattern`, or by configuring a custom side input `PCollection` that
defines the logic for the model update.\n",
+ "\n",
+ "The pipeline in this notebook uses a RunInference `PTransform` with
TensorFlow machine learning (ML) models to run inference on images. To update
the model, it uses a side input `PCollection` that emits `ModelMetadata`.\n",
+ "For more information about side inputs, see the [Side
inputs](https://beam.apache.org/documentation/programming-guide/#side-inputs)
section in the Apache Beam Programming Guide.\n",
+ "\n",
+ "This example uses `WatchFilePattern` as a side input.
`WatchFilePattern` is used to watch for file updates that match the
`file_pattern` based on timestamps. It emits the latest `ModelMetadata`, which
is used in the RunInference `PTransform` to automatically update the ML model
without stopping the Apache Beam pipeline.\n"
+ ],
+ "metadata": {
+ "id": "tBtqF5UpKJNZ"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Before you begin\n",
+ "Install the dependencies required to run this notebook.\n",
+ "\n",
+ "To use RunInference with side inputs for automatic model updates, use
Apache Beam version 2.46.0 or later."
+ ],
+ "metadata": {
+ "id": "SPuXFowiTpWx"
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "1RyTYsFEIOlA",
+ "outputId": "0e6b88a7-82d8-4d94-951c-046a9b8b7abb",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ }
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "!pip install apache_beam[gcp]>=2.46.0 --quiet\n",
+ "!pip install tensorflow\n",
+ "!pip install tensorflow_hub"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# Imports required for the notebook.\n",
+ "import logging\n",
+ "import time\n",
+ "from typing import Iterable\n",
+ "from typing import Tuple\n",
+ "\n",
+ "import apache_beam as beam\n",
+ "from apache_beam.ml.inference.base import PredictionResult\n",
+ "from apache_beam.ml.inference.base import RunInference\n",
+ "from apache_beam.ml.inference.tensorflow_inference import
TFModelHandlerTensor\n",
+ "from apache_beam.ml.inference.utils import WatchFilePattern\n",
+ "from apache_beam.options.pipeline_options import
GoogleCloudOptions\n",
+ "from apache_beam.options.pipeline_options import PipelineOptions\n",
+ "from apache_beam.options.pipeline_options import SetupOptions\n",
+ "from apache_beam.options.pipeline_options import StandardOptions\n",
+ "from apache_beam.options.pipeline_options import WorkerOptions\n",
+ "from apache_beam.transforms.periodicsequence import
PeriodicImpulse\n",
+ "import numpy\n",
+ "from PIL import Image\n",
+ "import tensorflow as tf"
+ ],
+ "metadata": {
+ "id": "Rs4cwwNrIV9H"
+ },
+ "execution_count": 4,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# Authenticate to your Google Cloud account.\n",
+ "def auth_to_colab():\n",
+ " from google.colab import auth\n",
+ " auth.authenticate_user()\n",
+ "\n",
+ "auth_to_colab()"
+ ],
+ "metadata": {
+ "id": "jAKpPcmmGm03",
+ "outputId": "8776c778-54f5-497c-d929-15b7bca98595"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Configure the runner\n",
+ "\n",
+ "This pipeline uses the Dataflow Runner. To run the pipeline, you need
to complete the following tasks:\n",
+ "\n",
+ "* Ensure that you have all the required permissions to run the
pipeline on Dataflow.\n",
+ "* Configure the pipeline options for the pipeline to run on Dataflow.
Make sure the pipeline is using streaming mode.\n",
+ "\n",
+ "In the following code, replace `BUCKET_NAME` with the the name of
your Cloud Storage bucket."
+ ],
+ "metadata": {
+ "id": "ORYNKhH3WQyP"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "options = PipelineOptions()\n",
+ "options.view_as(StandardOptions).streaming = True\n",
+ "\n",
+ "# Provide required pipeline options for the Dataflow Runner.\n",
+ "options.view_as(StandardOptions).runner = \"DataflowRunner\"\n",
+ "\n",
+ "# Set the project to the default project in your current Google Cloud
environment.\n",
+ "options.view_as(GoogleCloudOptions).project = 'your-project'\n",
+ "\n",
+ "# Set the Google Cloud region that you want to run Dataflow in.\n",
+ "options.view_as(GoogleCloudOptions).region = 'us-central1'\n",
+ "\n",
+ "# IMPORTANT: Replace BUCKET_NAME with the the name of your Cloud
Storage bucket.\n",
+ "dataflow_gcs_location = \"gs://BUCKET_NAME/tmp/\"\n",
+ "\n",
+ "# The Dataflow staging location. This location is used to stage the
Dataflow pipeline and the SDK binary.\n",
+ "options.view_as(GoogleCloudOptions).staging_location = '%s/staging' %
dataflow_gcs_location\n",
+ "\n",
+ "# The Dataflow temp location. This location is used to store
temporary files or intermediate results before outputting to the sink.\n",
+ "options.view_as(GoogleCloudOptions).temp_location = '%s/temp' %
dataflow_gcs_location\n",
+ "\n",
+ "options.view_as(SetupOptions).save_main_session = True\n",
+ "\n",
+ "# Launching Dataflow with only one worker might result in processing
delays due to\n",
+ "# initial input processing. This could further postpone the side
input model updates.\n",
+ "# To expedite the model update process, it's recommended to set
num_workers>1.\n",
+ "# https://github.com/apache/beam/issues/28776\n",
+ "options.view_as(WorkerOptions).num_workers = 5"
+ ],
+ "metadata": {
+ "id": "wWjbnq6X-4uE",
+ "outputId": "2125c017-dfd4-4402-f02a-8469f67409a8"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Install the `tensorflow` and `tensorflow_hub` dependencies on
Dataflow. Use the `requirements_file` pipeline option to pass these
dependencies."
+ ],
+ "metadata": {
+ "id": "HTJV8pO2Wcw4"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# In a requirements file, define the dependencies required for the
pipeline.\n",
+ "!printf 'tensorflow>=2.12.0\\ntensorflow_hub>=0.10.0\\nPillow>=9.0.0'
> ./requirements.txt\n",
+ "# Install the pipeline dependencies on Dataflow.\n",
+ "options.view_as(SetupOptions).requirements_file =
'./requirements.txt'"
+ ],
+ "metadata": {
+ "id": "lEy4PkluWbdm"
+ },
+ "execution_count": 7,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Use the TensorFlow model handler\n",
+ " This example uses `TFModelHandlerTensor` as the model handler and
the `resnet_101` model trained on [ImageNet](https://www.image-net.org/).\n",
+ "\n",
+ "\n",
+ "For DataflowRunner, the model needs to be stored remote location
accessible by the Beam pipeline. So we will download `ResNet101` model and
upload it to the GCS location.\n"
+ ],
+ "metadata": {
+ "id": "_AUNH_GJk_NE"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "model = tf.keras.applications.resnet.ResNet101()\n",
Review Comment:
We should have something similar to get the resnet 152 model below
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]