This is an automated email from the ASF dual-hosted git repository.

onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 7bd1aa6bc0f Add Bedrock Batch Inference Operator and accompanying 
parts (#48468)
7bd1aa6bc0f is described below

commit 7bd1aa6bc0f0f4d6703d21e3accb87a8ed9e614c
Author: D. Ferruzzi <[email protected]>
AuthorDate: Mon Apr 7 12:47:30 2025 -0700

    Add Bedrock Batch Inference Operator and accompanying parts (#48468)
    
    Add a new operator for Bedrock Batch Inference Jobs along with all the 
sensor/waiter/trigger/etc support, docs, and testing.
---
 providers/amazon/docs/operators/bedrock.rst        |  49 ++++++
 .../providers/amazon/aws/operators/bedrock.py      | 119 ++++++++++++++
 .../providers/amazon/aws/sensors/bedrock.py        | 110 +++++++++++++
 .../providers/amazon/aws/triggers/bedrock.py       |  98 +++++++++++
 .../providers/amazon/aws/waiters/bedrock.json      | 134 +++++++++++++++
 .../amazon/aws/example_bedrock_batch_inference.py  | 182 +++++++++++++++++++++
 .../unit/amazon/aws/operators/test_bedrock.py      |  60 +++++++
 .../tests/unit/amazon/aws/sensors/test_bedrock.py  |  95 ++++++++++-
 .../tests/unit/amazon/aws/triggers/test_bedrock.py |  62 +++++++
 .../tests/unit/amazon/aws/waiters/test_bedrock.py  |  89 +++++++++-
 10 files changed, 995 insertions(+), 3 deletions(-)

diff --git a/providers/amazon/docs/operators/bedrock.rst 
b/providers/amazon/docs/operators/bedrock.rst
index 9e110856b7e..f166a10c2dd 100644
--- a/providers/amazon/docs/operators/bedrock.rst
+++ b/providers/amazon/docs/operators/bedrock.rst
@@ -74,6 +74,14 @@ To invoke a Claude V2 model using the Completions API you 
would use:
     :start-after: [START howto_operator_invoke_claude_model]
     :end-before: [END howto_operator_invoke_claude_model]
 
+To invoke a Claude V3 Sonnet model using the Messages API you would use:
+
+.. exampleinclude:: 
/../../amazon/tests/system/amazon/aws/example_bedrock_batch_inference.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_operator_invoke_claude_messages]
+    :end-before: [END howto_operator_invoke_claude_messages]
+
 
 .. _howto/operator:BedrockCustomizeModelOperator:
 
@@ -237,6 +245,29 @@ Example using a PDF file in an Amazon S3 Bucket:
     :start-after: [START howto_operator_bedrock_external_sources_rag]
     :end-before: [END howto_operator_bedrock_external_sources_rag]
 
+.. _howto/operator:BedrockBatchInferenceOperator:
+
+Create an Amazon Bedrock Batch Inference Job
+============================================
+
+To creates a batch inference job to invoke a model on multiple prompts, you 
can use
+:class:`~airflow.providers.amazon.aws.operators.bedrock.BedrockBatchInferenceOperator`.
+
+The input must be formatted in jsonl and uploaded to an Amazon S3 bucket.  
Please see
+https://docs.aws.amazon.com/bedrock/latest/userguide/batch-inference.html for 
details.
+
+NOTE: Jobs are added to a queue and processed in order.  Given the potential 
wait times,
+and the fact that the optional timeout parameter is measured in hours, 
deferrable mode is
+recommended over "wait_for_completion" in this case.
+
+Example using an Amazon Bedrock Batch Inference Job:
+
+.. exampleinclude:: 
/../../amazon/tests/system/amazon/aws/example_bedrock_batch_inference.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_operator_bedrock_batch_inference]
+    :end-before: [END howto_operator_bedrock_batch_inference]
+
 
 Sensors
 -------
@@ -298,6 +329,24 @@ To wait on the state of an Amazon Bedrock data ingestion 
job until it reaches a
     :start-after: [START howto_sensor_bedrock_ingest_data]
     :end-before: [END howto_sensor_bedrock_ingest_data]
 
+.. _howto/sensor:BedrockBatchInferenceSensor:
+
+Wait for an Amazon Bedrock batch inference job
+==============================================
+
+To wait on the state of an Amazon Bedrock batch inference job until it reaches 
the "Scheduled" or "Completed"
+state you can use 
:class:`~airflow.providers.amazon.aws.sensors.bedrock.BedrockBatchInferenceScheduledSensor`
+
+Bedrock adds batch inference jobs to a queue, and they may take some time to 
complete.  If you want to wait
+for the job to complete, use TargetState.COMPLETED for the success_state, but 
if you only want to wait until
+the service confirms that the job is in the queue, use TargetState.SCHEDULED.
+
+.. exampleinclude:: 
/../../amazon/tests/system/amazon/aws/example_bedrock_batch_inference.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_sensor_bedrock_batch_inference_scheduled]
+    :end-before: [END howto_sensor_bedrock_batch_inference_scheduled]
+
 Reference
 ---------
 
