This is an automated email from the ASF dual-hosted git repository.
jrmccluskey pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new dbd719ba144 [WIP] Gemma Sentiment and Summarization Example Notebook
(#32172)
dbd719ba144 is described below
commit dbd719ba1448c13bc97237d1b701e2978c0e29d4
Author: Jack McCluskey <[email protected]>
AuthorDate: Wed Aug 14 09:50:17 2024 -0400
[WIP] Gemma Sentiment and Summarization Example Notebook (#32172)
---
.../gemma_2_sentiment_and_summarization.ipynb | 625 +++++++++++++++++++++
1 file changed, 625 insertions(+)
diff --git
a/examples/notebooks/beam-ml/gemma_2_sentiment_and_summarization.ipynb
b/examples/notebooks/beam-ml/gemma_2_sentiment_and_summarization.ipynb
new file mode 100644
index 00000000000..b45d9d7aea9
--- /dev/null
+++ b/examples/notebooks/beam-ml/gemma_2_sentiment_and_summarization.ipynb
@@ -0,0 +1,625 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "BrKf6TQ98qIJ",
+ "metadata": {
+ "id": "BrKf6TQ98qIJ"
+ },
+ "outputs": [],
+ "source": [
+ "# @title ###### Licensed to the Apache Software Foundation (ASF),
Version 2.0 (the \"License\")\n",
+ "\n",
+ "# Licensed to the Apache Software Foundation (ASF) under one\n",
+ "# or more contributor license agreements. See the NOTICE file\n",
+ "# distributed with this work for additional information\n",
+ "# regarding copyright ownership. The ASF licenses this file\n",
+ "# to you under the Apache License, Version 2.0 (the\n",
+ "# \"License\"); you may not use this file except in compliance\n",
+ "# with the License. You may obtain a copy of the License at\n",
+ "#\n",
+ "# http://www.apache.org/licenses/LICENSE-2.0\n",
+ "#\n",
+ "# Unless required by applicable law or agreed to in writing,\n",
+ "# software distributed under the License is distributed on an\n",
+ "# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n",
+ "# KIND, either express or implied. See the License for the\n",
+ "# specific language governing permissions and limitations\n",
+ "# under the License"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "hHg4SoUr8qIK",
+ "metadata": {
+ "id": "hHg4SoUr8qIK"
+ },
+ "source": [
+ "# Use Gemma to gauge sentiment and summarize conversations\n",
+ "\n",
+ "<table align=\"left\">\n",
+ " <td>\n",
+ " <a target=\"_blank\"
href=\"https://colab.research.google.com/github/apache/beam/blob/master/examples/notebooks/beam-ml/gemma_2_sentiment_and_summarization.ipynb\"><img
src=\"https://raw.githubusercontent.com/google/or-tools/main/tools/colab_32px.png\"
/>Run in Google Colab</a>\n",
+ " </td>\n",
+ " <td>\n",
+ " <a target=\"_blank\"
href=\"https://github.com/apache/beam/blob/master/examples/notebooks/beam-ml/gemma_2_sentiment_and_summarization.ipynb\"><img
src=\"https://raw.githubusercontent.com/google/or-tools/main/tools/github_32px.png\"
/>View source on GitHub</a>\n",
+ " </td>\n",
+ "</table>"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "yOs5SCyPdYNi",
+ "metadata": {
+ "id": "yOs5SCyPdYNi"
+ },
+ "source": [
+ "Gemma is a family of lightweight, state-of-the art open models built
from research and technology used to create the Gemini models. You can use
Gemma models in your Apache Beam inference pipelines.\n",
+ "\n",
+ "Because large language models (LLMs) like Gemma are versatile, you
can integrate them into business processes. The example in this notebook
demonstrates how to use Gemma to gauge the sentiment of a conversation,
summarize that conversation's content, and draft a reply for a difficult
conversation. The system allows a person to review the reply before it's sent
to customers. For more information, see the blog post [Gemma for Streaming ML
with Dataflow](https://developers.googlebl [...]
+ "\n",
+ "A requirement of this work is that customers who express a negative
sentiment receive a reply in near real-time. As a result, the workflow needs to
use a streaming data pipeline with an LLM that has minimal latency.\n",
+ "\n",
+ "## Use case\n",
+ "\n",
+ "An example use case is a bustling food chain grappling with analyzing
and storing a high volume of customer support requests. Customer interactions
include both chats generated by automated chatbots and nuanced conversations
that require the attention of live support staff.\n",
+ "\n",
+ "### Requirements\n",
+ "\n",
+ "To address both types of interactions, the workflow has the following
requirements.\n",
+ "\n",
+ "- It needs to efficiently manage and store chat data by summarizing
positive interactions for easy reference and future analysis.\n",
+ "\n",
+ "- It must use real-time issue detection and resolution.\n",
+ "\n",
+ "- Sentiment analysis must identify dissatisfied customers and
generate tailored responses to address their concerns.\n",
+ "\n",
+ "### Workflow\n",
+ "\n",
+ "To meet these requirements, the pipeline processes completed chat
messages in near real time. First, the pipeline uses Gemma to monitor the
sentiment of the customer chats. All chats are then summarized, with positive
or neutral sentiment chats sent directly to a data platform, BigQuery, by using
the available Dataflow I/Os.\n",
+ "\n",
+ "For chats that have a negative sentiment, the Gemma model crafts a
contextually appropriate response for the customer. This response is sent to a
human for review so that they can refine the message before it reaches the
customer.\n",
+ "\n",
+ "This example addresses important complexities inherent in using an
LLM within a pipeline. For example, processing the responses in code is
challenging because of the non-deterministic nature of the text. In this
example, the workflow requires the LLM to generate JSON responses, which is not
the default format. The worklow needs to parse and validate the response, a
process similar to processing data from sources that don't always have
correctly structured data.\n",
+ "\n",
+ "This workflow allows businesses to respond to customers faster and to
provide personalized responses when needed.\n",
+ "\n",
+ "- The automation of positive chat summarization allows support staff
to focus on more complex interactions.\n",
+ "- The scalability of the system makes it possible to adapt to
increasing chat volumes without compromising response quality.\n",
+ "\n",
+ "You can also use the in-depth analysis of chat data to drive
data-driven decision-making."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "tGZfleinj3xM",
+ "metadata": {
+ "id": "tGZfleinj3xM"
+ },
+ "source": [
+ "## The data processing pipeline"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "G-VpT7-FjcSu",
+ "metadata": {
+ "id": "G-VpT7-FjcSu"
+ },
+ "source": [
+ ". To
faciliate rapid testing of the pipeline, follow the guide [Run a pipeline with
GPU [...]
+ "\n",
+ "After you configure your environment, download the model
[gemma2_instruct_2b_en](https://www.kaggle.com/models/google/gemma-2/keras)
into a folder. In this example, the folder is named `gemma2`."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "jMrjYGW9spFG",
+ "metadata": {
+ "id": "jMrjYGW9spFG"
+ },
+ "source": [
+ "### Build the base image\n",
+ "\n",
+ "Add the following Dockerfile to your folder, and then build the base
image. Use the Dockerfile to build the image as you create the `pipeline.py`
file. The images are broken into two groups to facilitate testing and
development."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "29dOdG_kxzTv",
+ "metadata": {
+ "id": "29dOdG_kxzTv"
+ },
+ "outputs": [],
+ "source": [
+ "ARG SERVING_BUILD_IMAGE=tensorflow/tensorflow:2.16.1-gpu\n",
+ "\n",
+ "FROM ${SERVING_BUILD_IMAGE}\n",
+ "WORKDIR /workspace\n",
+ "\n",
+ "COPY gemma2 gemma2\n",
+ "RUN apt-get update -y && apt-get install -y cmake && apt-get install
-y vim"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "3eWt8AatpEuG",
+ "metadata": {
+ "id": "3eWt8AatpEuG"
+ },
+ "source": [
+ "When testing the pipeline code and when launchig the job on Dataflow,
test and launch from inside the container. This step prevents dependency
mismatches when running the pipeline on Dataflow."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "lyS0uYpsoeOW",
+ "metadata": {
+ "id": "lyS0uYpsoeOW"
+ },
+ "source": [
+ "The `requirements.txt` file contains the following dependencies: "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "K4gJQ0e9pCR4",
+ "metadata": {
+ "id": "K4gJQ0e9pCR4"
+ },
+ "outputs": [],
+ "source": [
+ "apache_beam[gcp]==2.54.0\n",
+ "keras_nlp==0.14.3\n",
+ "keras==3.4.1\n",
+ "jax[cuda12]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "k9gGtkqvn6Ar",
+ "metadata": {
+ "id": "k9gGtkqvn6Ar"
+ },
+ "source": [
+ "The next step includes the files needed to construct the pipeine. The
content of the `pipeline.py` file are contained in a later section of this
notebook."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "aqPS_p3Pp37b",
+ "metadata": {
+ "id": "aqPS_p3Pp37b"
+ },
+ "source": [
+ "Replace DOCKERFILE_IMAGE with the image that you built using the
first Dockerfile."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "EdUxmUjqx58r",
+ "metadata": {
+ "id": "EdUxmUjqx58r"
+ },
+ "outputs": [],
+ "source": [
+ "FROM <DOCKERFILE_IMAGE>\n",
+ "\n",
+ "WORKDIR /workspace\n",
+ "\n",
+ "# Copy files from the official SDK image, including the script and
dependencies.\n",
+ "COPY --from=apache/beam_python3.11_sdk:2.54.0 /opt/apache/beam
/opt/apache/beam\n",
+ "\n",
+ "\n",
+ "COPY requirements.txt requirements.txt\n",
+ "RUN pip install --upgrade --no-cache-dir pip \\\n",
+ " && pip install --no-cache-dir -r requirements.txt\n",
+ "\n",
+ "# Copy the model directory downloaded from Kaggle and the pipeline
code.\n",
+ "COPY pipeline.py pipeline.py\n",
+ "\n",
+ "# The colab was tested and run with a JAX backend to let Dataflow
workers\n",
+ "# pick up the environment needed to include in the Env of the
image.\n",
+ "ENV KERAS_BACKEND=\"jax\"\n",
+ "ENV XLA_PYTHON_CLIENT_MEM_FRACTION=\"0.9\"\n",
+ "\n",
+ "\n",
+ "# Set the entrypoint to the Apache Beam SDK launcher.\n",
+ "ENTRYPOINT [\"/opt/apache/beam/boot\"]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "i63FUxXwsSsO",
+ "metadata": {
+ "id": "i63FUxXwsSsO"
+ },
+ "source": [
+ "### Run the pipeline\n",
+ "\n",
+ "The following code creates and runs the pipeline.\n",
+ "\n",
+ "- The `pip install` steps are needed to run the code in the notebook,
but aren't needed when running the code in your container.\n",
+ "\n",
+ "- Without a GPU, the inference takes a long time to complete."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "ebb44382-ee7b-4cec-af67-1fe220cfb40d",
+ "metadata": {
+ "id": "ebb44382-ee7b-4cec-af67-1fe220cfb40d",
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "pip install apache_beam[gcp]==\"2.54.0\" keras_nlp==\"0.14.3\"
keras>=\"3\" jax[cuda12]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "oPgRBScKThZg",
+ "metadata": {
+ "id": "oPgRBScKThZg"
+ },
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "\n",
+ "os.environ[\"KERAS_BACKEND\"] = \"jax\" # Or \"tensorflow\" or
\"torch\".\n",
+ "os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"] = \"0.9\"\n",
+ "\n",
+ "import keras\n",
+ "import keras_nlp\n",
+ "import numpy as np\n",
+ "import json\n",
+ "import ast\n",
+ "import re\n",
+ "import logging\n",
+ "\n",
+ "import apache_beam as beam\n",
+ "from apache_beam.ml.inference import utils\n",
+ "from apache_beam.ml.inference.base import RunInference\n",
+ "from apache_beam.options import pipeline_options\n",
+ "from apache_beam.options.pipeline_options import
GoogleCloudOptions\n",
+ "from apache_beam.options.pipeline_options import PipelineOptions\n",
+ "from apache_beam.options.pipeline_options import SetupOptions\n",
+ "from apache_beam.options.pipeline_options import StandardOptions\n",
+ "from apache_beam.options.pipeline_options import WorkerOptions\n",
+ "from apache_beam.ml.inference import utils\n",
+ "from apache_beam.ml.inference.base import ModelHandler\n",
+ "from apache_beam.ml.inference.base import PredictionResult\n",
+ "from apache_beam.ml.inference.base import KeyedModelHandler\n",
+ "from keras_nlp.models import GemmaCausalLM\n",
+ "from typing import Any, Dict, Iterable, Optional, Sequence"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "0gicDesYWdbu",
+ "metadata": {
+ "id": "0gicDesYWdbu"
+ },
+ "source": [
+ "Set pipeline options and provide the input Pub/Sub topic. The options
that are commented out enable running the pipeline on Google Cloud Dataflow."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "wpG-ltLiTxWM",
+ "metadata": {
+ "id": "wpG-ltLiTxWM"
+ },
+ "outputs": [],
+ "source": [
+ "options = PipelineOptions()\n",
+ "options.view_as(StandardOptions).streaming = True\n",
+ "options.view_as(SetupOptions).save_main_session = True\n",
+ "\n",
+ "# options.view_as(StandardOptions).runner = \"dataflowrunner\"\n",
+ "# options.view_as(GoogleCloudOptions).project = <PROJECT>\n",
+ "# options.view_as(GoogleCloudOptions).temp_location= <TMP
LOCATION>\n",
+ "# options.view_as(GoogleCloudOptions).region= \"us-west1\"\n",
+ "# options.view_as(WorkerOptions).machine_type= \"g2-standard-4\"\n",
+ "# options.view_as(WorkerOptions).worker_harness_container_image =
<IMAGE YOU BUILT>\n",
+ "# options.view_as(WorkerOptions).disk_size_gb=200\n",
+ "#
options.view_as(GoogleCloudOptions).dataflow_service_options=[\"worker_accelerator=type:nvidia-l4;count:1;install-nvidia-driver\"]\n",
+ "\n",
+ "topic_reviews=\"<PubSub Topic>\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "g8sWSMRmW-Ab",
+ "metadata": {
+ "id": "g8sWSMRmW-Ab"
+ },
+ "source": [
+ "Define a custom model handler that will load the Gemma model and
handle inference calls."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "vRVCN3qBUAt9",
+ "metadata": {
+ "id": "vRVCN3qBUAt9"
+ },
+ "outputs": [],
+ "source": [
+ "class GemmaModelHandler(ModelHandler[str,\n",
+ "
PredictionResult,GemmaCausalLM\n",
+ " ]):\n",
+ " def __init__(\n",
+ " self,\n",
+ " model_name: str = \"\",\n",
+ " ):\n",
+ " \"\"\" Implementation of the ModelHandler interface for Gemma
using text as input.\n",
+ "\n",
+ " Example Usage::\n",
+ "\n",
+ " pcoll | RunInference(GemmaModelHandler())\n",
+ "\n",
+ " Args:\n",
+ " model_name: The Gemma model uri.\n",
+ " \"\"\"\n",
+ " self._model_name = model_name\n",
+ " self._env_vars = {}\n",
+ " def share_model_across_processes(self) -> bool:\n",
+ " \"\"\"Returns whether to share a single model in memory
across processes.\n",
+ "\n",
+ " This is useful when the loaded model is large, preventing
potential\n",
+ " out-of-memory issues when running the pipeline.\n",
+ "\n",
+ " Returns:\n",
+ " bool\n",
+ " \"\"\"\n",
+ " return True\n",
+ "\n",
+ " def load_model(self) -> GemmaCausalLM:\n",
+ " \"\"\"Loads and initializes a model for processing.\"\"\"\n",
+ " return
keras_nlp.models.GemmaCausalLM.from_preset(self._model_name)\n",
+ "\n",
+ " def run_inference(\n",
+ " self,\n",
+ " batch: Sequence[str],\n",
+ " model: GemmaCausalLM,\n",
+ " inference_args: Optional[Dict[str, Any]] = None\n",
+ " ) -> Iterable[PredictionResult]:\n",
+ " \"\"\"Runs inferences on a batch of text strings.\n",
+ "\n",
+ " Args:\n",
+ " batch: A sequence of examples as text strings.\n",
+ " model:\n",
+ " inference_args: Any additional arguments for an
inference.\n",
+ "\n",
+ " Returns:\n",
+ " An Iterable of type PredictionResult.\n",
+ " \"\"\"\n",
+ " # Loop each text string, and use a tuple to store the
inference results.\n",
+ " predictions = []\n",
+ " for one_text in batch:\n",
+ " result = model.generate(one_text, max_length=1024)\n",
+ " predictions.append(result)\n",
+ " return utils._convert_to_result(batch, predictions,
self._model_name)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "cSbAFPXmXPMc",
+ "metadata": {
+ "id": "cSbAFPXmXPMc"
+ },
+ "source": [
+ "We definte a prompt template to format a given input as well as
instruct the model on the task being asked of it. This block also has an
example input to the model."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "hqh-Ro5-UNqy",
+ "metadata": {
+ "id": "hqh-Ro5-UNqy"
+ },
+ "outputs": [],
+ "source": [
+ "prompt_template = \"\"\"\n",
+ "<prompt>\n",
+ "Provide the results of doing these two tasks on the chat history
provided below for the user {}\n",
+ "task 1 : assess if the tone is happy = 1 , neutral = 0 or unhappy =
-1\n",
+ "task 2 : summarize the text with a maximum of 512 characters\n",
+ "Return the answer as a JSON string with fields [sentiment, summary]
do NOT explain your answer\n",
+ "\n",
+ "@@@{}@@@\n",
+ "<answer>\n",
+ "\"\"\"\n",
+ "chat_text = \"\"\"\n",
+ "id 221: Hay I am really annoyed that your menu includes a pizza with
pineapple on it!\n",
+ "id 331: Sorry to hear that , but pineapple is nice on pizza\n",
+ "id 221: What a terriable thing to say! Its never ok, so unhappy right
now!\n",
+ "\"\"\"\n",
+ "\n",
+ "# Example input\n",
+ "chat =json.dumps({\"id\" : 42, \"user_id\" : 221 , \"chat_message\" :
chat_text})\n",
+ "print(chat)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "8UFfKvSeYn0b",
+ "metadata": {
+ "id": "8UFfKvSeYn0b"
+ },
+ "source": [
+ "Define pre and post-processing functions. `CreatePrompt` creates a
key-value pair of the chat ID and the formatted prompt. `extract_model_reply`
parses the response, extracting the JSON string we requested from the model;
however, the LLM is not *guaranteed* to return a JSON-formatted object, so we
also reaise an exception if the reply is malformed. This helper is then used in
the `SentimentAnalysis` `DoFn` to split out the sentiment score as well as the
summary of the text. The [...]
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "dDIys7XaUPvl",
+ "metadata": {
+ "id": "dDIys7XaUPvl"
+ },
+ "outputs": [],
+ "source": [
+ "keyed_model_handler =
KeyedModelHandler(GemmaModelHandler('gemma_2'))\n",
+ "\n",
+ "# Create the prompt by using the information from the chat.\n",
+ "class CreatePrompt(beam.DoFn):\n",
+ " def process(self, element, *args, **kwargs):\n",
+ " user_chat = json.loads(element)\n",
+ " chat_id = user_chat['id']\n",
+ " user_id = user_chat['user_id']\n",
+ " messages = user_chat['chat_message']\n",
+ " yield (chat_id, prompt_template.format(user_id, messages))\n",
+ "\n",
+ "def extract_model_reply(model_inference):\n",
+ " print(model_inference)\n",
+ " match = re.search(r\"(\\{[\\s\\S]*?\\})\", model_inference)\n",
+ " json_str = match.group(1)\n",
+ " print(json_str)\n",
+ " result = json.loads(json_str)\n",
+ " if all(key in result for key in ['sentiment', 'summary']):\n",
+ " return result\n",
+ " raise Exception('Malformed model reply')\n",
+ "\n",
+ " # @title\n",
+ "class SentimentAnalysis(beam.DoFn):\n",
+ " def process(self, element):\n",
+ " key = element[0]\n",
+ " match = re.search(r\"@@@([\\s\\S]*?)@@@\",
element[1].example)\n",
+ " chats = match.group(1)\n",
+ "\n",
+ " try:\n",
+ " # The result contains the prompt. Replace the prompt with
\"\".\n",
+ " result =
extract_model_reply(element[1].inference.replace(element[1].example, \"\"))\n",
+ " processed_result = (key, chats, result['sentiment'],
result['summary'])\n",
+ "\n",
+ " if (result['sentiment'] ==-1):\n",
+ " output = beam.TaggedOutput('negative',
processed_result)\n",
+ " else:\n",
+ " output = beam.TaggedOutput('main', processed_result)\n",
+ "\n",
+ " except Exception as err:\n",
+ " print(\"ERROR!\" + str(err))\n",
+ " output = beam.TaggedOutput('error', element)\n",
+ "\n",
+ " yield output\n",
+ "\n",
+ "gemma_inference = RunInference(keyed_model_handler)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "Yj9aQ0q8YLOn",
+ "metadata": {
+ "id": "Yj9aQ0q8YLOn"
+ },
+ "source": [
+ "Finally, execute the pipeline using the code below. To use the
example chat input created earlier instead of a custom Pub/Sub source, use
`chats = p | beam.Create([chat])` instead of the Pub/Sub read."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "1fb47a17-9563-46f6-9768-73f4802694e8",
+ "metadata": {
+ "id": "1fb47a17-9563-46f6-9768-73f4802694e8",
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "with beam.Pipeline(options=options) as p:\n",
+ " chats = (p | \"Read Topic\" >>\n",
+ " beam.io.ReadFromPubSub(subscription=topic_reviews)\n",
+ " | \"Parse\" >> beam.Map(lambda x:
x.decode(\"utf-8\")))\n",
+ " prompts = chats | \"Create Prompt\" >>
beam.ParDo(CreatePrompt())\n",
+ " results = prompts | \"RunInference-Gemma\" >> gemma_inference\n",
+ " filtered_results = results | \"Process Results\" >>
beam.ParDo(SentimentAnalysis()).with_outputs('main','negative','error')\n",
+ " generated_responses = (\n",
+ " filtered_results.negative\n",
+ " | \"Generate Response\" >> beam.Map(lambda x: ((x[0], x[3]),
\"<prompt>Generate an appology reponse for the user in this chat text: \" +
x[1] + \"<answer>\"))\n",
+ " | \"Gemma-Response\" >> gemma_inference\n",
+ " )\n",
+ "\n",
+ " generated_responses | \"Print Reponse\" >> beam.Map(lambda x:
logging.info(x))\n",
+ " filtered_results.main | \"Print Main\" >> beam.Map(lambda x:
logging.info(x))\n",
+ " filtered_results.error | \"Print Errors\" >> beam.Map(lambda x:
logging.info(x))"
+ ]
+ }
+ ],
+ "metadata": {
+ "colab": {
+ "provenance": []
+ },
+ "environment": {
+ "kernel": "apache-beam-2.57.0",
+ "name": ".m121",
+ "type": "gcloud",
+ "uri": "us-docker.pkg.dev/deeplearning-platform-release/gcr.io/:m121"
+ },
+ "kernelspec": {
+ "display_name": "Apache Beam 2.57.0 (Local)",
+ "language": "python",
+ "name": "apache-beam-2.57.0"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.14"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}