o-nikolas commented on code in PR #39500:
URL: https://github.com/apache/airflow/pull/39500#discussion_r1596032939
##########
airflow/providers/amazon/aws/operators/bedrock.py:
##########
@@ -664,3 +669,198 @@ def execute(self, context: Context) -> str:
)
return ingestion_job_id
+
+
+class BedrockRaGOperator(AwsBaseOperator[BedrockAgentRuntimeHook]):
+ """
+ Query a knowledge base and generate responses based on the retrieved
results with sources citations.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:BedrockRaGOperator`
+
+ :param input: The query to be made to the knowledge base. (templated)
+ :param source_type: The type of resource that is queried by the request.
(templated)
+ Must be one of 'KNOWLEDGE_BASE' or 'EXTERNAL_SOURCES', and the
appropriate config values must also be provided.
+ If set to 'KNOWLEDGE_BASE' then `knowledge_base_id` must be provided,
and `vector_search_config` may be.
+ If set to `EXTERNAL_SOURCES` then `sources` must also be provided.
+ :param model_arn: The ARN of the foundation model used to generate a
response. (templated)
+ :param prompt_template: The template for the prompt that's sent to the
model for response generation.
+ You can include prompt placeholders, which become replaced before the
prompt is sent to the model
Review Comment:
```suggestion
You can include prompt placeholders, which are replaced before the
prompt is sent to the model
```
##########
docs/apache-airflow-providers-amazon/operators/bedrock.rst:
##########
@@ -174,6 +177,40 @@ To add data from an Amazon S3 bucket into an Amazon
Bedrock Data Source, you can
:start-after: [START howto_operator_bedrock_ingest_data]
:end-before: [END howto_operator_bedrock_ingest_data]
+.. _howto/operator:BedrockRetrieveOperator:
+
+Amazon Bedrock Retrieve
+=======================
+
+To query a knowledge base, you can use
:class:`~airflow.providers.amazon.aws.operators.bedrock.BedrockRaGOperator`.
Review Comment:
This should be the `BedrockRetrieveOperator`, right? Or maybe the one below?
Right now RAG is in both. I also don't fully understand the difference between
the two operators. Maybe just a sentence or two more to tell the two apart? The
only difference I see in the description is one says "includes" the other says
"cites"
##########
airflow/providers/amazon/aws/operators/bedrock.py:
##########
@@ -664,3 +669,198 @@ def execute(self, context: Context) -> str:
)
return ingestion_job_id
+
+
+class BedrockRaGOperator(AwsBaseOperator[BedrockAgentRuntimeHook]):
+ """
+ Query a knowledge base and generate responses based on the retrieved
results with sources citations.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:BedrockRaGOperator`
+
+ :param input: The query to be made to the knowledge base. (templated)
+ :param source_type: The type of resource that is queried by the request.
(templated)
+ Must be one of 'KNOWLEDGE_BASE' or 'EXTERNAL_SOURCES', and the
appropriate config values must also be provided.
+ If set to 'KNOWLEDGE_BASE' then `knowledge_base_id` must be provided,
and `vector_search_config` may be.
+ If set to `EXTERNAL_SOURCES` then `sources` must also be provided.
+ :param model_arn: The ARN of the foundation model used to generate a
response. (templated)
+ :param prompt_template: The template for the prompt that's sent to the
model for response generation.
+ You can include prompt placeholders, which become replaced before the
prompt is sent to the model
+ to provide instructions and context to the model. In addition, you can
include XML tags to delineate
+ meaningful sections of the prompt template. (templated)
+ :param knowledge_base_id: The unique identifier of the knowledge base that
is queried. (templated)
+ Can only be specified if source_type='KNOWLEDGE_BASE'.
+ :param vector_search_config: How the results from the vector search should
be returned. (templated)
+ Can only be specified if source_type='KNOWLEDGE_BASE'.
+ For more information, see
https://docs.aws.amazon.com/bedrock/latest/userguide/kb-test-config.html.
+ :param sources: The documents used as reference for the response.
(templated)
+ Can only be specified if source_type='EXTERNAL_SOURCES'
+ :param rag_kwargs: Additional keyword arguments to pass to the API call.
(templated)
+ """
+
+ aws_hook_class = BedrockAgentRuntimeHook
+ template_fields: Sequence[str] = aws_template_fields(
+ "input",
+ "source_type",
+ "model_arn",
+ "prompt_template",
+ "knowledge_base_id",
+ "vector_search_config",
+ "sources",
+ "rag_kwargs",
+ )
+
+ def __init__(
+ self,
+ input: str,
+ source_type: str,
+ model_arn: str,
+ prompt_template: str | None = None,
+ knowledge_base_id: str | None = None,
+ vector_search_config: dict[str, Any] | None = None,
+ sources: list[dict[str, Any]] | None = None,
+ rag_kwargs: dict[str, Any] | None = None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.input = input
+ self.prompt_template = prompt_template
+ self.source_type = source_type.upper()
+ self.knowledge_base_id = knowledge_base_id
+ self.model_arn = model_arn
+ self.vector_search_config = vector_search_config
+ self.sources = sources
+ self.rag_kwargs = rag_kwargs or {}
+
+ def validate_inputs(self):
+ if self.source_type == "KNOWLEDGE_BASE":
+ if self.knowledge_base_id is None:
+ raise AttributeError(
+ "If `source_type` is set to 'KNOWLEDGE_BASE' then
`knowledge_base_id` must be provided."
+ )
+ if self.sources is not None:
+ raise AttributeError(
+ "`sources` can not be used when `source_type` is set to
'KNOWLEDGE_BASE'."
+ )
+ elif self.source_type == "EXTERNAL_SOURCES":
+ if not self.sources is not None:
+ raise AttributeError(
+ "If `source_type` is set to `EXTERNAL_SOURCES` then
`sources` must also be provided."
+ )
+ if self.vector_search_config or self.knowledge_base_id:
+ raise AttributeError(
+ "`vector_search_config` and `knowledge_base_id` can not be
used "
+ "when `source_type` is set to `EXTERNAL_SOURCES`"
+ )
+ else:
+ raise AttributeError(
+ "`source_type` must be one of 'KNOWLEDGE_BASE' or
'EXTERNAL_SOURCES', "
+ "and the appropriate config values must also be provided."
+ )
+
+ def build_rag_config(self) -> dict[str, Any]:
+ result: dict[str, Any] = {}
+ base_config: dict[str, Any] = {
+ "modelArn": self.model_arn,
+ }
+
+ if self.prompt_template:
+ base_config["generationConfiguration"] = {
+ "promptTemplate": {"textPromptTemplate": self.prompt_template}
+ }
+
+ if self.source_type == "KNOWLEDGE_BASE":
+ if self.vector_search_config:
+ base_config["retrievalConfiguration"] = {
+ "vectorSearchConfiguration": self.vector_search_config
+ }
+
+ result = {
+ "type": self.source_type,
+ "knowledgeBaseConfiguration": {
+ **base_config,
+ "knowledgeBaseId": self.knowledge_base_id,
+ },
+ }
+
+ if self.source_type == "EXTERNAL_SOURCES":
+ result = {
+ "type": self.source_type,
+ "externalSourcesConfiguration": {**base_config, "sources":
self.sources},
+ }
+ return result
+
+ def execute(self, context: Context) -> Any:
+ self.validate_inputs()
+
+ result = self.hook.conn.retrieve_and_generate(
+ input={"text": self.input},
+ retrieveAndGenerateConfiguration=self.build_rag_config(),
+ **self.rag_kwargs,
+ )
+
+ self.log.info(
+ "\nPrompt: %s\nResponse: %s\nCitations: %s",
+ self.input,
+ result["output"]["text"],
+ result["citations"],
+ )
+ return result
+
+
+class BedrockRetrieveOperator(AwsBaseOperator[BedrockAgentRuntimeHook]):
+ """
+ Query a knowledge base and retrieve results with source citations.
Review Comment:
Does this one also retrieve citations? I thought it was just the RAG
operator that did that?
##########
tests/providers/amazon/aws/operators/test_bedrock.py:
##########
@@ -346,3 +347,168 @@ def test_id_returned(self, mock_conn):
result = self.operator.execute({})
assert result == self.INGESTION_JOB_ID
+
+
+class TestBedrockRaGOperator:
+ VECTOR_SEARCH_CONFIG = {"filter": {"equals": {"key": "some key", "value":
"some value"}}}
+ KNOWLEDGE_BASE_ID = "knowledge_base_id"
+ SOURCES = [{"sourceType": "S3", "s3Location": "bucket"}]
+ MODEL_ARN = "model arn"
+
+ @pytest.mark.parametrize(
+ "source_type, vector_search_config, knowledge_base_id, sources,
expect_success",
+ [
+ pytest.param(
+ "invalid_source_type",
+ None,
+ None,
+ None,
+ False,
+ id="invalid_source_type",
+ ),
+ pytest.param(
+ "KNOWLEDGE_BASE",
+ VECTOR_SEARCH_CONFIG,
+ None,
+ None,
+ False,
+ id="KNOWLEDGE_BASE_without_knowledge_base_id_fails",
+ ),
+ pytest.param(
+ "KNOWLEDGE_BASE",
+ VECTOR_SEARCH_CONFIG,
+ KNOWLEDGE_BASE_ID,
+ SOURCES,
+ False,
+ id="KNOWLEDGE_BASE_with_sources_fails",
+ ),
+ pytest.param(
+ "KNOWLEDGE_BASE",
+ VECTOR_SEARCH_CONFIG,
+ KNOWLEDGE_BASE_ID,
+ None,
+ True,
+ id="KNOWLEDGE_BASE_passes",
+ ),
Review Comment:
Should another test case with just a knowledge_base_id (without vector
config) also pass?
##########
airflow/providers/amazon/aws/operators/bedrock.py:
##########
@@ -664,3 +669,198 @@ def execute(self, context: Context) -> str:
)
return ingestion_job_id
+
+
+class BedrockRaGOperator(AwsBaseOperator[BedrockAgentRuntimeHook]):
+ """
+ Query a knowledge base and generate responses based on the retrieved
results with sources citations.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:BedrockRaGOperator`
+
+ :param input: The query to be made to the knowledge base. (templated)
+ :param source_type: The type of resource that is queried by the request.
(templated)
+ Must be one of 'KNOWLEDGE_BASE' or 'EXTERNAL_SOURCES', and the
appropriate config values must also be provided.
+ If set to 'KNOWLEDGE_BASE' then `knowledge_base_id` must be provided,
and `vector_search_config` may be.
+ If set to `EXTERNAL_SOURCES` then `sources` must also be provided.
+ :param model_arn: The ARN of the foundation model used to generate a
response. (templated)
+ :param prompt_template: The template for the prompt that's sent to the
model for response generation.
+ You can include prompt placeholders, which become replaced before the
prompt is sent to the model
+ to provide instructions and context to the model. In addition, you can
include XML tags to delineate
+ meaningful sections of the prompt template. (templated)
+ :param knowledge_base_id: The unique identifier of the knowledge base that
is queried. (templated)
+ Can only be specified if source_type='KNOWLEDGE_BASE'.
+ :param vector_search_config: How the results from the vector search should
be returned. (templated)
+ Can only be specified if source_type='KNOWLEDGE_BASE'.
+ For more information, see
https://docs.aws.amazon.com/bedrock/latest/userguide/kb-test-config.html.
+ :param sources: The documents used as reference for the response.
(templated)
+ Can only be specified if source_type='EXTERNAL_SOURCES'
+ :param rag_kwargs: Additional keyword arguments to pass to the API call.
(templated)
+ """
+
+ aws_hook_class = BedrockAgentRuntimeHook
+ template_fields: Sequence[str] = aws_template_fields(
+ "input",
+ "source_type",
+ "model_arn",
+ "prompt_template",
+ "knowledge_base_id",
+ "vector_search_config",
+ "sources",
+ "rag_kwargs",
+ )
+
+ def __init__(
+ self,
+ input: str,
+ source_type: str,
+ model_arn: str,
+ prompt_template: str | None = None,
+ knowledge_base_id: str | None = None,
+ vector_search_config: dict[str, Any] | None = None,
+ sources: list[dict[str, Any]] | None = None,
+ rag_kwargs: dict[str, Any] | None = None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.input = input
+ self.prompt_template = prompt_template
+ self.source_type = source_type.upper()
+ self.knowledge_base_id = knowledge_base_id
+ self.model_arn = model_arn
+ self.vector_search_config = vector_search_config
+ self.sources = sources
+ self.rag_kwargs = rag_kwargs or {}
+
+ def validate_inputs(self):
+ if self.source_type == "KNOWLEDGE_BASE":
+ if self.knowledge_base_id is None:
+ raise AttributeError(
+ "If `source_type` is set to 'KNOWLEDGE_BASE' then
`knowledge_base_id` must be provided."
+ )
+ if self.sources is not None:
+ raise AttributeError(
+ "`sources` can not be used when `source_type` is set to
'KNOWLEDGE_BASE'."
+ )
+ elif self.source_type == "EXTERNAL_SOURCES":
+ if not self.sources is not None:
+ raise AttributeError(
+ "If `source_type` is set to `EXTERNAL_SOURCES` then
`sources` must also be provided."
+ )
+ if self.vector_search_config or self.knowledge_base_id:
+ raise AttributeError(
+ "`vector_search_config` and `knowledge_base_id` can not be
used "
+ "when `source_type` is set to `EXTERNAL_SOURCES`"
+ )
+ else:
+ raise AttributeError(
+ "`source_type` must be one of 'KNOWLEDGE_BASE' or
'EXTERNAL_SOURCES', "
+ "and the appropriate config values must also be provided."
+ )
+
+ def build_rag_config(self) -> dict[str, Any]:
+ result: dict[str, Any] = {}
+ base_config: dict[str, Any] = {
+ "modelArn": self.model_arn,
+ }
+
+ if self.prompt_template:
+ base_config["generationConfiguration"] = {
+ "promptTemplate": {"textPromptTemplate": self.prompt_template}
+ }
+
+ if self.source_type == "KNOWLEDGE_BASE":
+ if self.vector_search_config:
+ base_config["retrievalConfiguration"] = {
+ "vectorSearchConfiguration": self.vector_search_config
+ }
+
+ result = {
+ "type": self.source_type,
+ "knowledgeBaseConfiguration": {
+ **base_config,
+ "knowledgeBaseId": self.knowledge_base_id,
+ },
+ }
+
+ if self.source_type == "EXTERNAL_SOURCES":
+ result = {
+ "type": self.source_type,
+ "externalSourcesConfiguration": {**base_config, "sources":
self.sources},
+ }
+ return result
+
+ def execute(self, context: Context) -> Any:
+ self.validate_inputs()
+
+ result = self.hook.conn.retrieve_and_generate(
+ input={"text": self.input},
+ retrieveAndGenerateConfiguration=self.build_rag_config(),
+ **self.rag_kwargs,
+ )
+
+ self.log.info(
+ "\nPrompt: %s\nResponse: %s\nCitations: %s",
+ self.input,
+ result["output"]["text"],
+ result["citations"],
+ )
+ return result
+
+
+class BedrockRetrieveOperator(AwsBaseOperator[BedrockAgentRuntimeHook]):
+ """
+ Query a knowledge base and retrieve results with source citations.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:BedrockRetrieveOperator`
+
+ :param retrieval_query: The query to be made to the knowledge base.
(templated)
+ :param knowledge_base_id: The unique identifier of the knowledge base that
is queried. (templated)
+ Can only be specified if source_type='KNOWLEDGE_BASE'.
Review Comment:
I don't see source_type as a param for this operator?
##########
tests/system/providers/amazon/aws/example_bedrock_knowledge_base.py:
##########
@@ -480,6 +483,24 @@ def delete_opensearch_policies(collection_name: str):
)
# [END howto_sensor_bedrock_ingest_data]
+ # [START howto_operator_bedrock_retrieve_and_generate]
+ retrieve_and_generate = BedrockRaGOperator(
+ task_id="retrieve_and_generate",
+ input="Who was the CEO of Amazon on 2022?",
+ source_type="KNOWLEDGE_BASE",
Review Comment:
Should we have another one that tests the other source type?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]