diff --git 
a/providers/amazon/src/airflow/providers/amazon/aws/operators/bedrock.py 
b/providers/amazon/src/airflow/providers/amazon/aws/operators/bedrock.py
index ac8e86b92c8..f9715114b74 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/operators/bedrock.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/bedrock.py
@@ -33,6 +33,7 @@ from airflow.providers.amazon.aws.hooks.bedrock import (
 )
 from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
 from airflow.providers.amazon.aws.triggers.bedrock import (
+    BedrockBatchInferenceCompletedTrigger,
     BedrockCustomizeModelCompletedTrigger,
     BedrockIngestionJobTrigger,
     BedrockKnowledgeBaseActiveTrigger,
@@ -869,3 +870,121 @@ class 
BedrockRetrieveOperator(AwsBaseOperator[BedrockAgentRuntimeHook]):
 
         self.log.info("\nQuery: %s\nRetrieved: %s", self.retrieval_query, 
result["retrievalResults"])
         return result
+
+
+class BedrockBatchInferenceOperator(AwsBaseOperator[BedrockHook]):
+    """
+    Create a batch inference job to invoke a model on multiple prompts.
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the 
guide:
+        :ref:`howto/operator:BedrockBatchInferenceOperator`
+
+    :param job_name: A name to give the batch inference job. (templated)
+    :param role_arn: The ARN of the IAM role with permissions to create the 
knowledge base. (templated)
+    :param model_id: Name or ARN of the model to associate with this 
provisioned throughput. (templated)
+    :param input_uri: The S3 location of the input data. (templated)
+    :param output_uri: The S3 location of the output data. (templated)
+    :param invoke_kwargs: Additional keyword arguments to pass to the  API 
call. (templated)
+
+    :param wait_for_completion: Whether to wait for cluster to stop. (default: 
True)
+        NOTE:  The way batch inference jobs work, your jobs are added to a 
queue and done "eventually"
+        so using deferrable mode is much more practical than using 
wait_for_completion.
+    :param waiter_delay: Time in seconds to wait between status checks. 
(default: 60)
+    :param waiter_max_attempts: Maximum number of attempts to check for job 
completion. (default: 10)
+    :param deferrable: If True, the operator will wait asynchronously for the 
cluster to stop.
+        This implies waiting for completion. This mode requires aiobotocore 
module to be installed.
+        (default: False)
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+        If this is ``None`` or empty then the default boto3 behaviour is used. 
If
+        running Airflow in a distributed manner and aws_conn_id is None or
+        empty, then default boto3 configuration would be used (and must be
+        maintained on each worker node).
+    :param region_name: AWS region_name. If not specified then the default 
boto3 behaviour is used.
+    :param verify: Whether or not to verify SSL certificates. See:
+        
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+    :param botocore_config: Configuration dictionary (key-values) for botocore 
client. See:
+        
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
+    """
+
+    aws_hook_class = BedrockHook
+    template_fields: Sequence[str] = aws_template_fields(
+        "job_name",
+        "role_arn",
+        "model_id",
+        "input_uri",
+        "output_uri",
+        "invoke_kwargs",
+    )
+
+    def __init__(
+        self,
+        job_name: str,
+        role_arn: str,
+        model_id: str,
+        input_uri: str,
+        output_uri: str,
+        invoke_kwargs: dict[str, Any] | None = None,
+        wait_for_completion: bool = True,
+        waiter_delay: int = 60,
+        waiter_max_attempts: int = 10,
+        deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+        self.job_name = job_name
+        self.role_arn = role_arn
+        self.model_id = model_id
+        self.input_uri = input_uri
+        self.output_uri = output_uri
+        self.invoke_kwargs = invoke_kwargs or {}
+
+        self.wait_for_completion = wait_for_completion
+        self.waiter_delay = waiter_delay
+        self.waiter_max_attempts = waiter_max_attempts
+        self.deferrable = deferrable
+
+        self.activity = "Bedrock batch inference job"
+
+    def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> str:
+        validated_event = validate_execute_complete_event(event)
+
+        if validated_event["status"] != "success":
+            raise AirflowException(f"Error while running {self.activity}: 
{validated_event}")
+
+        self.log.info("%s '%s' complete.", self.activity, 
validated_event["job_arn"])
+
+        return validated_event["job_arn"]
+
+    def execute(self, context: Context) -> str:
+        response = self.hook.conn.create_model_invocation_job(
+            jobName=self.job_name,
+            roleArn=self.role_arn,
+            modelId=self.model_id,
+            inputDataConfig={"s3InputDataConfig": {"s3Uri": self.input_uri}},
+            outputDataConfig={"s3OutputDataConfig": {"s3Uri": 
self.output_uri}},
+            **self.invoke_kwargs,
+        )
+        job_arn = response["jobArn"]
+        self.log.info("%s '%s' started with ARN: %s", self.activity, 
self.job_name, job_arn)
+
+        task_description = f"for {self.activity} '{self.job_name}' to 
complete."
+        if self.deferrable:
+            self.log.info("Deferring %s", task_description)
+            self.defer(
+                trigger=BedrockBatchInferenceCompletedTrigger(
+                    job_arn=job_arn,
+                    waiter_delay=self.waiter_delay,
+                    waiter_max_attempts=self.waiter_max_attempts,
+                    aws_conn_id=self.aws_conn_id,
+                ),
+                method_name="execute_complete",
+            )
+        elif self.wait_for_completion:
+            self.log.info("Waiting %s", task_description)
+            self.hook.get_waiter(waiter_name="batch_inference_complete").wait(
+                jobIdentifier=job_arn,
+                WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": 
self.waiter_max_attempts},
+            )
+
+        return job_arn
diff --git 
a/providers/amazon/src/airflow/providers/amazon/aws/sensors/bedrock.py 
b/providers/amazon/src/airflow/providers/amazon/aws/sensors/bedrock.py
index 541702f4774..c74c3e772fb 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/sensors/bedrock.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/sensors/bedrock.py
@@ -26,6 +26,8 @@ from airflow.exceptions import AirflowException
 from airflow.providers.amazon.aws.hooks.bedrock import BedrockAgentHook, 
BedrockHook
 from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
 from airflow.providers.amazon.aws.triggers.bedrock import (
+    BedrockBatchInferenceCompletedTrigger,
+    BedrockBatchInferenceScheduledTrigger,
     BedrockCustomizeModelCompletedTrigger,
     BedrockIngestionJobTrigger,
     BedrockKnowledgeBaseActiveTrigger,
@@ -34,6 +36,7 @@ from airflow.providers.amazon.aws.triggers.bedrock import (
 from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
 
 if TYPE_CHECKING:
+    from airflow.providers.amazon.aws.triggers.bedrock import 
BedrockBaseBatchInferenceTrigger
     from airflow.utils.context import Context
 
 
@@ -368,3 +371,110 @@ class 
BedrockIngestionJobSensor(BedrockBaseSensor[BedrockAgentHook]):
             )
         else:
             super().execute(context=context)
+
+
+class BedrockBatchInferenceSensor(BedrockBaseSensor[BedrockHook]):
+    """
+    Poll the batch inference job status until it reaches a terminal state; 
fails if creation fails.
+
+    .. seealso::
+        For more information on how to use this sensor, take a look at the 
guide:
+        :ref:`howto/sensor:BedrockBatchInferenceSensor`
+
+    :param job_arn: The Amazon Resource Name (ARN) of the batch inference job. 
(templated)
+    :param success_state: A BedrockBatchInferenceSensor.TargetState; defaults 
to 'SCHEDULED' (templated)
+
+    :param deferrable: If True, the sensor will operate in deferrable more. 
This mode requires aiobotocore
+        module to be installed.
+        (default: False, but can be overridden in config file by setting 
default_deferrable to True)
+    :param poke_interval: Polling period in seconds to check for the status of 
the job. (default: 5)
+    :param max_retries: Number of times before returning the current state 
(default: 24)
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+        If this is ``None`` or empty then the default boto3 behaviour is used. 
If
+        running Airflow in a distributed manner and aws_conn_id is None or
+        empty, then default boto3 configuration would be used (and must be
+        maintained on each worker node).
+    :param region_name: AWS region_name. If not specified then the default 
boto3 behaviour is used.
+    :param verify: Whether or not to verify SSL certificates. See:
+        
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+    :param botocore_config: Configuration dictionary (key-values) for botocore 
client. See:
+        
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
+    """
+
+    class SuccessState:
+        """
+        Target state for the BedrockBatchInferenceSensor.
+
+        Bedrock adds batch inference jobs to a queue, and they may take some 
time to complete.
+        If you want to wait for the job to complete, use 
TargetState.COMPLETED, but if you only want
+        to wait until the service confirms that the job is in the queue, use 
TargetState.SCHEDULED.
+
+        The normal successful progression of states is:
+            Submitted > Validating > Scheduled > InProgress > 
PartiallyCompleted > Completed
+        """
+
+        SCHEDULED = "scheduled"
+        COMPLETED = "completed"
+
+    INTERMEDIATE_STATES: tuple[str, ...]  # Defined in __init__ based on 
target state
+    FAILURE_STATES: tuple[str, ...] = ("Failed", "Stopped", 
"PartiallyCompleted", "Expired")
+    SUCCESS_STATES: tuple[str, ...]  # Defined in __init__ based on target 
state
+    FAILURE_MESSAGE = "Bedrock batch inference job sensor failed."
+    INVALID_SUCCESS_STATE_MESSAGE = "success_state must be an instance of 
TargetState."
+
+    aws_hook_class = BedrockHook
+
+    template_fields: Sequence[str] = aws_template_fields("job_arn", 
"success_state")
+
+    def __init__(
+        self,
+        *,
+        job_arn: str,
+        success_state: SuccessState | str = SuccessState.SCHEDULED,
+        poke_interval: int = 120,
+        max_retries: int = 75,
+        **kwargs,
+    ) -> None:
+        super().__init__(**kwargs)
+        self.poke_interval = poke_interval
+        self.max_retries = max_retries
+        self.job_arn = job_arn
+        self.success_state = success_state
+
+        base_success_states: tuple[str, ...] = ("Completed",)
+        base_intermediate_states: tuple[str, ...] = ("Submitted", 
"InProgress", "Stopping", "Validating")
+        scheduled_state = ("Scheduled",)
+        self.trigger_class: type[BedrockBaseBatchInferenceTrigger]
+
+        if self.success_state == 
BedrockBatchInferenceSensor.SuccessState.COMPLETED:
+            intermediate_states = base_intermediate_states + scheduled_state
+            success_states = base_success_states
+            self.trigger_class = BedrockBatchInferenceCompletedTrigger
+        elif self.success_state == 
BedrockBatchInferenceSensor.SuccessState.SCHEDULED:
+            intermediate_states = base_intermediate_states
+            success_states = base_success_states + scheduled_state
+            self.trigger_class = BedrockBatchInferenceScheduledTrigger
+        else:
+            raise ValueError(
+                "Success states for BedrockBatchInferenceSensor must be set 
using a BedrockBatchInferenceSensor.SuccessState"
+            )
+
+        BedrockBatchInferenceSensor.INTERMEDIATE_STATES = intermediate_states 
or base_intermediate_states
+        BedrockBatchInferenceSensor.SUCCESS_STATES = success_states or 
base_success_states
+
+    def get_state(self) -> str:
+        return 
self.hook.conn.get_model_invocation_job(jobIdentifier=self.job_arn)["status"]
+
+    def execute(self, context: Context) -> Any:
+        if self.deferrable:
+            self.defer(
+                trigger=self.trigger_class(
+                    job_arn=self.job_arn,
+                    waiter_delay=int(self.poke_interval),
+                    waiter_max_attempts=self.max_retries,
+                    aws_conn_id=self.aws_conn_id,
+                ),
+                method_name="poke",
+            )
+        else:
+            super().execute(context=context)
diff --git 
a/providers/amazon/src/airflow/providers/amazon/aws/triggers/bedrock.py 
b/providers/amazon/src/airflow/providers/amazon/aws/triggers/bedrock.py
index 99d632d26e0..faac0f90753 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/triggers/bedrock.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/bedrock.py
@@ -20,6 +20,7 @@ from typing import TYPE_CHECKING
 
 from airflow.providers.amazon.aws.hooks.bedrock import BedrockAgentHook, 
BedrockHook
 from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
+from airflow.utils.types import NOTSET, ArgNotSet
 
 if TYPE_CHECKING:
     from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
@@ -182,3 +183,100 @@ class BedrockIngestionJobTrigger(AwsBaseWaiterTrigger):
 
     def hook(self) -> AwsGenericHook:
         return BedrockAgentHook(aws_conn_id=self.aws_conn_id)
+
+
+class BedrockBaseBatchInferenceTrigger(AwsBaseWaiterTrigger):
+    """
+    Trigger when a batch inference job is complete.
+
+    :param job_arn: The Amazon Resource Name (ARN) of the batch inference job.
+
+    :param waiter_delay: The amount of time in seconds to wait between 
attempts. (default: 120)
+    :param waiter_max_attempts: The maximum number of attempts to be made. 
(default: 75)
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+    """
+
+    def __init__(
+        self,
+        *,
+        job_arn: str,
+        waiter_name: str | ArgNotSet = NOTSET,  # This must be defined in the 
child class.
+        waiter_delay: int = 120,
+        waiter_max_attempts: int = 75,
+        aws_conn_id: str | None = None,
+    ) -> None:
+        if waiter_name == NOTSET:
+            raise NotImplementedError("Triggers must provide a waiter name.")
+
+        super().__init__(
+            serialized_fields={"job_arn": job_arn},
+            waiter_name=str(waiter_name),  # Cast a string to a string to make 
mypy happy
+            waiter_args={"jobIdentifier": job_arn},
+            failure_message="Bedrock batch inference job failed.",
+            status_message="Status of Bedrock batch inference job is",
+            status_queries=["status"],
+            return_key="job_arn",
+            return_value=job_arn,
+            waiter_delay=waiter_delay,
+            waiter_max_attempts=waiter_max_attempts,
+            aws_conn_id=aws_conn_id,
+        )
+
+    def hook(self) -> AwsGenericHook:
+        return BedrockHook(aws_conn_id=self.aws_conn_id)
+
+
+class BedrockBatchInferenceCompletedTrigger(BedrockBaseBatchInferenceTrigger):
+    """
+    Trigger when a batch inference job is complete.
+
+    :param job_arn: The Amazon Resource Name (ARN) of the batch inference job.
+
+    :param waiter_delay: The amount of time in seconds to wait between 
attempts. (default: 120)
+    :param waiter_max_attempts: The maximum number of attempts to be made. 
(default: 75)
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+    """
+
+    def __init__(
+        self,
+        *,
+        job_arn: str,
+        waiter_delay: int = 120,
+        waiter_max_attempts: int = 75,
+        aws_conn_id: str | None = None,
+    ) -> None:
+        super().__init__(
+            waiter_name="batch_inference_complete",
+            job_arn=job_arn,
+            waiter_delay=waiter_delay,
+            waiter_max_attempts=waiter_max_attempts,
+            aws_conn_id=aws_conn_id,
+        )
+
+
+class BedrockBatchInferenceScheduledTrigger(BedrockBaseBatchInferenceTrigger):
+    """
+    Trigger when a batch inference job is scheduled.
+
+    :param job_arn: The Amazon Resource Name (ARN) of the batch inference job.
+
+    :param waiter_delay: The amount of time in seconds to wait between 
attempts. (default: 120)
+    :param waiter_max_attempts: The maximum number of attempts to be made. 
(default: 75)
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+    """
+
+    def __init__(
+        self,
+        *,
+        job_arn: str,
+        waiter_delay: int = 120,
+        waiter_max_attempts: int = 75,
+        aws_conn_id: str | None = None,
+    ) -> None:
+        super().__init__(
+            waiter_name="batch_inference_scheduled",
+            job_arn=job_arn,
+            waiter_delay=waiter_delay,
+            waiter_max_attempts=waiter_max_attempts,
+            aws_conn_id=aws_conn_id,
+        )
diff --git 
a/providers/amazon/src/airflow/providers/amazon/aws/waiters/bedrock.json 
b/providers/amazon/src/airflow/providers/amazon/aws/waiters/bedrock.json
index c913b4dc7c6..18a6294a128 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/waiters/bedrock.json
+++ b/providers/amazon/src/airflow/providers/amazon/aws/waiters/bedrock.json
@@ -68,6 +68,140 @@
                     "state": "failure"
                 }
             ]
+        },
+        "batch_inference_complete": {
+            "delay": 120,
+            "maxAttempts": 75,
+            "operation": "GetModelInvocationJob",
+            "acceptors": [
+                {
+                    "matcher": "path",
+                    "argument": "status",
+                    "expected": "Completed",
+                    "state": "success"
+                },
+                {
+                    "matcher": "path",
+                    "argument": "status",
+                    "expected": "Failed",
+                    "state": "failure"
+                },
+                {
+                    "matcher": "path",
+                    "argument": "status",
+                    "expected": "Stopped",
+                    "state": "failure"
+                },
+                {
+                    "matcher": "path",
+                    "argument": "status",
+                    "expected": "PartiallyCompleted",
+                    "state": "failure"
+                },
+                {
+                    "matcher": "path",
+                    "argument": "status",
+                    "expected": "Expired",
+                    "state": "failure"
+                },
+                                {
+                    "matcher": "path",
+                    "argument": "status",
+                    "expected": "Stopping",
+                    "state": "failure"
+                },
+                {
+                    "matcher": "path",
+                    "argument": "status",
+                    "expected": "Submitted",
+                    "state": "retry"
+                },
+                {
+                    "matcher": "path",
+                    "argument": "status",
+                    "expected": "InProgress",
+                    "state": "retry"
+                },
+                {
+                    "matcher": "path",
+                    "argument": "status",
+                    "expected": "Validating",
+                    "state": "retry"
+                },
+                {
+                    "matcher": "path",
+                    "argument": "status",
+                    "expected": "Scheduled",
+                    "state": "retry"
+                }
+            ]
+        },
+        "batch_inference_scheduled": {
+            "delay": 120,
+            "maxAttempts": 75,
+            "operation": "GetModelInvocationJob",
+            "acceptors": [
+                {
+                    "matcher": "path",
+                    "argument": "status",
+                    "expected": "Completed",
+                    "state": "success"
+                },
+                {
+                    "matcher": "path",
+                    "argument": "status",
+                    "expected": "Failed",
+                    "state": "failure"
+                },
+                {
+                    "matcher": "path",
+                    "argument": "status",
+                    "expected": "Stopped",
+                    "state": "failure"
+                },
+                {
+                    "matcher": "path",
+                    "argument": "status",
+                    "expected": "Stopping",
+                    "state": "failure"
+                },
+                {
+                    "matcher": "path",
+                    "argument": "status",
+                    "expected": "PartiallyCompleted",
+                    "state": "failure"
+                },
+                {
+                    "matcher": "path",
+                    "argument": "status",
+                    "expected": "Expired",
+                    "state": "failure"
+                },
+                {
+                    "matcher": "path",
+                    "argument": "status",
+                    "expected": "Submitted",
+                    "state": "retry"
+                },
+                {
+                    "matcher": "path",
+                    "argument": "status",
+                    "expected": "InProgress",
+                    "state": "retry"
+                },
+                {
+                    "matcher": "path",
+                    "argument": "status",
+                    "expected": "Validating",
+                    "state": "retry"
+                },
+                {
+                    "matcher": "path",
+                    "argument": "status",
+                    "expected": "Scheduled",
+                    "state": "success"
+                }
+            ]
         }
     }
 }
