ferruzzi commented on code in PR #39500:
URL: https://github.com/apache/airflow/pull/39500#discussion_r1596061444


##########
airflow/providers/amazon/aws/operators/bedrock.py:
##########
@@ -664,3 +669,198 @@ def execute(self, context: Context) -> str:
             )
 
         return ingestion_job_id
+
+
+class BedrockRaGOperator(AwsBaseOperator[BedrockAgentRuntimeHook]):
+    """
+    Query a knowledge base and generate responses based on the retrieved 
results with sources citations.
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the 
guide:
+        :ref:`howto/operator:BedrockRaGOperator`
+
+    :param input: The query to be made to the knowledge base. (templated)
+    :param source_type: The type of resource that is queried by the request. 
(templated)
+        Must be one of 'KNOWLEDGE_BASE' or 'EXTERNAL_SOURCES', and the 
appropriate config values must also be provided.
+        If set to 'KNOWLEDGE_BASE' then `knowledge_base_id` must be provided, 
and `vector_search_config` may be.
+        If set to `EXTERNAL_SOURCES` then `sources` must also be provided.
+    :param model_arn: The ARN of the foundation model used to generate a 
response. (templated)
+    :param prompt_template: The template for the prompt that's sent to the 
model for response generation.
+        You can include prompt placeholders, which become replaced before the 
prompt is sent to the model
+        to provide instructions and context to the model. In addition, you can 
include XML tags to delineate
+        meaningful sections of the prompt template. (templated)
+    :param knowledge_base_id: The unique identifier of the knowledge base that 
is queried. (templated)
+            Can only be specified if source_type='KNOWLEDGE_BASE'.
+    :param vector_search_config: How the results from the vector search should 
be returned. (templated)
+        Can only be specified if source_type='KNOWLEDGE_BASE'.
+        For more information, see 
https://docs.aws.amazon.com/bedrock/latest/userguide/kb-test-config.html.
+    :param sources: The documents used as reference for the response. 
(templated)
+        Can only be specified if source_type='EXTERNAL_SOURCES'
+    :param rag_kwargs: Additional keyword arguments to pass to the  API call. 
(templated)
+    """
+
+    aws_hook_class = BedrockAgentRuntimeHook
+    template_fields: Sequence[str] = aws_template_fields(
+        "input",
+        "source_type",
+        "model_arn",
+        "prompt_template",
+        "knowledge_base_id",
+        "vector_search_config",
+        "sources",
+        "rag_kwargs",
+    )
+
+    def __init__(
+        self,
+        input: str,
+        source_type: str,
+        model_arn: str,
+        prompt_template: str | None = None,
+        knowledge_base_id: str | None = None,
+        vector_search_config: dict[str, Any] | None = None,
+        sources: list[dict[str, Any]] | None = None,
+        rag_kwargs: dict[str, Any] | None = None,
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+        self.input = input
+        self.prompt_template = prompt_template
+        self.source_type = source_type.upper()
+        self.knowledge_base_id = knowledge_base_id
+        self.model_arn = model_arn
+        self.vector_search_config = vector_search_config
+        self.sources = sources
+        self.rag_kwargs = rag_kwargs or {}
+
+    def validate_inputs(self):
+        if self.source_type == "KNOWLEDGE_BASE":
+            if self.knowledge_base_id is None:
+                raise AttributeError(
+                    "If `source_type` is set to 'KNOWLEDGE_BASE' then 
`knowledge_base_id` must be provided."
+                )
+            if self.sources is not None:
+                raise AttributeError(
+                    "`sources` can not be used when `source_type` is set to 
'KNOWLEDGE_BASE'."
+                )
+        elif self.source_type == "EXTERNAL_SOURCES":
+            if not self.sources is not None:
+                raise AttributeError(
+                    "If `source_type` is set to `EXTERNAL_SOURCES` then 
`sources` must also be provided."
+                )
+            if self.vector_search_config or self.knowledge_base_id:
+                raise AttributeError(
+                    "`vector_search_config` and `knowledge_base_id` can not be 
used "
+                    "when `source_type` is set to `EXTERNAL_SOURCES`"
+                )
+        else:
+            raise AttributeError(
+                "`source_type` must be one of 'KNOWLEDGE_BASE' or 
'EXTERNAL_SOURCES', "
+                "and the appropriate config values must also be provided."
+            )
+
+    def build_rag_config(self) -> dict[str, Any]:
+        result: dict[str, Any] = {}
+        base_config: dict[str, Any] = {
+            "modelArn": self.model_arn,
+        }
+
+        if self.prompt_template:
+            base_config["generationConfiguration"] = {
+                "promptTemplate": {"textPromptTemplate": self.prompt_template}
+            }
+
+        if self.source_type == "KNOWLEDGE_BASE":
+            if self.vector_search_config:
+                base_config["retrievalConfiguration"] = {
+                    "vectorSearchConfiguration": self.vector_search_config
+                }
+
+            result = {
+                "type": self.source_type,
+                "knowledgeBaseConfiguration": {
+                    **base_config,
+                    "knowledgeBaseId": self.knowledge_base_id,
+                },
+            }
+
+        if self.source_type == "EXTERNAL_SOURCES":
+            result = {
+                "type": self.source_type,
+                "externalSourcesConfiguration": {**base_config, "sources": 
self.sources},
+            }
+        return result
+
+    def execute(self, context: Context) -> Any:
+        self.validate_inputs()
+
+        result = self.hook.conn.retrieve_and_generate(
+            input={"text": self.input},
+            retrieveAndGenerateConfiguration=self.build_rag_config(),
+            **self.rag_kwargs,
+        )
+
+        self.log.info(
+            "\nPrompt: %s\nResponse: %s\nCitations: %s",
+            self.input,
+            result["output"]["text"],
+            result["citations"],
+        )
+        return result
+
+
+class BedrockRetrieveOperator(AwsBaseOperator[BedrockAgentRuntimeHook]):
+    """
+    Query a knowledge base and retrieve results with source citations.
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the 
guide:
+        :ref:`howto/operator:BedrockRetrieveOperator`
+
+    :param retrieval_query: The query to be made to the knowledge base. 
(templated)
+    :param knowledge_base_id: The unique identifier of the knowledge base that 
is queried. (templated)
+            Can only be specified if source_type='KNOWLEDGE_BASE'.

Review Comment:
   Yup, good eye.  The `Retrieve` API only supports knowledge bases, so it's 
not needed here.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to