ferruzzi commented on code in PR #39245:
URL: https://github.com/apache/airflow/pull/39245#discussion_r1581512161
##########
airflow/providers/amazon/aws/operators/bedrock.py:
##########
@@ -351,3 +353,288 @@ def execute_complete(self, context: Context, event:
dict[str, Any] | None = None
self.log.info("Bedrock provisioned throughput job `%s` complete.",
event["provisioned_model_id"])
return event["provisioned_model_id"]
+
+
+class BedrockCreateKnowledgeBaseOperator(AwsBaseOperator[BedrockAgentHook]):
+ """
+ Create a knowledge base that contains data sources used by Amazon Bedrock
LLMs and Agents.
+
+ To create a knowledge base, you must first set up your data sources and
configure a supported vector store.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:BedrockCreateKnowledgeBaseOperator`
+
+ :param name: The name of the knowledge base. (templated)
+ :param embedding_model_arn: ARN of the model used to create vector
embeddings for the knowledge base. (templated)
+ :param role_arn: The ARN of the IAM role with permissions to create the
knowledge base. (templated)
+ :param storage_config: Configuration details of the vector database used
for the knowledge base. (templated)
+ :param wait_for_indexing: Vector indexing can take some time and there is
no apparent way to check the state
+ before trying to create the Knowledge Base. If this is True, and
creation fails due to the index not
+ being available, the operator will wait and retry. (default: True)
(templated)
+ :param indexing_error_retry_delay: Seconds between retries if an index
error is encountered. (default 5) (templated)
+ :param indexing_error_max_attempts: Maximum number of times to retry when
encountering an index error. (default 20) (templated)
+
+ :param wait_for_completion: Whether to wait for cluster to stop. (default:
True)
+ :param waiter_delay: Time in seconds to wait between status checks.
(default: 60)
+ :param waiter_max_attempts: Maximum number of attempts to check for job
completion. (default: 20)
+ :param deferrable: If True, the operator will wait asynchronously for the
cluster to stop.
+ This implies waiting for completion. This mode requires aiobotocore
module to be installed.
+ (default: False)
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ If this is ``None`` or empty then the default boto3 behaviour is used.
If
+ running Airflow in a distributed manner and aws_conn_id is None or
+ empty, then default boto3 configuration would be used (and must be
+ maintained on each worker node).
+ :param region_name: AWS region_name. If not specified then the default
boto3 behaviour is used.
+ :param verify: Whether or not to verify SSL certificates. See:
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+ :param botocore_config: Configuration dictionary (key-values) for botocore
client. See:
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
+ """
+
+ aws_hook_class = BedrockAgentHook
+ template_fields: Sequence[str] = aws_template_fields(
+ "name",
+ "embedding_model_arn",
+ "role_arn",
+ "storage_config",
+ "wait_for_indexing",
+ "indexing_error_retry_delay",
+ "indexing_error_max_attempts",
+ )
+
+ def __init__(
+ self,
+ name: str,
+ embedding_model_arn: str,
+ role_arn: str,
+ storage_config: dict[str, Any],
+ create_knowledge_base_kwargs: dict[str, Any] | None = None,
+ wait_for_indexing: bool = True,
+ indexing_error_retry_delay: int = 5, # seconds
+ indexing_error_max_attempts: int = 20,
+ wait_for_completion: bool = True,
+ waiter_delay: int = 60,
+ waiter_max_attempts: int = 20,
+ deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.name = name
+ self.role_arn = role_arn
+ self.storage_config = storage_config
+ self.create_knowledge_base_kwargs = create_knowledge_base_kwargs or {}
+ self.embedding_model_arn = embedding_model_arn
+ self.knowledge_base_config = {
+ "type": "VECTOR",
+ "vectorKnowledgeBaseConfiguration": {"embeddingModelArn":
self.embedding_model_arn},
+ }
+ self.wait_for_indexing = wait_for_indexing
+ self.indexing_error_retry_delay = indexing_error_retry_delay
+ self.indexing_error_max_attempts = indexing_error_max_attempts
+
+ self.wait_for_completion = wait_for_completion
+ self.waiter_delay = waiter_delay
+ self.waiter_max_attempts = waiter_max_attempts
+ self.deferrable = deferrable
+
+ def execute_complete(self, context: Context, event: dict[str, Any] | None
= None) -> str:
+ event = validate_execute_complete_event(event)
+
+ if event["status"] != "success":
+ raise AirflowException(f"Error while running job: {event}")
+
+ self.log.info("Bedrock knowledge base creation job `%s` complete.",
self.name)
+ return
self.hook.conn.get_knowledge_base(knowledgeBaseId=event["knowledge_base_id"])["knowledgeBase"][
+ "status"
+ ]
+
+ def execute(self, context: Context) -> str:
+ def _create_kb():
+ # This API call will return the following if the index has not
completed, but there is no apparent
+ # way to check the state of the index beforehand, so retry on
index failure if set to do so.
+ # botocore.errorfactory.ValidationException: An error
occurred (ValidationException)
+ # when calling the CreateKnowledgeBase operation: The
knowledge base storage configuration
+ # provided is invalid... no such index
[bedrock-sample-rag-index-abc108]
+ try:
+ return self.hook.conn.create_knowledge_base(
+ name=self.name,
+ roleArn=self.role_arn,
+ knowledgeBaseConfiguration=self.knowledge_base_config,
+ storageConfiguration=self.storage_config,
+ **self.create_knowledge_base_kwargs,
+ )["knowledgeBase"]["knowledgeBaseId"]
+ except ClientError as error:
+ if all(
+ [
+ error.response["Error"]["Code"] ==
"ValidationException",
+ "no such index" in error.response["Error"]["Message"],
+ self.wait_for_indexing,
+ self.indexing_error_max_attempts > 0,
+ ]
+ ):
+ self.indexing_error_max_attempts -= 1
+ self.log.warning(
+ "Vector index not ready, retrying in %s seconds.",
self.indexing_error_retry_delay
+ )
+ self.log.debug("%s retries remaining.",
self.indexing_error_max_attempts)
+ sleep(self.indexing_error_retry_delay)
+ return _create_kb()
+ raise
+
+ self.log.info("Creating Amazon Bedrock Knowledge Base %s", self.name)
+ knowledge_base_id = _create_kb()
+
+ if self.deferrable:
+ self.log.info("Deferring for Knowledge base creation.")
+ self.defer(
+ trigger=BedrockKnowledgeBaseActiveTrigger(
+ knowledge_base_id=knowledge_base_id,
+ waiter_delay=self.waiter_delay,
+ waiter_max_attempts=self.waiter_max_attempts,
+ aws_conn_id=self.aws_conn_id,
+ ),
+ method_name="execute_complete",
+ )
+ if self.wait_for_completion:
+ self.log.info("Waiting for Knowledge Base creation.")
+ self.hook.get_waiter("knowledge_base_active").wait(
+ knowledgeBaseId=knowledge_base_id,
+ WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts":
self.waiter_max_attempts},
+ )
+
+ return knowledge_base_id
+
+
+class BedrockCreateDataSourceOperator(AwsBaseOperator[BedrockAgentHook]):
+ """
+ Set up an Amazon Bedrock Data Source to be added to an Amazon Bedrock
Knowledge Base.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:BedrockCreateDataSourceOperator`
+
+ :param name: name for the Amazon Bedrock Data Source being created.
(templated).
+ :param bucket_name: The name of the Amazon S3 bucket to use for data
source storage. (templated)
+ :param knowledge_base_id: The unique identifier of the knowledge base to
which to add the data source. (templated)
+
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ If this is ``None`` or empty then the default boto3 behaviour is used.
If
+ running Airflow in a distributed manner and aws_conn_id is None or
+ empty, then default boto3 configuration would be used (and must be
+ maintained on each worker node).
+ :param region_name: AWS region_name. If not specified then the default
boto3 behaviour is used.
+ :param verify: Whether or not to verify SSL certificates. See:
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+ :param botocore_config: Configuration dictionary (key-values) for botocore
client. See:
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
+ """
+
+ aws_hook_class = BedrockAgentHook
+
+ def __init__(
+ self,
+ name: str,
+ knowledge_base_id: str,
+ bucket_name: str | None = None,
+ create_data_source_kwargs: dict[str, Any] | None = None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.name = name
+ self.knowledge_base_id = knowledge_base_id
+ self.bucket_name = bucket_name
+ self.create_data_source_kwargs = create_data_source_kwargs or {}
+
+ template_fields: Sequence[str] = aws_template_fields(
+ "name",
+ "bucket_name",
+ "knowledge_base_id",
+ )
+
+ def execute(self, context: Context) -> str:
+ create_ds_response = self.hook.conn.create_data_source(
+ name=self.name,
+ knowledgeBaseId=self.knowledge_base_id,
+ dataSourceConfiguration={
+ "type": "S3",
+ "s3Configuration": {"bucketArn":
f"arn:aws:s3:::{self.bucket_name}"},
+ },
+ **self.create_data_source_kwargs,
+ )
+
+ return create_ds_response["dataSource"]["dataSourceId"]
+
+
+class BedrockIngestDataOperator(AwsBaseOperator[BedrockAgentHook]):
+ """
+ Begin an ingestion job, in which an Amazon Bedrock data source is added to
an Amazon Bedrock knowledge base.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:BedrockIngestDataOperator`
+
+ :param knowledge_base_id: The unique identifier of the knowledge base to
which to add the data source. (templated)
+ :param data_source_id: The unique identifier of the data source to ingest.
(templated)
+
+ :param wait_for_completion: Whether to wait for cluster to stop. (default:
True)
+ :param waiter_delay: Time in seconds to wait between status checks.
(default: 60)
+ :param waiter_max_attempts: Maximum number of attempts to check for job
completion. (default: 10)
+ :param deferrable: If True, the operator will wait asynchronously for the
cluster to stop.
+ This implies waiting for completion. This mode requires aiobotocore
module to be installed.
+ (default: False)
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ If this is ``None`` or empty then the default boto3 behaviour is used.
If
+ running Airflow in a distributed manner and aws_conn_id is None or
+ empty, then default boto3 configuration would be used (and must be
+ maintained on each worker node).
+ :param region_name: AWS region_name. If not specified then the default
boto3 behaviour is used.
+ :param verify: Whether or not to verify SSL certificates. See:
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+ :param botocore_config: Configuration dictionary (key-values) for botocore
client. See:
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
+ """
+
+ aws_hook_class = BedrockAgentHook
+
+ def __init__(
+ self,
+ knowledge_base_id: str,
+ data_source_id: str,
+ ingest_data_kwargs: dict[str, Any] | None = None,
+ wait_for_completion: bool = True,
+ waiter_delay: int = 60,
+ waiter_max_attempts: int = 10,
+ deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.knowledge_base_id = knowledge_base_id
+ self.data_source_id = data_source_id
+ self.ingest_data_kwargs = ingest_data_kwargs or {}
+
+ self.wait_for_completion = wait_for_completion
+ self.waiter_delay = waiter_delay
+ self.waiter_max_attempts = waiter_max_attempts
+ self.deferrable = deferrable
+
+ template_fields: Sequence[str] = aws_template_fields(
+ "knowledge_base_id",
+ "data_source_id",
+ )
+
+ def execute(self, context: Context) -> str:
+ ingestion_job_id = self.hook.conn.start_ingestion_job(
+ knowledgeBaseId=self.knowledge_base_id,
dataSourceId=self.data_source_id
+ )["ingestionJob"]["ingestionJobId"]
+
+ if self.wait_for_completion:
Review Comment:
implemented
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]