diff --git 
a/providers/amazon/tests/system/amazon/aws/example_bedrock_batch_inference.py 
b/providers/amazon/tests/system/amazon/aws/example_bedrock_batch_inference.py
new file mode 100644
index 00000000000..ed0da3b0aea
--- /dev/null
+++ 
b/providers/amazon/tests/system/amazon/aws/example_bedrock_batch_inference.py
@@ -0,0 +1,182 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import json
+import logging
+from datetime import datetime
+from tempfile import NamedTemporaryFile
+
+from airflow.decorators import task
+from airflow.models.baseoperator import chain
+from airflow.models.dag import DAG
+from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook
+from airflow.providers.amazon.aws.hooks.s3 import S3Hook
+from airflow.providers.amazon.aws.operators.bedrock import (
+    BedrockBatchInferenceOperator,
+    BedrockInvokeModelOperator,
+)
+from airflow.providers.amazon.aws.operators.s3 import (
+    S3CreateBucketOperator,
+    S3DeleteBucketOperator,
+)
+from airflow.providers.amazon.aws.sensors.bedrock import 
BedrockBatchInferenceSensor
+from airflow.utils.trigger_rule import TriggerRule
+
+from system.amazon.aws.utils import SystemTestContextBuilder
+
+log = logging.getLogger(__name__)
+
+
+# Externally fetched variables:
+ROLE_ARN_KEY = "ROLE_ARN"
+sys_test_context_task = 
SystemTestContextBuilder().add_variable(ROLE_ARN_KEY).build()
+
+DAG_ID = "example_bedrock_batch_inference"
+
+#######################################################################
+# NOTE:
+#   Access to the following foundation model must be requested via
+#   the Amazon Bedrock console and may take up to 24 hours to apply:
+#######################################################################
+
+CLAUDE_MODEL_ID = "anthropic.claude-3-sonnet-20240229-v1:0"
+ANTHROPIC_VERSION = "bedrock-2023-05-31"
+
+# Batch inferences currently require a minimum of 100 prompts per batch.
+MIN_NUM_PROMPTS = 300
+PROMPT_TEMPLATE = "Even numbers are red. Odd numbers are blue. What color is 
{n}?"
+
+
+@task
+def generate_prompts(_env_id: str, _bucket: str, _key: str):
+    """
+    Bedrock Batch Inference requires one or more jsonl-formatted files in an 
S3 bucket.
+
+    The JSONL format requires one serialized json object per prompt per line.
+    """
+    with NamedTemporaryFile(mode="w") as tmp_file:
+        # Generate the required number of prompts.
+        prompts = [
+            {
+                "modelInput": {
+                    "anthropic_version": ANTHROPIC_VERSION,
+                    "max_tokens": 1000,
+                    "messages": [PROMPT_TEMPLATE.format(n=n)],
+                },
+            }
+            for n in range(MIN_NUM_PROMPTS)
+        ]
+
+        # Convert each prompt to serialized json, append a newline, and write 
that line to the temp file.
+        tmp_file.writelines(json.dumps(prompt) + "\n" for prompt in prompts)
+
+        # Upload the file to S3.
+        S3Hook().conn.upload_file(tmp_file.name, _bucket, _key)
+
+
+@task(trigger_rule=TriggerRule.ALL_DONE)
+def stop_batch_inference(job_arn: str):
+    log.info("Stopping Batch Inference Job.")
+    BedrockHook().conn.stop_model_invocation_job(jobIdentifier=job_arn)
+
+
+with DAG(
+    dag_id=DAG_ID,
+    schedule="@once",
+    start_date=datetime(2021, 1, 1),
+    tags={"example"},
+    catchup=False,
+) as dag:
+    test_context = sys_test_context_task()
+    env_id = test_context["ENV_ID"]
+
+    bucket_name = f"{env_id}-bedrock"
+    input_data_s3_key = f"{env_id}/prompt_list.jsonl"
+    input_uri = f"s3://{bucket_name}/{input_data_s3_key}"
+    output_uri = f"s3://{bucket_name}/output/"
+    job_name = f"batch-infer-{env_id}"
+
+    # Test that this configuration works for a single prompt before trying the 
batch inferences.
+    # [START howto_operator_invoke_claude_messages]
+    invoke_claude_messages = BedrockInvokeModelOperator(
+        task_id="invoke_claude_messages",
+        model_id=CLAUDE_MODEL_ID,
+        input_data={
+            "anthropic_version": "bedrock-2023-05-31",
+            "max_tokens": 1000,
+            "messages": [{"role": "user", "content": 
PROMPT_TEMPLATE.format(n=42)}],
+        },
+    )
+    # [END howto_operator_invoke_claude_messages]
+
+    create_bucket = S3CreateBucketOperator(task_id="create_bucket", 
bucket_name=bucket_name)
+
+    # [START howto_operator_bedrock_batch_inference]
+    batch_infer = BedrockBatchInferenceOperator(
+        task_id="batch_infer",
+        job_name=job_name,
+        role_arn=test_context[ROLE_ARN_KEY],
+        model_id=CLAUDE_MODEL_ID,
+        input_uri=input_uri,
+        output_uri=output_uri,
+    )
+    # [END howto_operator_bedrock_batch_inference]
+    batch_infer.wait_for_completion = False
+
+    # [START howto_sensor_bedrock_batch_inference_scheduled]
+    await_job_scheduled = BedrockBatchInferenceSensor(
+        task_id="await_job_scheduled",
+        job_arn=batch_infer.output,
+        success_state=BedrockBatchInferenceSensor.SuccessState.SCHEDULED,
+    )
+    # [END howto_sensor_bedrock_batch_inference_scheduled]
+
+    stop_job = stop_batch_inference(batch_infer.output)
+
+    delete_bucket = S3DeleteBucketOperator(
+        task_id="delete_bucket",
+        trigger_rule=TriggerRule.ALL_DONE,
+        bucket_name=bucket_name,
+        force_delete=True,
+    )
+
+    chain(
+        # TEST SETUP
+        test_context,
+        invoke_claude_messages,
+        create_bucket,
+        generate_prompts(env_id, bucket_name, input_data_s3_key),
+        # TEST BODY
+        batch_infer,
+        await_job_scheduled,
+        stop_job,
+        # TEST TEARDOWN
+        delete_bucket,
+    )
+
+    from tests_common.test_utils.watcher import watcher
+
+    # This test needs watcher in order to properly mark success/failure
+    # when "tearDown" task with trigger rule is part of the DAG
+    list(dag.tasks) >> watcher()
+
+
+from tests_common.test_utils.system_tests import get_test_run  # noqa: E402
+
+# Needed to run the example DAG with pytest (see: 
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)
diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_bedrock.py 
b/providers/amazon/tests/unit/amazon/aws/operators/test_bedrock.py
index 527f0385e60..4591ab97c1b 100644
--- a/providers/amazon/tests/unit/amazon/aws/operators/test_bedrock.py
+++ b/providers/amazon/tests/unit/amazon/aws/operators/test_bedrock.py
@@ -28,6 +28,7 @@ from moto import mock_aws
 
 from airflow.providers.amazon.aws.hooks.bedrock import BedrockAgentHook, 
