This is an automated email from the ASF dual-hosted git repository.
onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 7bd1aa6bc0f Add Bedrock Batch Inference Operator and accompanying
parts (#48468)
7bd1aa6bc0f is described below
commit 7bd1aa6bc0f0f4d6703d21e3accb87a8ed9e614c
Author: D. Ferruzzi <[email protected]>
AuthorDate: Mon Apr 7 12:47:30 2025 -0700
Add Bedrock Batch Inference Operator and accompanying parts (#48468)
Add a new operator for Bedrock Batch Inference Jobs along with all the
sensor/waiter/trigger/etc support, docs, and testing.
---
providers/amazon/docs/operators/bedrock.rst | 49 ++++++
.../providers/amazon/aws/operators/bedrock.py | 119 ++++++++++++++
.../providers/amazon/aws/sensors/bedrock.py | 110 +++++++++++++
.../providers/amazon/aws/triggers/bedrock.py | 98 +++++++++++
.../providers/amazon/aws/waiters/bedrock.json | 134 +++++++++++++++
.../amazon/aws/example_bedrock_batch_inference.py | 182 +++++++++++++++++++++
.../unit/amazon/aws/operators/test_bedrock.py | 60 +++++++
.../tests/unit/amazon/aws/sensors/test_bedrock.py | 95 ++++++++++-
.../tests/unit/amazon/aws/triggers/test_bedrock.py | 62 +++++++
.../tests/unit/amazon/aws/waiters/test_bedrock.py | 89 +++++++++-
10 files changed, 995 insertions(+), 3 deletions(-)
diff --git a/providers/amazon/docs/operators/bedrock.rst
b/providers/amazon/docs/operators/bedrock.rst
index 9e110856b7e..f166a10c2dd 100644
--- a/providers/amazon/docs/operators/bedrock.rst
+++ b/providers/amazon/docs/operators/bedrock.rst
@@ -74,6 +74,14 @@ To invoke a Claude V2 model using the Completions API you
would use:
:start-after: [START howto_operator_invoke_claude_model]
:end-before: [END howto_operator_invoke_claude_model]
+To invoke a Claude V3 Sonnet model using the Messages API you would use:
+
+.. exampleinclude::
/../../amazon/tests/system/amazon/aws/example_bedrock_batch_inference.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_invoke_claude_messages]
+ :end-before: [END howto_operator_invoke_claude_messages]
+
.. _howto/operator:BedrockCustomizeModelOperator:
@@ -237,6 +245,29 @@ Example using a PDF file in an Amazon S3 Bucket:
:start-after: [START howto_operator_bedrock_external_sources_rag]
:end-before: [END howto_operator_bedrock_external_sources_rag]
+.. _howto/operator:BedrockBatchInferenceOperator:
+
+Create an Amazon Bedrock Batch Inference Job
+============================================
+
+To creates a batch inference job to invoke a model on multiple prompts, you
can use
+:class:`~airflow.providers.amazon.aws.operators.bedrock.BedrockBatchInferenceOperator`.
+
+The input must be formatted in jsonl and uploaded to an Amazon S3 bucket.
Please see
+https://docs.aws.amazon.com/bedrock/latest/userguide/batch-inference.html for
details.
+
+NOTE: Jobs are added to a queue and processed in order. Given the potential
wait times,
+and the fact that the optional timeout parameter is measured in hours,
deferrable mode is
+recommended over "wait_for_completion" in this case.
+
+Example using an Amazon Bedrock Batch Inference Job:
+
+.. exampleinclude::
/../../amazon/tests/system/amazon/aws/example_bedrock_batch_inference.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_bedrock_batch_inference]
+ :end-before: [END howto_operator_bedrock_batch_inference]
+
Sensors
-------
@@ -298,6 +329,24 @@ To wait on the state of an Amazon Bedrock data ingestion
job until it reaches a
:start-after: [START howto_sensor_bedrock_ingest_data]
:end-before: [END howto_sensor_bedrock_ingest_data]
+.. _howto/sensor:BedrockBatchInferenceSensor:
+
+Wait for an Amazon Bedrock batch inference job
+==============================================
+
+To wait on the state of an Amazon Bedrock batch inference job until it reaches
the "Scheduled" or "Completed"
+state you can use
:class:`~airflow.providers.amazon.aws.sensors.bedrock.BedrockBatchInferenceScheduledSensor`
+
+Bedrock adds batch inference jobs to a queue, and they may take some time to
complete. If you want to wait
+for the job to complete, use TargetState.COMPLETED for the success_state, but
if you only want to wait until
+the service confirms that the job is in the queue, use TargetState.SCHEDULED.
+
+.. exampleinclude::
/../../amazon/tests/system/amazon/aws/example_bedrock_batch_inference.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_sensor_bedrock_batch_inference_scheduled]
+ :end-before: [END howto_sensor_bedrock_batch_inference_scheduled]
+
Reference
---------
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/operators/bedrock.py
b/providers/amazon/src/airflow/providers/amazon/aws/operators/bedrock.py
index ac8e86b92c8..f9715114b74 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/operators/bedrock.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/bedrock.py
@@ -33,6 +33,7 @@ from airflow.providers.amazon.aws.hooks.bedrock import (
)
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
from airflow.providers.amazon.aws.triggers.bedrock import (
+ BedrockBatchInferenceCompletedTrigger,
BedrockCustomizeModelCompletedTrigger,
BedrockIngestionJobTrigger,
BedrockKnowledgeBaseActiveTrigger,
@@ -869,3 +870,121 @@ class
BedrockRetrieveOperator(AwsBaseOperator[BedrockAgentRuntimeHook]):
self.log.info("\nQuery: %s\nRetrieved: %s", self.retrieval_query,
result["retrievalResults"])
return result
+
+
+class BedrockBatchInferenceOperator(AwsBaseOperator[BedrockHook]):
+ """
+ Create a batch inference job to invoke a model on multiple prompts.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:BedrockBatchInferenceOperator`
+
+ :param job_name: A name to give the batch inference job. (templated)
+ :param role_arn: The ARN of the IAM role with permissions to create the
knowledge base. (templated)
+ :param model_id: Name or ARN of the model to associate with this
provisioned throughput. (templated)
+ :param input_uri: The S3 location of the input data. (templated)
+ :param output_uri: The S3 location of the output data. (templated)
+ :param invoke_kwargs: Additional keyword arguments to pass to the API
call. (templated)
+
+ :param wait_for_completion: Whether to wait for cluster to stop. (default:
True)
+ NOTE: The way batch inference jobs work, your jobs are added to a
queue and done "eventually"
+ so using deferrable mode is much more practical than using
wait_for_completion.
+ :param waiter_delay: Time in seconds to wait between status checks.
(default: 60)
+ :param waiter_max_attempts: Maximum number of attempts to check for job
completion. (default: 10)
+ :param deferrable: If True, the operator will wait asynchronously for the
cluster to stop.
+ This implies waiting for completion. This mode requires aiobotocore
module to be installed.
+ (default: False)
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ If this is ``None`` or empty then the default boto3 behaviour is used.
If
+ running Airflow in a distributed manner and aws_conn_id is None or
+ empty, then default boto3 configuration would be used (and must be
+ maintained on each worker node).
+ :param region_name: AWS region_name. If not specified then the default
boto3 behaviour is used.
+ :param verify: Whether or not to verify SSL certificates. See:
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+ :param botocore_config: Configuration dictionary (key-values) for botocore
client. See:
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
+ """
+
+ aws_hook_class = BedrockHook
+ template_fields: Sequence[str] = aws_template_fields(
+ "job_name",
+ "role_arn",
+ "model_id",
+ "input_uri",
+ "output_uri",
+ "invoke_kwargs",
+ )
+
+ def __init__(
+ self,
+ job_name: str,
+ role_arn: str,
+ model_id: str,
+ input_uri: str,
+ output_uri: str,
+ invoke_kwargs: dict[str, Any] | None = None,
+ wait_for_completion: bool = True,
+ waiter_delay: int = 60,
+ waiter_max_attempts: int = 10,
+ deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.job_name = job_name
+ self.role_arn = role_arn
+ self.model_id = model_id
+ self.input_uri = input_uri
+ self.output_uri = output_uri
+ self.invoke_kwargs = invoke_kwargs or {}
+
+ self.wait_for_completion = wait_for_completion
+ self.waiter_delay = waiter_delay
+ self.waiter_max_attempts = waiter_max_attempts
+ self.deferrable = deferrable
+
+ self.activity = "Bedrock batch inference job"
+
+ def execute_complete(self, context: Context, event: dict[str, Any] | None
= None) -> str:
+ validated_event = validate_execute_complete_event(event)
+
+ if validated_event["status"] != "success":
+ raise AirflowException(f"Error while running {self.activity}:
{validated_event}")
+
+ self.log.info("%s '%s' complete.", self.activity,
validated_event["job_arn"])
+
+ return validated_event["job_arn"]
+
+ def execute(self, context: Context) -> str:
+ response = self.hook.conn.create_model_invocation_job(
+ jobName=self.job_name,
+ roleArn=self.role_arn,
+ modelId=self.model_id,
+ inputDataConfig={"s3InputDataConfig": {"s3Uri": self.input_uri}},
+ outputDataConfig={"s3OutputDataConfig": {"s3Uri":
self.output_uri}},
+ **self.invoke_kwargs,
+ )
+ job_arn = response["jobArn"]
+ self.log.info("%s '%s' started with ARN: %s", self.activity,
self.job_name, job_arn)
+
+ task_description = f"for {self.activity} '{self.job_name}' to
complete."
+ if self.deferrable:
+ self.log.info("Deferring %s", task_description)
+ self.defer(
+ trigger=BedrockBatchInferenceCompletedTrigger(
+ job_arn=job_arn,
+ waiter_delay=self.waiter_delay,
+ waiter_max_attempts=self.waiter_max_attempts,
+ aws_conn_id=self.aws_conn_id,
+ ),
+ method_name="execute_complete",
+ )
+ elif self.wait_for_completion:
+ self.log.info("Waiting %s", task_description)
+ self.hook.get_waiter(waiter_name="batch_inference_complete").wait(
+ jobIdentifier=job_arn,
+ WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts":
self.waiter_max_attempts},
+ )
+
+ return job_arn
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/sensors/bedrock.py
b/providers/amazon/src/airflow/providers/amazon/aws/sensors/bedrock.py
index 541702f4774..c74c3e772fb 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/sensors/bedrock.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/sensors/bedrock.py
@@ -26,6 +26,8 @@ from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.bedrock import BedrockAgentHook,
BedrockHook
from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
from airflow.providers.amazon.aws.triggers.bedrock import (
+ BedrockBatchInferenceCompletedTrigger,
+ BedrockBatchInferenceScheduledTrigger,
BedrockCustomizeModelCompletedTrigger,
BedrockIngestionJobTrigger,
BedrockKnowledgeBaseActiveTrigger,
@@ -34,6 +36,7 @@ from airflow.providers.amazon.aws.triggers.bedrock import (
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
if TYPE_CHECKING:
+ from airflow.providers.amazon.aws.triggers.bedrock import
BedrockBaseBatchInferenceTrigger
from airflow.utils.context import Context
@@ -368,3 +371,110 @@ class
BedrockIngestionJobSensor(BedrockBaseSensor[BedrockAgentHook]):
)
else:
super().execute(context=context)
+
+
+class BedrockBatchInferenceSensor(BedrockBaseSensor[BedrockHook]):
+ """
+ Poll the batch inference job status until it reaches a terminal state;
fails if creation fails.
+
+ .. seealso::
+ For more information on how to use this sensor, take a look at the
guide:
+ :ref:`howto/sensor:BedrockBatchInferenceSensor`
+
+ :param job_arn: The Amazon Resource Name (ARN) of the batch inference job.
(templated)
+ :param success_state: A BedrockBatchInferenceSensor.TargetState; defaults
to 'SCHEDULED' (templated)
+
+ :param deferrable: If True, the sensor will operate in deferrable more.
This mode requires aiobotocore
+ module to be installed.
+ (default: False, but can be overridden in config file by setting
default_deferrable to True)
+ :param poke_interval: Polling period in seconds to check for the status of
the job. (default: 5)
+ :param max_retries: Number of times before returning the current state
(default: 24)
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ If this is ``None`` or empty then the default boto3 behaviour is used.
If
+ running Airflow in a distributed manner and aws_conn_id is None or
+ empty, then default boto3 configuration would be used (and must be
+ maintained on each worker node).
+ :param region_name: AWS region_name. If not specified then the default
boto3 behaviour is used.
+ :param verify: Whether or not to verify SSL certificates. See:
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+ :param botocore_config: Configuration dictionary (key-values) for botocore
client. See:
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
+ """
+
+ class SuccessState:
+ """
+ Target state for the BedrockBatchInferenceSensor.
+
+ Bedrock adds batch inference jobs to a queue, and they may take some
time to complete.
+ If you want to wait for the job to complete, use
TargetState.COMPLETED, but if you only want
+ to wait until the service confirms that the job is in the queue, use
TargetState.SCHEDULED.
+
+ The normal successful progression of states is:
+ Submitted > Validating > Scheduled > InProgress >
PartiallyCompleted > Completed
+ """
+
+ SCHEDULED = "scheduled"
+ COMPLETED = "completed"
+
+ INTERMEDIATE_STATES: tuple[str, ...] # Defined in __init__ based on
target state
+ FAILURE_STATES: tuple[str, ...] = ("Failed", "Stopped",
"PartiallyCompleted", "Expired")
+ SUCCESS_STATES: tuple[str, ...] # Defined in __init__ based on target
state
+ FAILURE_MESSAGE = "Bedrock batch inference job sensor failed."
+ INVALID_SUCCESS_STATE_MESSAGE = "success_state must be an instance of
TargetState."
+
+ aws_hook_class = BedrockHook
+
+ template_fields: Sequence[str] = aws_template_fields("job_arn",
"success_state")
+
+ def __init__(
+ self,
+ *,
+ job_arn: str,
+ success_state: SuccessState | str = SuccessState.SCHEDULED,
+ poke_interval: int = 120,
+ max_retries: int = 75,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.poke_interval = poke_interval
+ self.max_retries = max_retries
+ self.job_arn = job_arn
+ self.success_state = success_state
+
+ base_success_states: tuple[str, ...] = ("Completed",)
+ base_intermediate_states: tuple[str, ...] = ("Submitted",
"InProgress", "Stopping", "Validating")
+ scheduled_state = ("Scheduled",)
+ self.trigger_class: type[BedrockBaseBatchInferenceTrigger]
+
+ if self.success_state ==
BedrockBatchInferenceSensor.SuccessState.COMPLETED:
+ intermediate_states = base_intermediate_states + scheduled_state
+ success_states = base_success_states
+ self.trigger_class = BedrockBatchInferenceCompletedTrigger
+ elif self.success_state ==
BedrockBatchInferenceSensor.SuccessState.SCHEDULED:
+ intermediate_states = base_intermediate_states
+ success_states = base_success_states + scheduled_state
+ self.trigger_class = BedrockBatchInferenceScheduledTrigger
+ else:
+ raise ValueError(
+ "Success states for BedrockBatchInferenceSensor must be set
using a BedrockBatchInferenceSensor.SuccessState"
+ )
+
+ BedrockBatchInferenceSensor.INTERMEDIATE_STATES = intermediate_states
or base_intermediate_states
+ BedrockBatchInferenceSensor.SUCCESS_STATES = success_states or
base_success_states
+
+ def get_state(self) -> str:
+ return
self.hook.conn.get_model_invocation_job(jobIdentifier=self.job_arn)["status"]
+
+ def execute(self, context: Context) -> Any:
+ if self.deferrable:
+ self.defer(
+ trigger=self.trigger_class(
+ job_arn=self.job_arn,
+ waiter_delay=int(self.poke_interval),
+ waiter_max_attempts=self.max_retries,
+ aws_conn_id=self.aws_conn_id,
+ ),
+ method_name="poke",
+ )
+ else:
+ super().execute(context=context)
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/triggers/bedrock.py
b/providers/amazon/src/airflow/providers/amazon/aws/triggers/bedrock.py
index 99d632d26e0..faac0f90753 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/triggers/bedrock.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/bedrock.py
@@ -20,6 +20,7 @@ from typing import TYPE_CHECKING
from airflow.providers.amazon.aws.hooks.bedrock import BedrockAgentHook,
BedrockHook
from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
+from airflow.utils.types import NOTSET, ArgNotSet
if TYPE_CHECKING:
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
@@ -182,3 +183,100 @@ class BedrockIngestionJobTrigger(AwsBaseWaiterTrigger):
def hook(self) -> AwsGenericHook:
return BedrockAgentHook(aws_conn_id=self.aws_conn_id)
+
+
+class BedrockBaseBatchInferenceTrigger(AwsBaseWaiterTrigger):
+ """
+ Trigger when a batch inference job is complete.
+
+ :param job_arn: The Amazon Resource Name (ARN) of the batch inference job.
+
+ :param waiter_delay: The amount of time in seconds to wait between
attempts. (default: 120)
+ :param waiter_max_attempts: The maximum number of attempts to be made.
(default: 75)
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ """
+
+ def __init__(
+ self,
+ *,
+ job_arn: str,
+ waiter_name: str | ArgNotSet = NOTSET, # This must be defined in the
child class.
+ waiter_delay: int = 120,
+ waiter_max_attempts: int = 75,
+ aws_conn_id: str | None = None,
+ ) -> None:
+ if waiter_name == NOTSET:
+ raise NotImplementedError("Triggers must provide a waiter name.")
+
+ super().__init__(
+ serialized_fields={"job_arn": job_arn},
+ waiter_name=str(waiter_name), # Cast a string to a string to make
mypy happy
+ waiter_args={"jobIdentifier": job_arn},
+ failure_message="Bedrock batch inference job failed.",
+ status_message="Status of Bedrock batch inference job is",
+ status_queries=["status"],
+ return_key="job_arn",
+ return_value=job_arn,
+ waiter_delay=waiter_delay,
+ waiter_max_attempts=waiter_max_attempts,
+ aws_conn_id=aws_conn_id,
+ )
+
+ def hook(self) -> AwsGenericHook:
+ return BedrockHook(aws_conn_id=self.aws_conn_id)
+
+
+class BedrockBatchInferenceCompletedTrigger(BedrockBaseBatchInferenceTrigger):
+ """
+ Trigger when a batch inference job is complete.
+
+ :param job_arn: The Amazon Resource Name (ARN) of the batch inference job.
+
+ :param waiter_delay: The amount of time in seconds to wait between
attempts. (default: 120)
+ :param waiter_max_attempts: The maximum number of attempts to be made.
(default: 75)
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ """
+
+ def __init__(
+ self,
+ *,
+ job_arn: str,
+ waiter_delay: int = 120,
+ waiter_max_attempts: int = 75,
+ aws_conn_id: str | None = None,
+ ) -> None:
+ super().__init__(
+ waiter_name="batch_inference_complete",
+ job_arn=job_arn,
+ waiter_delay=waiter_delay,
+ waiter_max_attempts=waiter_max_attempts,
+ aws_conn_id=aws_conn_id,
+ )
+
+
+class BedrockBatchInferenceScheduledTrigger(BedrockBaseBatchInferenceTrigger):
+ """
+ Trigger when a batch inference job is scheduled.
+
+ :param job_arn: The Amazon Resource Name (ARN) of the batch inference job.
+
+ :param waiter_delay: The amount of time in seconds to wait between
attempts. (default: 120)
+ :param waiter_max_attempts: The maximum number of attempts to be made.
(default: 75)
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ """
+
+ def __init__(
+ self,
+ *,
+ job_arn: str,
+ waiter_delay: int = 120,
+ waiter_max_attempts: int = 75,
+ aws_conn_id: str | None = None,
+ ) -> None:
+ super().__init__(
+ waiter_name="batch_inference_scheduled",
+ job_arn=job_arn,
+ waiter_delay=waiter_delay,
+ waiter_max_attempts=waiter_max_attempts,
+ aws_conn_id=aws_conn_id,
+ )
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/waiters/bedrock.json
b/providers/amazon/src/airflow/providers/amazon/aws/waiters/bedrock.json
index c913b4dc7c6..18a6294a128 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/waiters/bedrock.json
+++ b/providers/amazon/src/airflow/providers/amazon/aws/waiters/bedrock.json
@@ -68,6 +68,140 @@
"state": "failure"
}
]
+ },
+ "batch_inference_complete": {
+ "delay": 120,
+ "maxAttempts": 75,
+ "operation": "GetModelInvocationJob",
+ "acceptors": [
+ {
+ "matcher": "path",
+ "argument": "status",
+ "expected": "Completed",
+ "state": "success"
+ },
+ {
+ "matcher": "path",
+ "argument": "status",
+ "expected": "Failed",
+ "state": "failure"
+ },
+ {
+ "matcher": "path",
+ "argument": "status",
+ "expected": "Stopped",
+ "state": "failure"
+ },
+ {
+ "matcher": "path",
+ "argument": "status",
+ "expected": "PartiallyCompleted",
+ "state": "failure"
+ },
+ {
+ "matcher": "path",
+ "argument": "status",
+ "expected": "Expired",
+ "state": "failure"
+ },
+ {
+ "matcher": "path",
+ "argument": "status",
+ "expected": "Stopping",
+ "state": "failure"
+ },
+ {
+ "matcher": "path",
+ "argument": "status",
+ "expected": "Submitted",
+ "state": "retry"
+ },
+ {
+ "matcher": "path",
+ "argument": "status",
+ "expected": "InProgress",
+ "state": "retry"
+ },
+ {
+ "matcher": "path",
+ "argument": "status",
+ "expected": "Validating",
+ "state": "retry"
+ },
+ {
+ "matcher": "path",
+ "argument": "status",
+ "expected": "Scheduled",
+ "state": "retry"
+ }
+ ]
+ },
+ "batch_inference_scheduled": {
+ "delay": 120,
+ "maxAttempts": 75,
+ "operation": "GetModelInvocationJob",
+ "acceptors": [
+ {
+ "matcher": "path",
+ "argument": "status",
+ "expected": "Completed",
+ "state": "success"
+ },
+ {
+ "matcher": "path",
+ "argument": "status",
+ "expected": "Failed",
+ "state": "failure"
+ },
+ {
+ "matcher": "path",
+ "argument": "status",
+ "expected": "Stopped",
+ "state": "failure"
+ },
+ {
+ "matcher": "path",
+ "argument": "status",
+ "expected": "Stopping",
+ "state": "failure"
+ },
+ {
+ "matcher": "path",
+ "argument": "status",
+ "expected": "PartiallyCompleted",
+ "state": "failure"
+ },
+ {
+ "matcher": "path",
+ "argument": "status",
+ "expected": "Expired",
+ "state": "failure"
+ },
+ {
+ "matcher": "path",
+ "argument": "status",
+ "expected": "Submitted",
+ "state": "retry"
+ },
+ {
+ "matcher": "path",
+ "argument": "status",
+ "expected": "InProgress",
+ "state": "retry"
+ },
+ {
+ "matcher": "path",
+ "argument": "status",
+ "expected": "Validating",
+ "state": "retry"
+ },
+ {
+ "matcher": "path",
+ "argument": "status",
+ "expected": "Scheduled",
+ "state": "success"
+ }
+ ]
}
}
}
diff --git
a/providers/amazon/tests/system/amazon/aws/example_bedrock_batch_inference.py
b/providers/amazon/tests/system/amazon/aws/example_bedrock_batch_inference.py
new file mode 100644
index 00000000000..ed0da3b0aea
--- /dev/null
+++
b/providers/amazon/tests/system/amazon/aws/example_bedrock_batch_inference.py
@@ -0,0 +1,182 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import json
+import logging
+from datetime import datetime
+from tempfile import NamedTemporaryFile
+
+from airflow.decorators import task
+from airflow.models.baseoperator import chain
+from airflow.models.dag import DAG
+from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook
+from airflow.providers.amazon.aws.hooks.s3 import S3Hook
+from airflow.providers.amazon.aws.operators.bedrock import (
+ BedrockBatchInferenceOperator,
+ BedrockInvokeModelOperator,
+)
+from airflow.providers.amazon.aws.operators.s3 import (
+ S3CreateBucketOperator,
+ S3DeleteBucketOperator,
+)
+from airflow.providers.amazon.aws.sensors.bedrock import
BedrockBatchInferenceSensor
+from airflow.utils.trigger_rule import TriggerRule
+
+from system.amazon.aws.utils import SystemTestContextBuilder
+
+log = logging.getLogger(__name__)
+
+
+# Externally fetched variables:
+ROLE_ARN_KEY = "ROLE_ARN"
+sys_test_context_task =
SystemTestContextBuilder().add_variable(ROLE_ARN_KEY).build()
+
+DAG_ID = "example_bedrock_batch_inference"
+
+#######################################################################
+# NOTE:
+# Access to the following foundation model must be requested via
+# the Amazon Bedrock console and may take up to 24 hours to apply:
+#######################################################################
+
+CLAUDE_MODEL_ID = "anthropic.claude-3-sonnet-20240229-v1:0"
+ANTHROPIC_VERSION = "bedrock-2023-05-31"
+
+# Batch inferences currently require a minimum of 100 prompts per batch.
+MIN_NUM_PROMPTS = 300
+PROMPT_TEMPLATE = "Even numbers are red. Odd numbers are blue. What color is
{n}?"
+
+
+@task
+def generate_prompts(_env_id: str, _bucket: str, _key: str):
+ """
+ Bedrock Batch Inference requires one or more jsonl-formatted files in an
S3 bucket.
+
+ The JSONL format requires one serialized json object per prompt per line.
+ """
+ with NamedTemporaryFile(mode="w") as tmp_file:
+ # Generate the required number of prompts.
+ prompts = [
+ {
+ "modelInput": {
+ "anthropic_version": ANTHROPIC_VERSION,
+ "max_tokens": 1000,
+ "messages": [PROMPT_TEMPLATE.format(n=n)],
+ },
+ }
+ for n in range(MIN_NUM_PROMPTS)
+ ]
+
+ # Convert each prompt to serialized json, append a newline, and write
that line to the temp file.
+ tmp_file.writelines(json.dumps(prompt) + "\n" for prompt in prompts)
+
+ # Upload the file to S3.
+ S3Hook().conn.upload_file(tmp_file.name, _bucket, _key)
+
+
+@task(trigger_rule=TriggerRule.ALL_DONE)
+def stop_batch_inference(job_arn: str):
+ log.info("Stopping Batch Inference Job.")
+ BedrockHook().conn.stop_model_invocation_job(jobIdentifier=job_arn)
+
+
+with DAG(
+ dag_id=DAG_ID,
+ schedule="@once",
+ start_date=datetime(2021, 1, 1),
+ tags={"example"},
+ catchup=False,
+) as dag:
+ test_context = sys_test_context_task()
+ env_id = test_context["ENV_ID"]
+
+ bucket_name = f"{env_id}-bedrock"
+ input_data_s3_key = f"{env_id}/prompt_list.jsonl"
+ input_uri = f"s3://{bucket_name}/{input_data_s3_key}"
+ output_uri = f"s3://{bucket_name}/output/"
+ job_name = f"batch-infer-{env_id}"
+
+ # Test that this configuration works for a single prompt before trying the
batch inferences.
+ # [START howto_operator_invoke_claude_messages]
+ invoke_claude_messages = BedrockInvokeModelOperator(
+ task_id="invoke_claude_messages",
+ model_id=CLAUDE_MODEL_ID,
+ input_data={
+ "anthropic_version": "bedrock-2023-05-31",
+ "max_tokens": 1000,
+ "messages": [{"role": "user", "content":
PROMPT_TEMPLATE.format(n=42)}],
+ },
+ )
+ # [END howto_operator_invoke_claude_messages]
+
+ create_bucket = S3CreateBucketOperator(task_id="create_bucket",
bucket_name=bucket_name)
+
+ # [START howto_operator_bedrock_batch_inference]
+ batch_infer = BedrockBatchInferenceOperator(
+ task_id="batch_infer",
+ job_name=job_name,
+ role_arn=test_context[ROLE_ARN_KEY],
+ model_id=CLAUDE_MODEL_ID,
+ input_uri=input_uri,
+ output_uri=output_uri,
+ )
+ # [END howto_operator_bedrock_batch_inference]
+ batch_infer.wait_for_completion = False
+
+ # [START howto_sensor_bedrock_batch_inference_scheduled]
+ await_job_scheduled = BedrockBatchInferenceSensor(
+ task_id="await_job_scheduled",
+ job_arn=batch_infer.output,
+ success_state=BedrockBatchInferenceSensor.SuccessState.SCHEDULED,
+ )
+ # [END howto_sensor_bedrock_batch_inference_scheduled]
+
+ stop_job = stop_batch_inference(batch_infer.output)
+
+ delete_bucket = S3DeleteBucketOperator(
+ task_id="delete_bucket",
+ trigger_rule=TriggerRule.ALL_DONE,
+ bucket_name=bucket_name,
+ force_delete=True,
+ )
+
+ chain(
+ # TEST SETUP
+ test_context,
+ invoke_claude_messages,
+ create_bucket,
+ generate_prompts(env_id, bucket_name, input_data_s3_key),
+ # TEST BODY
+ batch_infer,
+ await_job_scheduled,
+ stop_job,
+ # TEST TEARDOWN
+ delete_bucket,
+ )
+
+ from tests_common.test_utils.watcher import watcher
+
+ # This test needs watcher in order to properly mark success/failure
+ # when "tearDown" task with trigger rule is part of the DAG
+ list(dag.tasks) >> watcher()
+
+
+from tests_common.test_utils.system_tests import get_test_run # noqa: E402
+
+# Needed to run the example DAG with pytest (see:
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)
diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_bedrock.py
b/providers/amazon/tests/unit/amazon/aws/operators/test_bedrock.py
index 527f0385e60..4591ab97c1b 100644
--- a/providers/amazon/tests/unit/amazon/aws/operators/test_bedrock.py
+++ b/providers/amazon/tests/unit/amazon/aws/operators/test_bedrock.py
@@ -28,6 +28,7 @@ from moto import mock_aws
from airflow.providers.amazon.aws.hooks.bedrock import BedrockAgentHook,
BedrockHook, BedrockRuntimeHook
from airflow.providers.amazon.aws.operators.bedrock import (
+ BedrockBatchInferenceOperator,
BedrockCreateDataSourceOperator,
BedrockCreateKnowledgeBaseOperator,
BedrockCreateProvisionedModelThroughputOperator,
@@ -549,3 +550,62 @@ class TestBedrockRaGOperator:
vector_search_config=self.VECTOR_SEARCH_CONFIG,
)
validate_template_fields(op)
+
+
+class TestBedrockBatchInferenceOperator:
+ JOB_NAME = "job_name"
+ ROLE_ARN = "role_arn"
+ MODEL_ID = "model_id"
+ INPUT_URI = "input_uri"
+ OUTPUT_URI = "output_uri"
+ INVOKE_KWARGS = {"tags": {"key": "key", "value": "value"}}
+
+ JOB_ARN = "job_arn"
+
+ @pytest.fixture
+ def mock_conn(self) -> Generator[BaseAwsConnection, None, None]:
+ with mock.patch.object(BedrockHook, "conn") as _conn:
+ _conn.create_model_invocation_job.return_value = {"jobArn":
self.JOB_ARN}
+ yield _conn
+
+ @pytest.fixture
+ def bedrock_hook(self) -> Generator[BedrockHook, None, None]:
+ with mock_aws():
+ hook = BedrockHook(aws_conn_id="aws_default")
+ yield hook
+
+ def setup_method(self):
+ self.operator = BedrockBatchInferenceOperator(
+ task_id="test_task",
+ job_name=self.JOB_NAME,
+ role_arn=self.ROLE_ARN,
+ model_id=self.MODEL_ID,
+ input_uri=self.INPUT_URI,
+ output_uri=self.OUTPUT_URI,
+ invoke_kwargs=self.INVOKE_KWARGS,
+ )
+ self.operator.defer = mock.MagicMock()
+
+ @pytest.mark.parametrize(
+ "wait_for_completion, deferrable",
+ [
+ pytest.param(False, False, id="no_wait"),
+ pytest.param(True, False, id="wait"),
+ pytest.param(False, True, id="defer"),
+ ],
+ )
+ @mock.patch.object(BedrockHook, "get_waiter")
+ def test_customize_model_wait_combinations(
+ self, _, wait_for_completion, deferrable, mock_conn, bedrock_hook
+ ):
+ self.operator.wait_for_completion = wait_for_completion
+ self.operator.deferrable = deferrable
+
+ response = self.operator.execute({})
+
+ assert response == self.JOB_ARN
+ assert bedrock_hook.get_waiter.call_count == wait_for_completion
+ assert self.operator.defer.call_count == deferrable
+
+ def test_template_fields(self):
+ validate_template_fields(self.operator)
diff --git a/providers/amazon/tests/unit/amazon/aws/sensors/test_bedrock.py
b/providers/amazon/tests/unit/amazon/aws/sensors/test_bedrock.py
index 52151cd9539..dc5ce64c4ad 100644
--- a/providers/amazon/tests/unit/amazon/aws/sensors/test_bedrock.py
+++ b/providers/amazon/tests/unit/amazon/aws/sensors/test_bedrock.py
@@ -21,9 +21,10 @@ from unittest import mock
import pytest
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, TaskDeferred
from airflow.providers.amazon.aws.hooks.bedrock import BedrockAgentHook,
BedrockHook
from airflow.providers.amazon.aws.sensors.bedrock import (
+ BedrockBatchInferenceSensor,
BedrockCustomizeModelCompletedSensor,
BedrockIngestionJobSensor,
BedrockKnowledgeBaseActiveSensor,
@@ -244,3 +245,95 @@ class TestBedrockIngestionJobSensor:
sensor = self.SENSOR(**self.default_op_kwargs, aws_conn_id=None)
with pytest.raises(AirflowException, match=sensor.FAILURE_MESSAGE):
sensor.poke({})
+
+
+class TestBedrockBatchInferenceSensor:
+ SENSOR = BedrockBatchInferenceSensor
+
+ @pytest.fixture(
+ params=[
+ BedrockBatchInferenceSensor.SuccessState.COMPLETED,
+ BedrockBatchInferenceSensor.SuccessState.SCHEDULED,
+ ]
+ )
+ def success_state(self, request):
+ return request.param
+
+ @pytest.fixture(params=["deferrable", "not deferrable"])
+ def is_deferrable(self, request):
+ # I did it this way instead of passing True/False purely so the pytest
names are more descriptive.
+ return request.param == "deferrable"
+
+ def setup_method(self, is_deferrable):
+ self.default_op_kwargs = dict(
+ task_id="test_bedrock_batch_inference_sensor",
+ job_arn="job_arn",
+ deferrable=is_deferrable,
+ )
+
+ def test_base_aws_op_attributes(self, success_state):
+ op = self.SENSOR(**self.default_op_kwargs, success_state=success_state)
+
+ if success_state == BedrockBatchInferenceSensor.SuccessState.COMPLETED:
+ assert "Scheduled" in op.INTERMEDIATE_STATES
+ assert "Scheduled" not in op.SUCCESS_STATES
+ elif success_state ==
BedrockBatchInferenceSensor.SuccessState.SCHEDULED:
+ assert "Scheduled" in op.SUCCESS_STATES
+ assert "Scheduled" not in op.INTERMEDIATE_STATES
+ assert op.hook.aws_conn_id == "aws_default"
+ assert op.hook._region_name is None
+ assert op.hook._verify is None
+ assert op.hook._config is None
+
+ op = self.SENSOR(
+ **self.default_op_kwargs,
+ aws_conn_id="aws-test-custom-conn",
+ region_name="eu-west-1",
+ verify=False,
+ botocore_config={"read_timeout": 42},
+ )
+ assert op.hook.aws_conn_id == "aws-test-custom-conn"
+ assert op.hook._region_name == "eu-west-1"
+ assert op.hook._verify is False
+ assert op.hook._config is not None
+ assert op.hook._config.read_timeout == 42
+
+ @mock.patch.object(BedrockHook, "conn")
+ def test_poke_success_states(self, mock_conn, success_state,
is_deferrable):
+ op = self.SENSOR(**self.default_op_kwargs, success_state=success_state)
+
+ for state in op.SUCCESS_STATES:
+ mock_conn.get_model_invocation_job.return_value = {"status": state}
+
+ if is_deferrable:
+ with pytest.raises(TaskDeferred):
+ op.execute({})
+ else:
+ assert op.poke({}) is True
+
+ @mock.patch.object(BedrockHook, "conn")
+ def test_poke_intermediate_states(self, mock_conn, success_state,
is_deferrable):
+ op = self.SENSOR(**self.default_op_kwargs, success_state=success_state)
+
+ for state in op.INTERMEDIATE_STATES:
+ mock_conn.get_model_invocation_job.return_value = {"status": state}
+
+ if is_deferrable:
+ with pytest.raises(TaskDeferred):
+ op.execute({})
+ else:
+ assert op.poke({}) is False
+
+ @mock.patch.object(BedrockHook, "conn")
+ def test_poke_failure_states(self, mock_conn, success_state,
is_deferrable):
+ op = self.SENSOR(**self.default_op_kwargs, success_state=success_state)
+
+ for state in op.FAILURE_STATES:
+ mock_conn.get_model_invocation_job.return_value = {"status": state}
+
+ if is_deferrable:
+ with pytest.raises(AirflowException, match=op.FAILURE_MESSAGE):
+ op.poke({})
+ else:
+ with pytest.raises(TaskDeferred):
+ op.execute({})
diff --git a/providers/amazon/tests/unit/amazon/aws/triggers/test_bedrock.py
b/providers/amazon/tests/unit/amazon/aws/triggers/test_bedrock.py
index be17bb958d0..963f96211e7 100644
--- a/providers/amazon/tests/unit/amazon/aws/triggers/test_bedrock.py
+++ b/providers/amazon/tests/unit/amazon/aws/triggers/test_bedrock.py
@@ -23,6 +23,8 @@ import pytest
from airflow.providers.amazon.aws.hooks.bedrock import BedrockAgentHook,
BedrockHook
from airflow.providers.amazon.aws.triggers.bedrock import (
+ BedrockBatchInferenceCompletedTrigger,
+ BedrockBatchInferenceScheduledTrigger,
BedrockCustomizeModelCompletedTrigger,
BedrockIngestionJobTrigger,
BedrockKnowledgeBaseActiveTrigger,
@@ -169,3 +171,63 @@ class
TestBedrockIngestionJobTrigger(TestBaseBedrockTrigger):
assert_expected_waiter_type(mock_get_waiter, self.EXPECTED_WAITER_NAME)
assert response == TriggerEvent({"status": "success",
"ingestion_job_id": self.INGESTION_JOB_ID})
mock_get_waiter().wait.assert_called_once()
+
+
+class TestBedrockBatchInferenceCompletedTrigger(TestBaseBedrockTrigger):
+ EXPECTED_WAITER_NAME = "batch_inference_complete"
+
+ JOB_ARN = "job_arn"
+
+ def test_serialization(self):
+ """Assert that arguments and classpath are correctly serialized."""
+ trigger = BedrockBatchInferenceCompletedTrigger(job_arn=self.JOB_ARN)
+
+ classpath, kwargs = trigger.serialize()
+
+ assert classpath == BASE_TRIGGER_CLASSPATH +
"BedrockBatchInferenceCompletedTrigger"
+ assert kwargs.get("job_arn") == self.JOB_ARN
+
+ @pytest.mark.asyncio
+ @mock.patch.object(BedrockHook, "get_waiter")
+ @mock.patch.object(BedrockHook, "get_async_conn")
+ async def test_run_success(self, mock_async_conn, mock_get_waiter):
+ mock_async_conn.__aenter__.return_value = mock.MagicMock()
+ mock_get_waiter().wait = AsyncMock()
+ trigger = BedrockBatchInferenceCompletedTrigger(job_arn=self.JOB_ARN)
+
+ generator = trigger.run()
+ response = await generator.asend(None)
+
+ assert_expected_waiter_type(mock_get_waiter, self.EXPECTED_WAITER_NAME)
+ assert response == TriggerEvent({"status": "success", "job_arn":
self.JOB_ARN})
+ mock_get_waiter().wait.assert_called_once()
+
+
+class TestBedrockBatchInferenceScheduledTrigger(TestBaseBedrockTrigger):
+ EXPECTED_WAITER_NAME = "batch_inference_scheduled"
+
+ JOB_ARN = "job_arn"
+
+ def test_serialization(self):
+ """Assert that arguments and classpath are correctly serialized."""
+ trigger = BedrockBatchInferenceScheduledTrigger(job_arn=self.JOB_ARN)
+
+ classpath, kwargs = trigger.serialize()
+
+ assert classpath == BASE_TRIGGER_CLASSPATH +
"BedrockBatchInferenceScheduledTrigger"
+ assert kwargs.get("job_arn") == self.JOB_ARN
+
+ @pytest.mark.asyncio
+ @mock.patch.object(BedrockHook, "get_waiter")
+ @mock.patch.object(BedrockHook, "get_async_conn")
+ async def test_run_success(self, mock_async_conn, mock_get_waiter):
+ mock_async_conn.__aenter__.return_value = mock.MagicMock()
+ mock_get_waiter().wait = AsyncMock()
+ trigger = BedrockBatchInferenceScheduledTrigger(job_arn=self.JOB_ARN)
+
+ generator = trigger.run()
+ response = await generator.asend(None)
+
+ assert_expected_waiter_type(mock_get_waiter, self.EXPECTED_WAITER_NAME)
+ assert response == TriggerEvent({"status": "success", "job_arn":
self.JOB_ARN})
+ mock_get_waiter().wait.assert_called_once()
diff --git a/providers/amazon/tests/unit/amazon/aws/waiters/test_bedrock.py
b/providers/amazon/tests/unit/amazon/aws/waiters/test_bedrock.py
index 3d8a3a1af1f..3214e662493 100644
--- a/providers/amazon/tests/unit/amazon/aws/waiters/test_bedrock.py
+++ b/providers/amazon/tests/unit/amazon/aws/waiters/test_bedrock.py
@@ -17,6 +17,9 @@
from __future__ import annotations
+import inspect
+import re
+import sys
from unittest import mock
import boto3
@@ -25,6 +28,7 @@ import pytest
from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook
from airflow.providers.amazon.aws.sensors.bedrock import (
+ BedrockBatchInferenceSensor,
BedrockCustomizeModelCompletedSensor,
BedrockProvisionModelThroughputCompletedSensor,
)
@@ -32,8 +36,17 @@ from airflow.providers.amazon.aws.sensors.bedrock import (
class TestBedrockCustomWaiters:
def test_service_waiters(self):
- assert "model_customization_job_complete" in
BedrockHook().list_waiters()
- assert "provisioned_model_throughput_complete" in
BedrockHook().list_waiters()
+ """Ensure that all custom Bedrock waiters have unit tests."""
+
+ def _class_tests_a_waiter(class_name: str) -> bool:
+ """Check if the class name starts with 'Test' and ends with
'Waiter'."""
+ return bool(re.match(r"^Test[A-Za-z]+Waiter$", class_name))
+
+ # Collect WAITER_NAME from each waiter test class in this module.
+ test_classes = inspect.getmembers(sys.modules[__name__],
inspect.isclass)
+ waiters_tested = [cls.WAITER_NAME for (name, cls) in test_classes if
_class_tests_a_waiter(name)]
+
+ assert sorted(BedrockHook()._list_custom_waiters()) ==
sorted(waiters_tested)
class TestBedrockCustomWaitersBase:
@@ -103,3 +116,75 @@ class
TestProvisionedModelThroughputCompleteWaiter(TestBedrockCustomWaitersBase)
BedrockHook().get_waiter(self.WAITER_NAME).wait(
jobIdentifier="job_id", WaiterConfig={"Delay": 0.01,
"MaxAttempts": 3}
)
+
+
+class TestBatchInferenceCompleteWaiter(TestBedrockCustomWaitersBase):
+ WAITER_NAME = "batch_inference_complete"
+ SENSOR = BedrockBatchInferenceSensor(
+ task_id="task_id",
+ job_arn="job_arn",
+ success_state=BedrockBatchInferenceSensor.SuccessState.COMPLETED,
+ )
+
+ @pytest.fixture
+ def mock_get_job(self):
+ with mock.patch.object(self.client, "get_model_invocation_job") as
mock_getter:
+ yield mock_getter
+
+ @pytest.mark.parametrize("state", SENSOR.SUCCESS_STATES)
+ def test_batch_inference_complete(self, state, mock_get_job):
+ mock_get_job.return_value = {"status": state}
+
+
BedrockHook().get_waiter(self.WAITER_NAME).wait(jobIdentifier="job_arn")
+
+ @pytest.mark.parametrize("state", SENSOR.FAILURE_STATES)
+ def test_batch_inference_failed(self, state, mock_get_job):
+ mock_get_job.return_value = {"status": state}
+
+ with pytest.raises(botocore.exceptions.WaiterError):
+
BedrockHook().get_waiter(self.WAITER_NAME).wait(jobIdentifier="job_arn")
+
+ def test_batch_inference_wait(self, mock_get_job):
+ wait = {"status": "InProgress"}
+ success = {"status": "Completed"}
+ mock_get_job.side_effect = [wait, wait, success]
+
+ BedrockHook().get_waiter(self.WAITER_NAME).wait(
+ jobIdentifier="job_arn", WaiterConfig={"Delay": 0.01,
"MaxAttempts": 3}
+ )
+
+
+class TestBatchInferenceScheduledWaiter(TestBedrockCustomWaitersBase):
+ WAITER_NAME = "batch_inference_scheduled"
+ SENSOR = BedrockBatchInferenceSensor(
+ task_id="task_id",
+ job_arn="job_arn",
+ success_state=BedrockBatchInferenceSensor.SuccessState.SCHEDULED,
+ )
+
+ @pytest.fixture
+ def mock_get_job(self):
+ with mock.patch.object(self.client, "get_model_invocation_job") as
mock_getter:
+ yield mock_getter
+
+ @pytest.mark.parametrize("state", SENSOR.SUCCESS_STATES)
+ def test_batch_inference_complete(self, state, mock_get_job):
+ mock_get_job.return_value = {"status": state}
+
+
BedrockHook().get_waiter(self.WAITER_NAME).wait(jobIdentifier="job_arn")
+
+ @pytest.mark.parametrize("state", SENSOR.FAILURE_STATES)
+ def test_batch_inference_failed(self, state, mock_get_job):
+ mock_get_job.return_value = {"status": state}
+
+ with pytest.raises(botocore.exceptions.WaiterError):
+
BedrockHook().get_waiter(self.WAITER_NAME).wait(jobIdentifier="job_arn")
+
+ def test_batch_inference_wait(self, mock_get_job):
+ wait = {"status": "InProgress"}
+ success = {"status": "Completed"}
+ mock_get_job.side_effect = [wait, wait, success]
+
+ BedrockHook().get_waiter(self.WAITER_NAME).wait(
+ jobIdentifier="job_arn", WaiterConfig={"Delay": 0.01,
"MaxAttempts": 3}
+ )