This is an automated email from the ASF dual-hosted git repository.

jrmccluskey pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new dbd719ba144 [WIP] Gemma Sentiment and Summarization Example Notebook 
(#32172)
dbd719ba144 is described below

commit dbd719ba1448c13bc97237d1b701e2978c0e29d4
Author: Jack McCluskey <[email protected]>
AuthorDate: Wed Aug 14 09:50:17 2024 -0400

    [WIP] Gemma Sentiment and Summarization Example Notebook (#32172)
---
 .../gemma_2_sentiment_and_summarization.ipynb      | 625 +++++++++++++++++++++
 1 file changed, 625 insertions(+)

diff --git 
a/examples/notebooks/beam-ml/gemma_2_sentiment_and_summarization.ipynb 
b/examples/notebooks/beam-ml/gemma_2_sentiment_and_summarization.ipynb
new file mode 100644
index 00000000000..b45d9d7aea9
--- /dev/null
+++ b/examples/notebooks/beam-ml/gemma_2_sentiment_and_summarization.ipynb
@@ -0,0 +1,625 @@
+{
+  "cells": [
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "id": "BrKf6TQ98qIJ",
+      "metadata": {
+        "id": "BrKf6TQ98qIJ"
+      },
+      "outputs": [],
+      "source": [
+        "# @title ###### Licensed to the Apache Software Foundation (ASF), 
Version 2.0 (the \"License\")\n",
+        "\n",
+        "# Licensed to the Apache Software Foundation (ASF) under one\n",
+        "# or more contributor license agreements. See the NOTICE file\n",
+        "# distributed with this work for additional information\n",
+        "# regarding copyright ownership. The ASF licenses this file\n",
+        "# to you under the Apache License, Version 2.0 (the\n",
+        "# \"License\"); you may not use this file except in compliance\n",
+        "# with the License. You may obtain a copy of the License at\n",
+        "#\n",
+        "#   http://www.apache.org/licenses/LICENSE-2.0\n";,
+        "#\n",
+        "# Unless required by applicable law or agreed to in writing,\n",
+        "# software distributed under the License is distributed on an\n",
+        "# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n",
+        "# KIND, either express or implied. See the License for the\n",
+        "# specific language governing permissions and limitations\n",
+        "# under the License"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "id": "hHg4SoUr8qIK",
+      "metadata": {
+        "id": "hHg4SoUr8qIK"
+      },
+      "source": [
+        "# Use Gemma to gauge sentiment and summarize conversations\n",
+        "\n",
+        "<table align=\"left\">\n",
+        "  <td>\n",
+        "    <a target=\"_blank\" 
href=\"https://colab.research.google.com/github/apache/beam/blob/master/examples/notebooks/beam-ml/gemma_2_sentiment_and_summarization.ipynb\";><img
 
src=\"https://raw.githubusercontent.com/google/or-tools/main/tools/colab_32px.png\";
 />Run in Google Colab</a>\n",
+        "  </td>\n",
+        "  <td>\n",
+        "    <a target=\"_blank\" 
href=\"https://github.com/apache/beam/blob/master/examples/notebooks/beam-ml/gemma_2_sentiment_and_summarization.ipynb\";><img
 
src=\"https://raw.githubusercontent.com/google/or-tools/main/tools/github_32px.png\";
 />View source on GitHub</a>\n",
+        "  </td>\n",
+        "</table>"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "id": "yOs5SCyPdYNi",
+      "metadata": {
+        "id": "yOs5SCyPdYNi"
+      },
+      "source": [
+        "Gemma is a family of lightweight, state-of-the art open models built 
from research and technology used to create the Gemini models. You can use 
Gemma models in your Apache Beam inference pipelines.\n",
+        "\n",
+        "Because large language models (LLMs) like Gemma are versatile, you 
can integrate them into business processes. The example in this notebook 
demonstrates how to use Gemma to gauge the sentiment of a conversation, 
summarize that conversation's content, and draft a reply for a difficult 
conversation. The system allows a person to review the reply before it's sent 
to customers. For more information, see the blog post [Gemma for Streaming ML 
with Dataflow](https://developers.googlebl [...]
+        "\n",
+        "A requirement of this work is that customers who express a negative 
sentiment receive a reply in near real-time. As a result, the workflow needs to 
use a streaming data pipeline with an LLM that has minimal latency.\n",
+        "\n",
+        "## Use case\n",
+        "\n",
+        "An example use case is a bustling food chain grappling with analyzing 
and storing a high volume of customer support requests. Customer interactions 
include both chats generated by automated chatbots and nuanced conversations 
that require the attention of live support staff.\n",
+        "\n",
+        "### Requirements\n",
+        "\n",
+        "To address both types of interactions, the workflow has the following 
requirements.\n",
+        "\n",
+        "- It needs to efficiently manage and store chat data by summarizing 
positive interactions for easy reference and future analysis.\n",
+        "\n",
+        "- It must use real-time issue detection and resolution.\n",
+        "\n",
+        "- Sentiment analysis must identify dissatisfied customers and 
generate tailored responses to address their concerns.\n",
+        "\n",
+        "### Workflow\n",
+        "\n",
+        "To meet these requirements, the pipeline processes completed chat 
messages in near real time. First, the pipeline uses Gemma to monitor the 
sentiment of the customer chats. All chats are then summarized, with positive 
or neutral sentiment chats sent directly to a data platform, BigQuery, by using 
the available Dataflow I/Os.\n",
+        "\n",
+        "For chats that have a negative sentiment, the Gemma model crafts a 
contextually appropriate response for the customer. This response is sent to a 
human for review so that they can refine the message before it reaches the 
customer.\n",
+        "\n",
+        "This example addresses important complexities inherent in using an 
LLM within a pipeline. For example, processing the responses in code is 
challenging because of the non-deterministic nature of the text. In this 
example, the workflow requires the LLM to generate JSON responses, which is not 
the default format. The worklow needs to parse and validate the response, a 
process similar to processing data from sources that don't always have 
correctly structured data.\n",
+        "\n",
+        "This workflow allows businesses to respond to customers faster and to 
provide personalized responses when needed.\n",
+        "\n",
+        "- The automation of positive chat summarization allows support staff 
to focus on more complex interactions.\n",
+        "- The scalability of the system makes it possible to adapt to 
increasing chat volumes without compromising response quality.\n",
+        "\n",
+        "You can also use the in-depth analysis of chat data to drive 
data-driven decision-making."
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "id": "tGZfleinj3xM",
+      "metadata": {
+        "id": "tGZfleinj3xM"
+      },
+      "source": [
+        "## The data processing pipeline"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "id": "G-VpT7-FjcSu",
+      "metadata": {
+        "id": "G-VpT7-FjcSu"
+      },
+      "source": [
+        "![Screenshot 2024-08-08 at 
11.15.41.png](data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAACVwAAAVECAYAAAAlH0OMAAABXmlDQ1BJQ0MgUHJvZmlsZQAAKJF1kEFLAlEQx/+WpahEQYcOBXZZKExEjU4dzCQChdWKskO0rpsKqz3ebkS3PkT0EaJvIEGHOlTXICjwVETQoU7BXkq2eW6lFs0w/H/8mXlvGKAnoDCmuwFUaybPLcwF1/LrQc8zvJQ+TCOsqAZLyHKaWvCt3WHdwiX0Zkq81cifX0izT1Im03/lzY69/O3vCl9RM1TSD6q4yrgJuCLE8q7JBO8TD3NaivhAcMnhY8EFh09bPcu5JPE18aBaVorE98ShQodf6uCqvqN+7SC2D2i1lSXSEapRzCOFNGUQMmKIUsaQReqfmXhrJoltMOyBo4ISyjBpOkEOgw6NeBE1qAgjRBxF
 [...]
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "id": "mXtb21lMj_rU",
+      "metadata": {
+        "id": "mXtb21lMj_rU"
+      },
+      "source": [
+        "At a high level, the pipeline has the following steps:\n",
+        "\n",
+        "1. Read the review data from Pub/Sub, the event messaging source. 
This data contains the chat ID and the chat history as a JSON payload. This 
payload is processed in the pipeline.\n",
+        "1. Pass the text from the messages to Gemma with a prompt.\n",
+        "1. The pipeline requests that the model complete the following two 
tasks:\n",
+        " *  Attach a sentiment score to the message, by using one of the 
following three values: `1` for a positive chat, `0` for a neutral chat, and 
`-1` for a negative chat.\n",
+        " *  Provide a one-sentence summary of the chat.\n",
+        "1. The pipeline branches, depending on the sentiment score:\n",
+        " * If the score is `1` or `0`, the chat and its summarization are 
sent to a data analytics system for storage and future analysis.\n",
+        " * If the score is `-1`, the Gemma model drafts a response. This 
response and the chat information are sent to an event messaging system that 
connects the pipeline and other applications. This step allows a person to 
review the content of the response.  "
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "id": "nETbaxwZk7us",
+      "metadata": {
+        "id": "nETbaxwZk7us"
+      },
+      "source": [
+        "## Build the pipeline\n",
+        "\n",
+        "This section provides the code needed to run the pipeline.\n",
+        "\n",
+        "### Before you begin\n",
+        "\n",
+        "Although you can use CPUs for testing and development, for a 
production Dataflow ML system we recommend that you use GPUs. When you use GPUs 
with Dataflow, we recommend that you use a custom container. For more 
information about configuring GPUs and custom containers with Dataflow, see 
[Best practices for working with Dataflow 
GPUs](https://cloud.google.com/dataflow/docs/gpu/develop-with-gpus). To 
faciliate rapid testing of the pipeline, follow the guide [Run a pipeline with 
GPU [...]
+        "\n",
+        "After you configure your environment, download the model 
[gemma2_instruct_2b_en](https://www.kaggle.com/models/google/gemma-2/keras) 
into a folder. In this example, the folder is named `gemma2`."
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "id": "jMrjYGW9spFG",
+      "metadata": {
+        "id": "jMrjYGW9spFG"
+      },
+      "source": [
+        "### Build the base image\n",
+        "\n",
+        "Add the following Dockerfile to your folder, and then build the base 
image. Use the Dockerfile to build the image as you create the `pipeline.py` 
file. The images are broken into two groups to facilitate testing and 
development."
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "id": "29dOdG_kxzTv",
+      "metadata": {
+        "id": "29dOdG_kxzTv"
+      },
+      "outputs": [],
+      "source": [
+        "ARG SERVING_BUILD_IMAGE=tensorflow/tensorflow:2.16.1-gpu\n",
+        "\n",
+        "FROM ${SERVING_BUILD_IMAGE}\n",
+        "WORKDIR /workspace\n",
+        "\n",
+        "COPY gemma2  gemma2\n",
+        "RUN apt-get update -y && apt-get install -y cmake && apt-get install 
-y vim"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "id": "3eWt8AatpEuG",
+      "metadata": {
+        "id": "3eWt8AatpEuG"
+      },
+      "source": [
+        "When testing the pipeline code and when launchig the job on Dataflow, 
test and launch from inside the container. This step prevents dependency 
mismatches when running the pipeline on Dataflow."
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "id": "lyS0uYpsoeOW",
+      "metadata": {
+        "id": "lyS0uYpsoeOW"
+      },
+      "source": [
+        "The `requirements.txt` file contains the following dependencies:  "
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "id": "K4gJQ0e9pCR4",
+      "metadata": {
+        "id": "K4gJQ0e9pCR4"
+      },
+      "outputs": [],
+      "source": [
+        "apache_beam[gcp]==2.54.0\n",
+        "keras_nlp==0.14.3\n",
+        "keras==3.4.1\n",
+        "jax[cuda12]"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "id": "k9gGtkqvn6Ar",
+      "metadata": {
+        "id": "k9gGtkqvn6Ar"
+      },
+      "source": [
+        "The next step includes the files needed to construct the pipeine. The 
content of the `pipeline.py` file are contained in a later section of this 
notebook."
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "id": "aqPS_p3Pp37b",
+      "metadata": {
+        "id": "aqPS_p3Pp37b"
+      },
+      "source": [
+        "Replace DOCKERFILE_IMAGE with the image that you built using the 
first Dockerfile."
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "id": "EdUxmUjqx58r",
+      "metadata": {
+        "id": "EdUxmUjqx58r"
+      },
+      "outputs": [],
+      "source": [
+        "FROM <DOCKERFILE_IMAGE>\n",
+        "\n",
+        "WORKDIR /workspace\n",
+        "\n",
+        "# Copy files from the official SDK image, including the script and 
dependencies.\n",
+        "COPY --from=apache/beam_python3.11_sdk:2.54.0 /opt/apache/beam 
/opt/apache/beam\n",
+        "\n",
+        "\n",
+        "COPY requirements.txt requirements.txt\n",
+        "RUN pip install --upgrade --no-cache-dir pip \\\n",
+        "    && pip install --no-cache-dir -r  requirements.txt\n",
+        "\n",
+        "# Copy the model directory downloaded from Kaggle and the pipeline 
code.\n",
+        "COPY pipeline.py pipeline.py\n",
+        "\n",
+        "# The colab was tested and run with a JAX backend to let Dataflow 
workers\n",
+        "# pick up the environment needed to include in the Env of the 
image.\n",
+        "ENV KERAS_BACKEND=\"jax\"\n",
+        "ENV XLA_PYTHON_CLIENT_MEM_FRACTION=\"0.9\"\n",
+        "\n",
+        "\n",
+        "# Set the entrypoint to the Apache Beam SDK launcher.\n",
+        "ENTRYPOINT [\"/opt/apache/beam/boot\"]"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "id": "i63FUxXwsSsO",
+      "metadata": {
+        "id": "i63FUxXwsSsO"
+      },
+      "source": [
+        "### Run the pipeline\n",
+        "\n",
+        "The following code creates and runs the pipeline.\n",
+        "\n",
+        "- The `pip install` steps are needed to run the code in the notebook, 
but aren't needed when running the code in your container.\n",
+        "\n",
+        "- Without a GPU, the inference takes a long time to complete."
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "id": "ebb44382-ee7b-4cec-af67-1fe220cfb40d",
+      "metadata": {
+        "id": "ebb44382-ee7b-4cec-af67-1fe220cfb40d",
+        "tags": []
+      },
+      "outputs": [],
+      "source": [
+        "pip install apache_beam[gcp]==\"2.54.0\" keras_nlp==\"0.14.3\" 
keras>=\"3\" jax[cuda12]"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "id": "oPgRBScKThZg",
+      "metadata": {
+        "id": "oPgRBScKThZg"
+      },
+      "outputs": [],
+      "source": [
+        "import os\n",
+        "\n",
+        "os.environ[\"KERAS_BACKEND\"] = \"jax\"  # Or \"tensorflow\" or 
\"torch\".\n",
+        "os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"] = \"0.9\"\n",
+        "\n",
+        "import keras\n",
+        "import keras_nlp\n",
+        "import numpy as np\n",
+        "import json\n",
+        "import ast\n",
+        "import re\n",
+        "import logging\n",
+        "\n",
+        "import apache_beam as beam\n",
+        "from apache_beam.ml.inference import utils\n",
+        "from apache_beam.ml.inference.base import RunInference\n",
+        "from apache_beam.options import pipeline_options\n",
+        "from apache_beam.options.pipeline_options import 
GoogleCloudOptions\n",
+        "from apache_beam.options.pipeline_options import PipelineOptions\n",
+        "from apache_beam.options.pipeline_options import SetupOptions\n",
+        "from apache_beam.options.pipeline_options import StandardOptions\n",
+        "from apache_beam.options.pipeline_options import WorkerOptions\n",
+        "from apache_beam.ml.inference import utils\n",
+        "from apache_beam.ml.inference.base import ModelHandler\n",
+        "from apache_beam.ml.inference.base import PredictionResult\n",
+        "from apache_beam.ml.inference.base import KeyedModelHandler\n",
+        "from keras_nlp.models import GemmaCausalLM\n",
+        "from typing import Any, Dict, Iterable, Optional, Sequence"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "id": "0gicDesYWdbu",
+      "metadata": {
+        "id": "0gicDesYWdbu"
+      },
+      "source": [
+        "Set pipeline options and provide the input Pub/Sub topic. The options 
that are commented out enable running the pipeline on Google Cloud Dataflow."
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "id": "wpG-ltLiTxWM",
+      "metadata": {
+        "id": "wpG-ltLiTxWM"
+      },
+      "outputs": [],
+      "source": [
+        "options = PipelineOptions()\n",
+        "options.view_as(StandardOptions).streaming = True\n",
+        "options.view_as(SetupOptions).save_main_session = True\n",
+        "\n",
+        "# options.view_as(StandardOptions).runner = \"dataflowrunner\"\n",
+        "# options.view_as(GoogleCloudOptions).project = <PROJECT>\n",
+        "# options.view_as(GoogleCloudOptions).temp_location= <TMP 
LOCATION>\n",
+        "# options.view_as(GoogleCloudOptions).region= \"us-west1\"\n",
+        "# options.view_as(WorkerOptions).machine_type= \"g2-standard-4\"\n",
+        "# options.view_as(WorkerOptions).worker_harness_container_image = 
<IMAGE YOU BUILT>\n",
+        "# options.view_as(WorkerOptions).disk_size_gb=200\n",
+        "# 
options.view_as(GoogleCloudOptions).dataflow_service_options=[\"worker_accelerator=type:nvidia-l4;count:1;install-nvidia-driver\"]\n",
+        "\n",
+        "topic_reviews=\"<PubSub Topic>\""
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "id": "g8sWSMRmW-Ab",
+      "metadata": {
+        "id": "g8sWSMRmW-Ab"
+      },
+      "source": [
+        "Define a custom model handler that will load the Gemma model and 
handle inference calls."
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "id": "vRVCN3qBUAt9",
+      "metadata": {
+        "id": "vRVCN3qBUAt9"
+      },
+      "outputs": [],
+      "source": [
+        "class GemmaModelHandler(ModelHandler[str,\n",
+        "                                     
PredictionResult,GemmaCausalLM\n",
+        "                                     ]):\n",
+        "    def __init__(\n",
+        "        self,\n",
+        "        model_name: str = \"\",\n",
+        "    ):\n",
+        "        \"\"\" Implementation of the ModelHandler interface for Gemma 
using text as input.\n",
+        "\n",
+        "        Example Usage::\n",
+        "\n",
+        "          pcoll | RunInference(GemmaModelHandler())\n",
+        "\n",
+        "        Args:\n",
+        "          model_name: The Gemma model uri.\n",
+        "        \"\"\"\n",
+        "        self._model_name = model_name\n",
+        "        self._env_vars = {}\n",
+        "    def share_model_across_processes(self)  -> bool:\n",
+        "        \"\"\"Returns whether to share a single model in memory 
across processes.\n",
+        "\n",
+        "        This is useful when the loaded model is large, preventing 
potential\n",
+        "        out-of-memory issues when running the pipeline.\n",
+        "\n",
+        "        Returns:\n",
+        "          bool\n",
+        "        \"\"\"\n",
+        "        return True\n",
+        "\n",
+        "    def load_model(self) -> GemmaCausalLM:\n",
+        "        \"\"\"Loads and initializes a model for processing.\"\"\"\n",
+        "        return 
keras_nlp.models.GemmaCausalLM.from_preset(self._model_name)\n",
+        "\n",
+        "    def run_inference(\n",
+        "        self,\n",
+        "        batch: Sequence[str],\n",
+        "        model: GemmaCausalLM,\n",
+        "        inference_args: Optional[Dict[str, Any]] = None\n",
+        "    ) -> Iterable[PredictionResult]:\n",
+        "        \"\"\"Runs inferences on a batch of text strings.\n",
+        "\n",
+        "        Args:\n",
+        "          batch: A sequence of examples as text strings.\n",
+        "          model:\n",
+        "          inference_args: Any additional arguments for an 
inference.\n",
+        "\n",
+        "        Returns:\n",
+        "          An Iterable of type PredictionResult.\n",
+        "        \"\"\"\n",
+        "        # Loop each text string, and use a tuple to store the 
inference results.\n",
+        "        predictions = []\n",
+        "        for one_text in batch:\n",
+        "            result = model.generate(one_text, max_length=1024)\n",
+        "            predictions.append(result)\n",
+        "        return utils._convert_to_result(batch, predictions, 
self._model_name)"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "id": "cSbAFPXmXPMc",
+      "metadata": {
+        "id": "cSbAFPXmXPMc"
+      },
+      "source": [
+        "We definte a prompt template to format a given input as well as 
instruct the model on the task being asked of it. This block also has an 
example input to the model."
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "id": "hqh-Ro5-UNqy",
+      "metadata": {
+        "id": "hqh-Ro5-UNqy"
+      },
+      "outputs": [],
+      "source": [
+        "prompt_template = \"\"\"\n",
+        "<prompt>\n",
+        "Provide the results of doing these two tasks on the chat history 
provided below for the user {}\n",
+        "task 1 : assess if the tone is happy = 1 , neutral = 0 or unhappy = 
-1\n",
+        "task 2 : summarize the text with a maximum of 512 characters\n",
+        "Return the answer as a JSON string with fields [sentiment, summary] 
do NOT explain your answer\n",
+        "\n",
+        "@@@{}@@@\n",
+        "<answer>\n",
+        "\"\"\"\n",
+        "chat_text = \"\"\"\n",
+        "id 221: Hay I am really annoyed that your menu includes a pizza with 
pineapple on it!\n",
+        "id 331: Sorry to hear that , but pineapple is nice on pizza\n",
+        "id 221: What a terriable thing to say! Its never ok, so unhappy right 
now!\n",
+        "\"\"\"\n",
+        "\n",
+        "# Example input\n",
+        "chat =json.dumps({\"id\" : 42, \"user_id\" : 221 , \"chat_message\" : 
chat_text})\n",
+        "print(chat)"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "id": "8UFfKvSeYn0b",
+      "metadata": {
+        "id": "8UFfKvSeYn0b"
+      },
+      "source": [
+        "Define pre and post-processing functions. `CreatePrompt` creates a 
key-value pair of the chat ID and the formatted prompt. `extract_model_reply` 
parses the response, extracting the JSON string we requested from the model; 
however, the LLM is not *guaranteed* to return a JSON-formatted object, so we 
also reaise an exception if the reply is malformed. This helper is then used in 
the `SentimentAnalysis` `DoFn` to split out the sentiment score as well as the 
summary of the text. The [...]
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "id": "dDIys7XaUPvl",
+      "metadata": {
+        "id": "dDIys7XaUPvl"
+      },
+      "outputs": [],
+      "source": [
+        "keyed_model_handler = 
KeyedModelHandler(GemmaModelHandler('gemma_2'))\n",
+        "\n",
+        "# Create the prompt by using the information from the chat.\n",
+        "class CreatePrompt(beam.DoFn):\n",
+        "  def process(self, element, *args, **kwargs):\n",
+        "    user_chat = json.loads(element)\n",
+        "    chat_id = user_chat['id']\n",
+        "    user_id = user_chat['user_id']\n",
+        "    messages = user_chat['chat_message']\n",
+        "    yield (chat_id, prompt_template.format(user_id, messages))\n",
+        "\n",
+        "def extract_model_reply(model_inference):\n",
+        "    print(model_inference)\n",
+        "    match = re.search(r\"(\\{[\\s\\S]*?\\})\", model_inference)\n",
+        "    json_str = match.group(1)\n",
+        "    print(json_str)\n",
+        "    result = json.loads(json_str)\n",
+        "    if all(key in result for key in ['sentiment', 'summary']):\n",
+        "        return result\n",
+        "    raise Exception('Malformed model reply')\n",
+        "\n",
+        "    # @title\n",
+        "class SentimentAnalysis(beam.DoFn):\n",
+        "    def process(self, element):\n",
+        "        key = element[0]\n",
+        "        match = re.search(r\"@@@([\\s\\S]*?)@@@\", 
element[1].example)\n",
+        "        chats = match.group(1)\n",
+        "\n",
+        "        try:\n",
+        "            # The result contains the prompt. Replace the prompt with 
\"\".\n",
+        "            result = 
extract_model_reply(element[1].inference.replace(element[1].example, \"\"))\n",
+        "            processed_result = (key, chats, result['sentiment'], 
result['summary'])\n",
+        "\n",
+        "            if (result['sentiment'] ==-1):\n",
+        "              output = beam.TaggedOutput('negative', 
processed_result)\n",
+        "            else:\n",
+        "              output = beam.TaggedOutput('main', processed_result)\n",
+        "\n",
+        "        except Exception as err:\n",
+        "            print(\"ERROR!\" + str(err))\n",
+        "            output = beam.TaggedOutput('error', element)\n",
+        "\n",
+        "        yield output\n",
+        "\n",
+        "gemma_inference = RunInference(keyed_model_handler)"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "id": "Yj9aQ0q8YLOn",
+      "metadata": {
+        "id": "Yj9aQ0q8YLOn"
+      },
+      "source": [
+        "Finally, execute the pipeline using the code below. To use the 
example chat input created earlier instead of a custom Pub/Sub source, use 
`chats = p | beam.Create([chat])` instead of the Pub/Sub read."
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "id": "1fb47a17-9563-46f6-9768-73f4802694e8",
+      "metadata": {
+        "id": "1fb47a17-9563-46f6-9768-73f4802694e8",
+        "tags": []
+      },
+      "outputs": [],
+      "source": [
+        "\n",
+        "with beam.Pipeline(options=options) as p:\n",
+        "  chats = (p | \"Read Topic\" >>\n",
+        "            beam.io.ReadFromPubSub(subscription=topic_reviews)\n",
+        "            | \"Parse\" >> beam.Map(lambda x: 
x.decode(\"utf-8\")))\n",
+        "  prompts = chats | \"Create Prompt\" >> 
beam.ParDo(CreatePrompt())\n",
+        "  results = prompts | \"RunInference-Gemma\" >> gemma_inference\n",
+        "  filtered_results = results | \"Process Results\" >> 
beam.ParDo(SentimentAnalysis()).with_outputs('main','negative','error')\n",
+        "  generated_responses = (\n",
+        "      filtered_results.negative\n",
+        "       | \"Generate Response\" >> beam.Map(lambda x: ((x[0], x[3]), 
\"<prompt>Generate an appology reponse for the user in this chat text: \" + 
x[1] + \"<answer>\"))\n",
+        "       | \"Gemma-Response\" >> gemma_inference\n",
+        "       )\n",
+        "\n",
+        "  generated_responses | \"Print Reponse\" >> beam.Map(lambda x: 
logging.info(x))\n",
+        "  filtered_results.main | \"Print Main\" >> beam.Map(lambda x: 
logging.info(x))\n",
+        "  filtered_results.error | \"Print Errors\" >> beam.Map(lambda x: 
logging.info(x))"
+      ]
+    }
+  ],
+  "metadata": {
+    "colab": {
+      "provenance": []
+    },
+    "environment": {
+      "kernel": "apache-beam-2.57.0",
+      "name": ".m121",
+      "type": "gcloud",
+      "uri": "us-docker.pkg.dev/deeplearning-platform-release/gcr.io/:m121"
+    },
+    "kernelspec": {
+      "display_name": "Apache Beam 2.57.0 (Local)",
+      "language": "python",
+      "name": "apache-beam-2.57.0"
+    },
+    "language_info": {
+      "codemirror_mode": {
+        "name": "ipython",
+        "version": 3
+      },
+      "file_extension": ".py",
+      "mimetype": "text/x-python",
+      "name": "python",
+      "nbconvert_exporter": "python",
+      "pygments_lexer": "ipython3",
+      "version": "3.10.14"
+    }
+  },
+  "nbformat": 4,
+  "nbformat_minor": 5
+}

Reply via email to