damccorm commented on code in PR #28327:
URL: https://github.com/apache/beam/pull/28327#discussion_r1320229732


##########
examples/notebooks/beam-ml/per_key_models.ipynb:
##########
@@ -0,0 +1,597 @@
+{
+  "nbformat": 4,
+  "nbformat_minor": 0,
+  "metadata": {
+    "colab": {
+      "provenance": []
+    },
+    "kernelspec": {
+      "name": "python3",
+      "display_name": "Python 3"
+    },
+    "language_info": {
+      "name": "python"
+    }
+  },
+  "cells": [
+    {
+      "cell_type": "code",
+      "source": [
+        "# @title ###### Licensed to the Apache Software Foundation (ASF), 
Version 2.0 (the \"License\")\n",
+        "\n",
+        "# Licensed to the Apache Software Foundation (ASF) under one\n",
+        "# or more contributor license agreements. See the NOTICE file\n",
+        "# distributed with this work for additional information\n",
+        "# regarding copyright ownership. The ASF licenses this file\n",
+        "# to you under the Apache License, Version 2.0 (the\n",
+        "# \"License\"); you may not use this file except in compliance\n",
+        "# with the License. You may obtain a copy of the License at\n",
+        "#\n",
+        "#   http://www.apache.org/licenses/LICENSE-2.0\n";,
+        "#\n",
+        "# Unless required by applicable law or agreed to in writing,\n",
+        "# software distributed under the License is distributed on an\n",
+        "# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n",
+        "# KIND, either express or implied. See the License for the\n",
+        "# specific language governing permissions and limitations\n",
+        "# under the License"
+      ],
+      "metadata": {
+        "id": "OsFaZscKSPvo"
+      },
+      "execution_count": 1,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "# Run ML Inference with Different Models Per Key\n",
+        "\n",
+        "<table align=\"left\">\n",
+        "  <td>\n",
+        "    <a target=\"_blank\" 
href=\"https://colab.sandbox.google.com/github/apache/beam/blob/master/examples/notebooks/beam-ml/per_key_models.ipynb\";><img
 
src=\"https://raw.githubusercontent.com/google/or-tools/main/tools/colab_32px.png\";
 />Run in Google Colab</a>\n",
+        "  </td>\n",
+        "  <td>\n",
+        "    <a target=\"_blank\" 
href=\"https://github.com/apache/beam/blob/master/examples/notebooks/beam-ml/per_key_models.ipynb\";><img
 
src=\"https://raw.githubusercontent.com/google/or-tools/main/tools/github_32px.png\";
 />View source on GitHub</a>\n",
+        "  </td>\n",
+        "</table>\n"
+      ],
+      "metadata": {
+        "id": "ZUSiAR62SgO8"
+      }
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "Often users desire to run inference with many differently trained 
models performing the same task. This can be helpful if you are comparing the 
performance of multiple different models, or if you have models trained on 
different datasets which you would like to conditionally use based on 
additional metadata.\n",
+        "\n",
+        "In Apache Beam, the recommended way to run inference is with the 
`RunInference` transform. Using a `KeyedModelHandler`, you can efficiently run 
inference with O(100s) of models without worrying about managing memory 
yourself.\n",
+        "\n",
+        "This notebook demonstrates how you can use a `KeyedModelHandler` to 
run inference in a Beam model with multiple different models on a per key 
basis. This notebook uses pretrained pipelines pulled from Hugging Face. It is 
recommended that you walk through the [beginner RunInference 
notebook](https://colab.sandbox.google.com/github/apache/beam/blob/master/examples/notebooks/beam-ml/run_inference_pytorch_tensorflow_sklearn.ipynb)
 before continuing with this notebook."
+      ],
+      "metadata": {
+        "id": "ZAVOrrW2An1n"
+      }
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "## Install Dependencies\n",
+        "\n",
+        "We will first install Beam and some dependencies needed by Hugging 
Face"
+      ],
+      "metadata": {
+        "id": "_fNyheQoDgGt"
+      }
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 11,
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/";
+        },
+        "id": "B-ENznuJqArA",
+        "outputId": "f72963fc-82db-4d0d-9225-07f6b501e256"
+      },
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            ""
+          ]
+        }
+      ],
+      "source": [
+        "!pip install apache_beam[gcp]>=2.51.0 --quiet\n",
+        "!pip install torch --quiet\n",
+        "!pip install transformers --quiet\n",
+        "\n",
+        "# To use the newly installed versions, restart the runtime.\n",
+        "exit()"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "from typing import Dict\n",
+        "from typing import Iterable\n",
+        "from typing import Tuple\n",
+        "\n",
+        "from transformers import pipeline\n",
+        "\n",
+        "import apache_beam as beam\n",
+        "from apache_beam.ml.inference.base import KeyedModelHandler\n",
+        "from apache_beam.ml.inference.base import KeyModelMapping\n",
+        "from apache_beam.ml.inference.base import PredictionResult\n",
+        "from apache_beam.ml.inference.huggingface_inference import 
HuggingFacePipelineModelHandler\n",
+        "from apache_beam.ml.inference.base import RunInference"
+      ],
+      "metadata": {
+        "id": "wUmBEglvsOYW"
+      },
+      "execution_count": 1,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "## Define Configuration for our Models\n",
+        "\n",
+        "A `ModelHandler` is Beam's method for defining the configuration 
needed to load and invoke your model. Since we want to use multiple models, we 
will define 2 ModelHandlers, one for each model we're using in this example. 
Since both models being used are incapsulated by Hugging Face pipelines, we 
will use `HuggingFacePipelineModelHandler`.\n",
+        "\n",
+        "In this notebook, we will also load the models using Hugging Face and 
run them against an example. Note that they produce different outputs."
+      ],
+      "metadata": {
+        "id": "uEqljVgCD7hx"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "distilbert_mh = 
HuggingFacePipelineModelHandler('text-classification', 
model=\"distilbert-base-uncased-finetuned-sst-2-english\")\n",
+        "roberta_mh = HuggingFacePipelineModelHandler('text-classification', 
model=\"roberta-large-mnli\")\n",
+        "\n",
+        "distilbert_pipe = pipeline('text-classification', 
model=\"distilbert-base-uncased-finetuned-sst-2-english\")\n",
+        "roberta_large_pipe = pipeline(model=\"roberta-large-mnli\")"
+      ],
+      "metadata": {
+        "id": "v2NJT5ZcxgH5",
+        "outputId": "3924d72e-5c49-477d-c50f-6d9098f5a4b2"
+      },
+      "execution_count": 2,
+      "outputs": [
+        {
+          "output_type": "display_data",
+          "data": {
+            "text/plain": [
+              "Downloading (…)lve/main/config.json:   0%|          | 0.00/629 
[00:00<?, ?B/s]"
+            ],
+            "application/vnd.jupyter.widget-view+json": {
+              "version_major": 2,
+              "version_minor": 0,
+              "model_id": "b7cb51663677434ca42de6b5e6f37420"
+            }
+          },
+          "metadata": {}
+        },
+        {
+          "output_type": "display_data",
+          "data": {
+            "text/plain": [
+              "Downloading model.safetensors:   0%|          | 0.00/268M 
[00:00<?, ?B/s]"
+            ],
+            "application/vnd.jupyter.widget-view+json": {
+              "version_major": 2,
+              "version_minor": 0,
+              "model_id": "3702756019854683a9dea9f8af0a29d0"
+            }
+          },
+          "metadata": {}
+        },
+        {
+          "output_type": "display_data",
+          "data": {
+            "text/plain": [
+              "Downloading (…)okenizer_config.json:   0%|          | 0.00/48.0 
[00:00<?, ?B/s]"
+            ],
+            "application/vnd.jupyter.widget-view+json": {
+              "version_major": 2,
+              "version_minor": 0,
+              "model_id": "52b9fdb51d514c2e8b9fa5813972ab01"
+            }
+          },
+          "metadata": {}
+        },
+        {
+          "output_type": "display_data",
+          "data": {
+            "text/plain": [
+              "Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k 
[00:00<?, ?B/s]"
+            ],
+            "application/vnd.jupyter.widget-view+json": {
+              "version_major": 2,
+              "version_minor": 0,
+              "model_id": "eca24b7b7b1847c1aed6aa59a44ed63a"
+            }
+          },
+          "metadata": {}
+        },
+        {
+          "output_type": "display_data",
+          "data": {
+            "text/plain": [
+              "Downloading (…)lve/main/config.json:   0%|          | 0.00/688 
[00:00<?, ?B/s]"
+            ],
+            "application/vnd.jupyter.widget-view+json": {
+              "version_major": 2,
+              "version_minor": 0,
+              "model_id": "4d4cfe9a0ca54897aa991420bee01ff9"
+            }
+          },
+          "metadata": {}
+        },
+        {
+          "output_type": "display_data",
+          "data": {
+            "text/plain": [
+              "Downloading model.safetensors:   0%|          | 0.00/1.43G 
[00:00<?, ?B/s]"
+            ],
+            "application/vnd.jupyter.widget-view+json": {
+              "version_major": 2,
+              "version_minor": 0,
+              "model_id": "aee85cd919d24125acff1663fba0b47c"
+            }
+          },
+          "metadata": {}
+        },
+        {
+          "output_type": "display_data",
+          "data": {
+            "text/plain": [
+              "Downloading (…)olve/main/vocab.json:   0%|          | 0.00/899k 
[00:00<?, ?B/s]"
+            ],
+            "application/vnd.jupyter.widget-view+json": {
+              "version_major": 2,
+              "version_minor": 0,
+              "model_id": "0af8ad4eed2d49878fa83b5828d58a96"
+            }
+          },
+          "metadata": {}
+        },
+        {
+          "output_type": "display_data",
+          "data": {
+            "text/plain": [
+              "Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k 
[00:00<?, ?B/s]"
+            ],
+            "application/vnd.jupyter.widget-view+json": {
+              "version_major": 2,
+              "version_minor": 0,
+              "model_id": "1ed943a51c704ab7a72101b5b6182772"
+            }
+          },
+          "metadata": {}
+        },
+        {
+          "output_type": "display_data",
+          "data": {
+            "text/plain": [
+              "Downloading (…)/main/tokenizer.json:   0%|          | 
0.00/1.36M [00:00<?, ?B/s]"
+            ],
+            "application/vnd.jupyter.widget-view+json": {
+              "version_major": 2,
+              "version_minor": 0,
+              "model_id": "5b1dcbb533894267b184fd591d8ccdbc"
+            }
+          },
+          "metadata": {}
+        }
+      ]
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "distilbert_pipe(\"This restaurant is awesome\")"
+      ],
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/";
+        },
+        "id": "H3nYX9thy8ec",
+        "outputId": "826e3285-24b9-47a8-d2a6-835543fdcae7"
+      },
+      "execution_count": 3,
+      "outputs": [
+        {
+          "output_type": "execute_result",
+          "data": {
+            "text/plain": [
+              "[{'label': 'POSITIVE', 'score': 0.9998743534088135}]"
+            ]
+          },
+          "metadata": {},
+          "execution_count": 3
+        }
+      ]
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "roberta_large_pipe(\"This restaurant is awesome\")\n"
+      ],
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/";
+        },
+        "id": "IIfc94ODyjUg",
+        "outputId": "94ec8afb-ebfb-47ce-9813-48358741bc6b"
+      },
+      "execution_count": 4,
+      "outputs": [
+        {
+          "output_type": "execute_result",
+          "data": {
+            "text/plain": [
+              "[{'label': 'NEUTRAL', 'score': 0.7313134670257568}]"
+            ]
+          },
+          "metadata": {},
+          "execution_count": 4
+        }
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "## Define our Examples\n",
+        "\n",
+        "Next, we will define some examples that we can input into our 
pipeline, along with their correct classifications."
+      ],
+      "metadata": {
+        "id": "yd92MC7YEsTf"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "examples = [\n",
+        "    (\"This restaurant is awesome\", \"positive\"),\n",
+        "    (\"This restaurant is bad\", \"negative\"),\n",
+        "    (\"I feel fine\", \"neutral\"),\n",
+        "    (\"I love chocolate\", \"positive\"),\n",
+        "]"
+      ],
+      "metadata": {
+        "id": "5HAziWEavQws"
+      },
+      "execution_count": 5,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "To feed our examples into RunInference, we need to have distinct keys 
that can easily map to our model. In this case, we will define keys of the form 
`<model_name>-<actual_sentiment>` so that we can extract the actual sentiment 
of the example later."
+      ],
+      "metadata": {
+        "id": "r6GXL5PLFBY7"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "class FormatExamples(beam.DoFn):\n",
+        "  \"\"\"\n",
+        "  Map each example to a tuple of ('<model_name>-<actual_sentiment>', 
'example').\n",
+        "  We will use these keyes to map our elements to the correct 
models.\n",
+        "  \"\"\"\n",
+        "  def process(self, element: Tuple[str, str]) -> Iterable[Tuple[str, 
str]]:\n",
+        "    yield (f'distilbert-{element[1]}', element[0])\n",
+        "    yield (f'roberta-{element[1]}', element[0])"
+      ],
+      "metadata": {
+        "id": "p2uVwws8zRpg"
+      },
+      "execution_count": 6,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "Using the formatted keys, we will define a `KeyedModelHandler` which 
maps keys to the ModelHandler we should use for those keys. `KeyedModelHandler` 
also allows you to define an optional `max_models_per_worker_hint` which will 
limit the number of models that can be held in a single worker process at once. 
This is useful if you are worried about your worker running out of memory. See 
https://beam.apache.org/documentation/sdks/python-machine-learning/index.html#use-a-keyed-modelhandler
 for more info on managing memory."
+      ],
+      "metadata": {
+        "id": "IP65_5nNGIb8"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "per_key_mhs = [\n",
+        "    KeyModelMapping(['distilbert-positive', 'distilbert-neutral', 
'distilbert-negative'], distilbert_mh),\n",
+        "    KeyModelMapping(['roberta-positive', 'roberta-neutral', 
'roberta-negative'], roberta_mh)\n",
+        "]\n",
+        "mh = KeyedModelHandler(per_key_mhs, max_models_per_worker_hint=2)"
+      ],
+      "metadata": {
+        "id": "DZpfjeGL2hMG"
+      },
+      "execution_count": 7,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "## Postprocess our results\n",
+        "\n",
+        "The `RunInference` transform returns a Tuple of the original key and 
a `PredictionResult` object that contains the original example and the 
inference. From that, we will extract the data we care about. We will then 
group this data by the original example in order to compare each model's 
prediction."

Review Comment:
   This is the relevant diff - 
https://github.com/apache/beam/pull/28327/commits/056d6d2ebc509938686497e1a0baf8946cdcf136



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to