BedrockHook, BedrockRuntimeHook
 from airflow.providers.amazon.aws.operators.bedrock import (
+    BedrockBatchInferenceOperator,
     BedrockCreateDataSourceOperator,
     BedrockCreateKnowledgeBaseOperator,
     BedrockCreateProvisionedModelThroughputOperator,
@@ -549,3 +550,62 @@ class TestBedrockRaGOperator:
             vector_search_config=self.VECTOR_SEARCH_CONFIG,
         )
         validate_template_fields(op)
+
+
+class TestBedrockBatchInferenceOperator:
+    JOB_NAME = "job_name"
+    ROLE_ARN = "role_arn"
+    MODEL_ID = "model_id"
+    INPUT_URI = "input_uri"
+    OUTPUT_URI = "output_uri"
+    INVOKE_KWARGS = {"tags": {"key": "key", "value": "value"}}
+
+    JOB_ARN = "job_arn"
+
+    @pytest.fixture
+    def mock_conn(self) -> Generator[BaseAwsConnection, None, None]:
+        with mock.patch.object(BedrockHook, "conn") as _conn:
+            _conn.create_model_invocation_job.return_value = {"jobArn": 
self.JOB_ARN}
+            yield _conn
+
+    @pytest.fixture
+    def bedrock_hook(self) -> Generator[BedrockHook, None, None]:
+        with mock_aws():
+            hook = BedrockHook(aws_conn_id="aws_default")
+            yield hook
+
+    def setup_method(self):
+        self.operator = BedrockBatchInferenceOperator(
+            task_id="test_task",
+            job_name=self.JOB_NAME,
+            role_arn=self.ROLE_ARN,
+            model_id=self.MODEL_ID,
+            input_uri=self.INPUT_URI,
+            output_uri=self.OUTPUT_URI,
+            invoke_kwargs=self.INVOKE_KWARGS,
+        )
+        self.operator.defer = mock.MagicMock()
+
+    @pytest.mark.parametrize(
+        "wait_for_completion, deferrable",
+        [
+            pytest.param(False, False, id="no_wait"),
+            pytest.param(True, False, id="wait"),
+            pytest.param(False, True, id="defer"),
+        ],
+    )
+    @mock.patch.object(BedrockHook, "get_waiter")
+    def test_customize_model_wait_combinations(
+        self, _, wait_for_completion, deferrable, mock_conn, bedrock_hook
+    ):
+        self.operator.wait_for_completion = wait_for_completion
+        self.operator.deferrable = deferrable
+
+        response = self.operator.execute({})
+
+        assert response == self.JOB_ARN
+        assert bedrock_hook.get_waiter.call_count == wait_for_completion
+        assert self.operator.defer.call_count == deferrable
+
+    def test_template_fields(self):
+        validate_template_fields(self.operator)
diff --git a/providers/amazon/tests/unit/amazon/aws/sensors/test_bedrock.py 
b/providers/amazon/tests/unit/amazon/aws/sensors/test_bedrock.py
index 52151cd9539..dc5ce64c4ad 100644
--- a/providers/amazon/tests/unit/amazon/aws/sensors/test_bedrock.py
+++ b/providers/amazon/tests/unit/amazon/aws/sensors/test_bedrock.py
@@ -21,9 +21,10 @@ from unittest import mock
 
 import pytest
 
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, TaskDeferred
 from airflow.providers.amazon.aws.hooks.bedrock import BedrockAgentHook, 
BedrockHook
 from airflow.providers.amazon.aws.sensors.bedrock import (
+    BedrockBatchInferenceSensor,
     BedrockCustomizeModelCompletedSensor,
     BedrockIngestionJobSensor,
     BedrockKnowledgeBaseActiveSensor,
@@ -244,3 +245,95 @@ class TestBedrockIngestionJobSensor:
         sensor = self.SENSOR(**self.default_op_kwargs, aws_conn_id=None)
         with pytest.raises(AirflowException, match=sensor.FAILURE_MESSAGE):
             sensor.poke({})
+
+
+class TestBedrockBatchInferenceSensor:
+    SENSOR = BedrockBatchInferenceSensor
+
+    @pytest.fixture(
+        params=[
+            BedrockBatchInferenceSensor.SuccessState.COMPLETED,
+            BedrockBatchInferenceSensor.SuccessState.SCHEDULED,
+        ]
+    )
+    def success_state(self, request):
+        return request.param
+
+    @pytest.fixture(params=["deferrable", "not deferrable"])
+    def is_deferrable(self, request):
+        # I did it this way instead of passing True/False purely so the pytest 
names are more descriptive.
+        return request.param == "deferrable"
+
+    def setup_method(self, is_deferrable):
+        self.default_op_kwargs = dict(
+            task_id="test_bedrock_batch_inference_sensor",
+            job_arn="job_arn",
+            deferrable=is_deferrable,
+        )
+
+    def test_base_aws_op_attributes(self, success_state):
+        op = self.SENSOR(**self.default_op_kwargs, success_state=success_state)
+
+        if success_state == BedrockBatchInferenceSensor.SuccessState.COMPLETED:
+            assert "Scheduled" in op.INTERMEDIATE_STATES
+            assert "Scheduled" not in op.SUCCESS_STATES
+        elif success_state == 
BedrockBatchInferenceSensor.SuccessState.SCHEDULED:
+            assert "Scheduled" in op.SUCCESS_STATES
+            assert "Scheduled" not in op.INTERMEDIATE_STATES
+        assert op.hook.aws_conn_id == "aws_default"
+        assert op.hook._region_name is None
+        assert op.hook._verify is None
+        assert op.hook._config is None
+
+        op = self.SENSOR(
+            **self.default_op_kwargs,
+            aws_conn_id="aws-test-custom-conn",
+            region_name="eu-west-1",
+            verify=False,
+            botocore_config={"read_timeout": 42},
+        )
+        assert op.hook.aws_conn_id == "aws-test-custom-conn"
+        assert op.hook._region_name == "eu-west-1"
+        assert op.hook._verify is False
+        assert op.hook._config is not None
+        assert op.hook._config.read_timeout == 42
+
+    @mock.patch.object(BedrockHook, "conn")
+    def test_poke_success_states(self, mock_conn, success_state, 
is_deferrable):
+        op = self.SENSOR(**self.default_op_kwargs, success_state=success_state)
+
+        for state in op.SUCCESS_STATES:
+            mock_conn.get_model_invocation_job.return_value = {"status": state}
+
+            if is_deferrable:
+                with pytest.raises(TaskDeferred):
+                    op.execute({})
+            else:
+                assert op.poke({}) is True
+
+    @mock.patch.object(BedrockHook, "conn")
+    def test_poke_intermediate_states(self, mock_conn, success_state, 
is_deferrable):
+        op = self.SENSOR(**self.default_op_kwargs, success_state=success_state)
+
+        for state in op.INTERMEDIATE_STATES:
+            mock_conn.get_model_invocation_job.return_value = {"status": state}
+
+            if is_deferrable:
+                with pytest.raises(TaskDeferred):
+                    op.execute({})
+            else:
+                assert op.poke({}) is False
+
+    @mock.patch.object(BedrockHook, "conn")
+    def test_poke_failure_states(self, mock_conn, success_state, 
is_deferrable):
+        op = self.SENSOR(**self.default_op_kwargs, success_state=success_state)
+
+        for state in op.FAILURE_STATES:
+            mock_conn.get_model_invocation_job.return_value = {"status": state}
+
+            if is_deferrable:
+                with pytest.raises(AirflowException, match=op.FAILURE_MESSAGE):
+                    op.poke({})
+            else:
+                with pytest.raises(TaskDeferred):
+                    op.execute({})
diff --git a/providers/amazon/tests/unit/amazon/aws/triggers/test_bedrock.py 
b/providers/amazon/tests/unit/amazon/aws/triggers/test_bedrock.py
index be17bb958d0..963f96211e7 100644
--- a/providers/amazon/tests/unit/amazon/aws/triggers/test_bedrock.py
+++ b/providers/amazon/tests/unit/amazon/aws/triggers/test_bedrock.py
@@ -23,6 +23,8 @@ import pytest
 
 from airflow.providers.amazon.aws.hooks.bedrock import BedrockAgentHook, 
BedrockHook
 from airflow.providers.amazon.aws.triggers.bedrock import (
+    BedrockBatchInferenceCompletedTrigger,
+    BedrockBatchInferenceScheduledTrigger,
     BedrockCustomizeModelCompletedTrigger,
     BedrockIngestionJobTrigger,
     BedrockKnowledgeBaseActiveTrigger,
@@ -169,3 +171,63 @@ class 
TestBedrockIngestionJobTrigger(TestBaseBedrockTrigger):
         assert_expected_waiter_type(mock_get_waiter, self.EXPECTED_WAITER_NAME)
         assert response == TriggerEvent({"status": "success", 
"ingestion_job_id": self.INGESTION_JOB_ID})
         mock_get_waiter().wait.assert_called_once()
+
+
+class TestBedrockBatchInferenceCompletedTrigger(TestBaseBedrockTrigger):
+    EXPECTED_WAITER_NAME = "batch_inference_complete"
+
+    JOB_ARN = "job_arn"
+
+    def test_serialization(self):
+        """Assert that arguments and classpath are correctly serialized."""
+        trigger = BedrockBatchInferenceCompletedTrigger(job_arn=self.JOB_ARN)
+
+        classpath, kwargs = trigger.serialize()
+
+        assert classpath == BASE_TRIGGER_CLASSPATH + 
"BedrockBatchInferenceCompletedTrigger"
+        assert kwargs.get("job_arn") == self.JOB_ARN
+
+    @pytest.mark.asyncio
+    @mock.patch.object(BedrockHook, "get_waiter")
+    @mock.patch.object(BedrockHook, "get_async_conn")
+    async def test_run_success(self, mock_async_conn, mock_get_waiter):
+        mock_async_conn.__aenter__.return_value = mock.MagicMock()
+        mock_get_waiter().wait = AsyncMock()
+        trigger = BedrockBatchInferenceCompletedTrigger(job_arn=self.JOB_ARN)
+
+        generator = trigger.run()
+        response = await generator.asend(None)
+
+        assert_expected_waiter_type(mock_get_waiter, self.EXPECTED_WAITER_NAME)
+        assert response == TriggerEvent({"status": "success", "job_arn": 
self.JOB_ARN})
+        mock_get_waiter().wait.assert_called_once()
+
+
+class TestBedrockBatchInferenceScheduledTrigger(TestBaseBedrockTrigger):
+    EXPECTED_WAITER_NAME = "batch_inference_scheduled"
+
+    JOB_ARN = "job_arn"
+
+    def test_serialization(self):
+        """Assert that arguments and classpath are correctly serialized."""
+        trigger = BedrockBatchInferenceScheduledTrigger(job_arn=self.JOB_ARN)
+
+        classpath, kwargs = trigger.serialize()
+
+        assert classpath == BASE_TRIGGER_CLASSPATH + 
"BedrockBatchInferenceScheduledTrigger"
+        assert kwargs.get("job_arn") == self.JOB_ARN
+
+    @pytest.mark.asyncio
+    @mock.patch.object(BedrockHook, "get_waiter")
+    @mock.patch.object(BedrockHook, "get_async_conn")
+    async def test_run_success(self, mock_async_conn, mock_get_waiter):
+        mock_async_conn.__aenter__.return_value = mock.MagicMock()
+        mock_get_waiter().wait = AsyncMock()
+        trigger = BedrockBatchInferenceScheduledTrigger(job_arn=self.JOB_ARN)
+
+        generator = trigger.run()
+        response = await generator.asend(None)
+
+        assert_expected_waiter_type(mock_get_waiter, self.EXPECTED_WAITER_NAME)
+        assert response == TriggerEvent({"status": "success", "job_arn": 
self.JOB_ARN})
+        mock_get_waiter().wait.assert_called_once()
diff --git a/providers/amazon/tests/unit/amazon/aws/waiters/test_bedrock.py 
b/providers/amazon/tests/unit/amazon/aws/waiters/test_bedrock.py
index 3d8a3a1af1f..3214e662493 100644
--- a/providers/amazon/tests/unit/amazon/aws/waiters/test_bedrock.py
+++ b/providers/amazon/tests/unit/amazon/aws/waiters/test_bedrock.py
@@ -17,6 +17,9 @@
 
 from __future__ import annotations
 
+import inspect
+import re
+import sys
 from unittest import mock
 
 import boto3
@@ -25,6 +28,7 @@ import pytest
 
 from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook
 from airflow.providers.amazon.aws.sensors.bedrock import (
+    BedrockBatchInferenceSensor,
     BedrockCustomizeModelCompletedSensor,
     BedrockProvisionModelThroughputCompletedSensor,
 )
@@ -32,8 +36,17 @@ from airflow.providers.amazon.aws.sensors.bedrock import (
 
 class TestBedrockCustomWaiters:
     def test_service_waiters(self):
-        assert "model_customization_job_complete" in 
BedrockHook().list_waiters()
-        assert "provisioned_model_throughput_complete" in 
BedrockHook().list_waiters()
+        """Ensure that all custom Bedrock waiters have unit tests."""
+
+        def _class_tests_a_waiter(class_name: str) -> bool:
+            """Check if the class name starts with 'Test' and ends with 
'Waiter'."""
+            return bool(re.match(r"^Test[A-Za-z]+Waiter$", class_name))
+
+        # Collect WAITER_NAME from each waiter test class in this module.
+        test_classes = inspect.getmembers(sys.modules[__name__], 
inspect.isclass)
+        waiters_tested = [cls.WAITER_NAME for (name, cls) in test_classes if 
_class_tests_a_waiter(name)]
+
+        assert sorted(BedrockHook()._list_custom_waiters()) == 
sorted(waiters_tested)
 
 
 class TestBedrockCustomWaitersBase:
@@ -103,3 +116,75 @@ class 
TestProvisionedModelThroughputCompleteWaiter(TestBedrockCustomWaitersBase)
         BedrockHook().get_waiter(self.WAITER_NAME).wait(
             jobIdentifier="job_id", WaiterConfig={"Delay": 0.01, 
"MaxAttempts": 3}
         )
+
+
+class TestBatchInferenceCompleteWaiter(TestBedrockCustomWaitersBase):
+    WAITER_NAME = "batch_inference_complete"
+    SENSOR = BedrockBatchInferenceSensor(
+        task_id="task_id",
+        job_arn="job_arn",
+        success_state=BedrockBatchInferenceSensor.SuccessState.COMPLETED,
+    )
+
+    @pytest.fixture
+    def mock_get_job(self):
+        with mock.patch.object(self.client, "get_model_invocation_job") as 
mock_getter:
+            yield mock_getter
+
+    @pytest.mark.parametrize("state", SENSOR.SUCCESS_STATES)
+    def test_batch_inference_complete(self, state, mock_get_job):
+        mock_get_job.return_value = {"status": state}
+
+        
BedrockHook().get_waiter(self.WAITER_NAME).wait(jobIdentifier="job_arn")
+
+    @pytest.mark.parametrize("state", SENSOR.FAILURE_STATES)
+    def test_batch_inference_failed(self, state, mock_get_job):
+        mock_get_job.return_value = {"status": state}
+
+        with pytest.raises(botocore.exceptions.WaiterError):
+            
BedrockHook().get_waiter(self.WAITER_NAME).wait(jobIdentifier="job_arn")
+
+    def test_batch_inference_wait(self, mock_get_job):
+        wait = {"status": "InProgress"}
+        success = {"status": "Completed"}
+        mock_get_job.side_effect = [wait, wait, success]
+
+        BedrockHook().get_waiter(self.WAITER_NAME).wait(
+            jobIdentifier="job_arn", WaiterConfig={"Delay": 0.01, 
"MaxAttempts": 3}
+        )
+
+
+class TestBatchInferenceScheduledWaiter(TestBedrockCustomWaitersBase):
+    WAITER_NAME = "batch_inference_scheduled"
+    SENSOR = BedrockBatchInferenceSensor(
+        task_id="task_id",
+        job_arn="job_arn",
+        success_state=BedrockBatchInferenceSensor.SuccessState.SCHEDULED,
+    )
+
+    @pytest.fixture
+    def mock_get_job(self):
+        with mock.patch.object(self.client, "get_model_invocation_job") as 
mock_getter:
+            yield mock_getter
+
+    @pytest.mark.parametrize("state", SENSOR.SUCCESS_STATES)
+    def test_batch_inference_complete(self, state, mock_get_job):
+        mock_get_job.return_value = {"status": state}
+
+        
BedrockHook().get_waiter(self.WAITER_NAME).wait(jobIdentifier="job_arn")
+
+    @pytest.mark.parametrize("state", SENSOR.FAILURE_STATES)
+    def test_batch_inference_failed(self, state, mock_get_job):
+        mock_get_job.return_value = {"status": state}
+
+        with pytest.raises(botocore.exceptions.WaiterError):
+            
BedrockHook().get_waiter(self.WAITER_NAME).wait(jobIdentifier="job_arn")
+
+    def test_batch_inference_wait(self, mock_get_job):
+        wait = {"status": "InProgress"}
+        success = {"status": "Completed"}
+        mock_get_job.side_effect = [wait, wait, success]
+
+        BedrockHook().get_waiter(self.WAITER_NAME).wait(
+            jobIdentifier="job_arn", WaiterConfig={"Delay": 0.01, 
"MaxAttempts": 3}
+        )

Reply via email to