This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 9dd77520be Introduce Amazon Comprehend Service (#39592)
9dd77520be is described below

commit 9dd77520be3d8492156958d57b63b5779a3f55eb
Author: gopidesupavan <[email protected]>
AuthorDate: Wed May 15 15:05:53 2024 +0100

    Introduce Amazon Comprehend Service (#39592)
---
 airflow/providers/amazon/aws/hooks/comprehend.py   |  37 ++++
 .../providers/amazon/aws/operators/comprehend.py   | 192 +++++++++++++++++++++
 airflow/providers/amazon/aws/sensors/comprehend.py | 147 ++++++++++++++++
 .../providers/amazon/aws/triggers/comprehend.py    |  61 +++++++
 .../providers/amazon/aws/waiters/comprehend.json   |  49 ++++++
 airflow/providers/amazon/provider.yaml             |  18 ++
 .../operators/comprehend.rst                       |  74 ++++++++
 .../aws/[email protected]          | Bin 0 -> 7254 bytes
 docs/spelling_wordlist.txt                         |   2 +
 tests/always/test_project_structure.py             |   2 +
 .../providers/amazon/aws/hooks/test_comprehend.py  |  31 ++++
 .../amazon/aws/operators/test_comprehend.py        | 163 +++++++++++++++++
 .../amazon/aws/sensors/test_comprehend.py          |  94 ++++++++++
 .../amazon/aws/triggers/test_comprehend.py         |  67 +++++++
 .../amazon/aws/waiters/test_comprehend.py          |  71 ++++++++
 .../providers/amazon/aws/example_comprehend.py     | 137 +++++++++++++++
 16 files changed, 1145 insertions(+)

diff --git a/airflow/providers/amazon/aws/hooks/comprehend.py 
b/airflow/providers/amazon/aws/hooks/comprehend.py
new file mode 100644
index 0000000000..897aaf72ee
--- /dev/null
+++ b/airflow/providers/amazon/aws/hooks/comprehend.py
@@ -0,0 +1,37 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
+
+
+class ComprehendHook(AwsBaseHook):
+    """
+    Interact with AWS Comprehend.
+
+    Provide thin wrapper around 
:external+boto3:py:class:`boto3.client("comprehend") <Comprehend.Client>`.
+
+    Additional arguments (such as ``aws_conn_id``) may be specified and
+    are passed down to the underlying AwsBaseHook.
+
+    .. seealso::
+        - :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
+    """
+
+    def __init__(self, *args, **kwargs) -> None:
+        kwargs["client_type"] = "comprehend"
+        super().__init__(*args, **kwargs)
diff --git a/airflow/providers/amazon/aws/operators/comprehend.py 
b/airflow/providers/amazon/aws/operators/comprehend.py
new file mode 100644
index 0000000000..780e227af4
--- /dev/null
+++ b/airflow/providers/amazon/aws/operators/comprehend.py
@@ -0,0 +1,192 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from functools import cached_property
+from typing import TYPE_CHECKING, Any, Sequence
+
+from airflow.configuration import conf
+from airflow.exceptions import AirflowException
+from airflow.providers.amazon.aws.hooks.comprehend import ComprehendHook
+from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
+from airflow.providers.amazon.aws.triggers.comprehend import 
ComprehendPiiEntitiesDetectionJobCompletedTrigger
+from airflow.providers.amazon.aws.utils import validate_execute_complete_event
+from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
+from airflow.utils.timezone import utcnow
+
+if TYPE_CHECKING:
+    import boto3
+
+    from airflow.utils.context import Context
+
+
+class ComprehendBaseOperator(AwsBaseOperator[ComprehendHook]):
+    """
+    This is the base operator for Comprehend Service operators (not supposed 
to be used directly in DAGs).
+
+    :param input_data_config: The input properties for a PII entities 
detection job. (templated)
+    :param output_data_config: Provides `configuration` parameters for the 
output of PII entity detection
+        jobs. (templated)
+    :param data_access_role_arn: The Amazon Resource Name (ARN) of the IAM 
role that grants Amazon Comprehend
+        read access to your input data. (templated)
+    :param language_code: The language of the input documents. (templated)
+    """
+
+    aws_hook_class = ComprehendHook
+
+    template_fields: Sequence[str] = aws_template_fields(
+        "input_data_config", "output_data_config", "data_access_role_arn", 
"language_code"
+    )
+
+    template_fields_renderers: dict = {"input_data_config": "json", 
"output_data_config": "json"}
+
+    def __init__(
+        self,
+        input_data_config: dict,
+        output_data_config: dict,
+        data_access_role_arn: str,
+        language_code: str,
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+        self.input_data_config = input_data_config
+        self.output_data_config = output_data_config
+        self.data_access_role_arn = data_access_role_arn
+        self.language_code = language_code
+
+    @cached_property
+    def client(self) -> boto3.client:
+        """Create and return the Comprehend client."""
+        return self.hook.conn
+
+    def execute(self, context: Context):
+        """Must overwrite in child classes."""
+        raise NotImplementedError("Please implement execute() in subclass")
+
+
+class ComprehendStartPiiEntitiesDetectionJobOperator(ComprehendBaseOperator):
+    """
+    Create a comprehend pii entities detection job for a collection of 
documents.
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the 
guide:
+        :ref:`howto/operator:ComprehendStartPiiEntitiesDetectionJobOperator`
+
+    :param input_data_config: The input properties for a PII entities 
detection job. (templated)
+    :param output_data_config: Provides `configuration` parameters for the 
output of PII entity detection
+        jobs. (templated)
+    :param mode: Specifies whether the output provides the locations (offsets) 
of PII  entities or a file in
+        which PII entities are redacted. If you set the mode parameter to 
ONLY_REDACTION. In that case you
+        must provide a RedactionConfig in start_pii_entities_kwargs.
+    :param data_access_role_arn: The Amazon Resource Name (ARN) of the IAM 
role that grants Amazon Comprehend
+        read access to your input data. (templated)
+    :param language_code: The language of the input documents. (templated)
+    :param start_pii_entities_kwargs: Any optional parameters to pass to the 
job. If JobName is not provided
+        in start_pii_entities_kwargs, operator will create.
+
+    :param wait_for_completion: Whether to wait for job to stop. (default: 
True)
+    :param waiter_delay: Time in seconds to wait between status checks. 
(default: 60)
+    :param waiter_max_attempts: Maximum number of attempts to check for job 
completion. (default: 20)
+    :param deferrable: If True, the operator will wait asynchronously for the 
job to stop.
+        This implies waiting for completion. This mode requires aiobotocore 
module to be installed.
+        (default: False)
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+        If this is ``None`` or empty then the default boto3 behaviour is used. 
If
+        running Airflow in a distributed manner and aws_conn_id is None or
+        empty, then default boto3 configuration would be used (and must be
+        maintained on each worker node).
+    :param region_name: AWS region_name. If not specified then the default 
boto3 behaviour is used.
+    :param verify: Whether to verify SSL certificates. See:
+        
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+    :param botocore_config: Configuration dictionary (key-values) for botocore 
client. See:
+        
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
+    """
+
+    def __init__(
+        self,
+        input_data_config: dict,
+        output_data_config: dict,
+        mode: str,
+        data_access_role_arn: str,
+        language_code: str,
+        start_pii_entities_kwargs: dict[str, Any] | None = None,
+        wait_for_completion: bool = True,
+        waiter_delay: int = 60,
+        waiter_max_attempts: int = 20,
+        deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
+        **kwargs,
+    ):
+        super().__init__(
+            input_data_config=input_data_config,
+            output_data_config=output_data_config,
+            data_access_role_arn=data_access_role_arn,
+            language_code=language_code,
+            **kwargs,
+        )
+        self.mode = mode
+        self.start_pii_entities_kwargs = start_pii_entities_kwargs or {}
+        self.wait_for_completion = wait_for_completion
+        self.waiter_delay = waiter_delay
+        self.waiter_max_attempts = waiter_max_attempts
+        self.deferrable = deferrable
+
+    def execute(self, context: Context) -> str:
+        if self.start_pii_entities_kwargs.get("JobName", None) is None:
+            self.start_pii_entities_kwargs["JobName"] = (
+                f"start_pii_entities_detection_job-{int(utcnow().timestamp())}"
+            )
+
+        self.log.info(
+            "Submitting start pii entities detection job '%s'.", 
self.start_pii_entities_kwargs["JobName"]
+        )
+        job_id = self.client.start_pii_entities_detection_job(
+            InputDataConfig=self.input_data_config,
+            OutputDataConfig=self.output_data_config,
+            Mode=self.mode,
+            DataAccessRoleArn=self.data_access_role_arn,
+            LanguageCode=self.language_code,
+            **self.start_pii_entities_kwargs,
+        )["JobId"]
+
+        message_description = f"start pii entities detection job {job_id} to 
complete."
+        if self.deferrable:
+            self.log.info("Deferring %s", message_description)
+            self.defer(
+                trigger=ComprehendPiiEntitiesDetectionJobCompletedTrigger(
+                    job_id=job_id,
+                    waiter_delay=self.waiter_delay,
+                    waiter_max_attempts=self.waiter_max_attempts,
+                    aws_conn_id=self.aws_conn_id,
+                ),
+                method_name="execute_complete",
+            )
+        elif self.wait_for_completion:
+            self.log.info("Waiting for %s", message_description)
+            self.hook.get_waiter("pii_entities_detection_job_complete").wait(
+                JobId=job_id,
+                WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": 
self.waiter_max_attempts},
+            )
+
+        return job_id
+
+    def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> str:
+        event = validate_execute_complete_event(event)
+        if event["status"] != "success":
+            raise AirflowException("Error while running job: %s", event)
+
+        self.log.info("Comprehend pii entities detection job `%s` complete.", 
event["job_id"])
+        return event["job_id"]
diff --git a/airflow/providers/amazon/aws/sensors/comprehend.py 
b/airflow/providers/amazon/aws/sensors/comprehend.py
new file mode 100644
index 0000000000..8f0e328cbc
--- /dev/null
+++ b/airflow/providers/amazon/aws/sensors/comprehend.py
@@ -0,0 +1,147 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import abc
+from typing import TYPE_CHECKING, Any, Sequence
+
+from airflow.configuration import conf
+from airflow.exceptions import AirflowException, AirflowSkipException
+from airflow.providers.amazon.aws.hooks.comprehend import ComprehendHook
+from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
+from airflow.providers.amazon.aws.triggers.comprehend import 
ComprehendPiiEntitiesDetectionJobCompletedTrigger
+from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
+
+if TYPE_CHECKING:
+    from airflow.utils.context import Context
+
+
+class ComprehendBaseSensor(AwsBaseSensor[ComprehendHook]):
+    """
+    General sensor behavior for Amazon Comprehend.
+
+    Subclasses must implement following methods:
+        - ``get_state()``
+
+    Subclasses must set the following fields:
+        - ``INTERMEDIATE_STATES``
+        - ``FAILURE_STATES``
+        - ``SUCCESS_STATES``
+        - ``FAILURE_MESSAGE``
+
+    :param deferrable: If True, the sensor will operate in deferrable mode. 
This mode requires aiobotocore
+        module to be installed.
+        (default: False, but can be overridden in config file by setting 
default_deferrable to True)
+    """
+
+    aws_hook_class = ComprehendHook
+
+    INTERMEDIATE_STATES: tuple[str, ...] = ()
+    FAILURE_STATES: tuple[str, ...] = ()
+    SUCCESS_STATES: tuple[str, ...] = ()
+    FAILURE_MESSAGE = ""
+
+    ui_color = "#66c3ff"
+
+    def __init__(
+        self,
+        deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
+        **kwargs: Any,
+    ):
+        super().__init__(**kwargs)
+        self.deferrable = deferrable
+
+    def poke(self, context: Context, **kwargs) -> bool:
+        state = self.get_state()
+        if state in self.FAILURE_STATES:
+            # TODO: remove this if block when min_airflow_version is set to 
higher than 2.7.1
+            if self.soft_fail:
+                raise AirflowSkipException(self.FAILURE_MESSAGE)
+            raise AirflowException(self.FAILURE_MESSAGE)
+
+        return state not in self.INTERMEDIATE_STATES
+
+    @abc.abstractmethod
+    def get_state(self) -> str:
+        """Implement in subclasses."""
+
+
+class 
ComprehendStartPiiEntitiesDetectionJobCompletedSensor(ComprehendBaseSensor):
+    """
+    Poll the state of the pii entities detection job until it reaches a 
completed state; fails if the job fails.
+
+    .. seealso::
+        For more information on how to use this sensor, take a look at the 
guide:
+        
:ref:`howto/sensor:ComprehendStartPiiEntitiesDetectionJobCompletedSensor`
+
+    :param job_id: The id of the Comprehend pii entities detection job.
+
+    :param deferrable: If True, the sensor will operate in deferrable mode. 
This mode requires aiobotocore
+        module to be installed.
+        (default: False, but can be overridden in config file by setting 
default_deferrable to True)
+    :param poke_interval: Polling period in seconds to check for the status of 
the job. (default: 120)
+    :param max_retries: Number of times before returning the current state. 
(default: 75)
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+        If this is ``None`` or empty then the default boto3 behaviour is used. 
If
+        running Airflow in a distributed manner and aws_conn_id is None or
+        empty, then default boto3 configuration would be used (and must be
+        maintained on each worker node).
+    :param region_name: AWS region_name. If not specified then the default 
boto3 behaviour is used.
+    :param verify: Whether to verify SSL certificates. See:
+        
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+    :param botocore_config: Configuration dictionary (key-values) for botocore 
client. See:
+        
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
+    """
+
+    INTERMEDIATE_STATES: tuple[str, ...] = ("IN_PROGRESS",)
+    FAILURE_STATES: tuple[str, ...] = ("FAILED", "STOP_REQUESTED", "STOPPED")
+    SUCCESS_STATES: tuple[str, ...] = ("COMPLETED",)
+    FAILURE_MESSAGE = "Comprehend start pii entities detection job sensor 
failed."
+
+    template_fields: Sequence[str] = aws_template_fields("job_id")
+
+    def __init__(
+        self,
+        *,
+        job_id: str,
+        max_retries: int = 75,
+        poke_interval: int = 120,
+        **kwargs: Any,
+    ) -> None:
+        super().__init__(**kwargs)
+        self.job_id = job_id
+        self.max_retries = max_retries
+        self.poke_interval = poke_interval
+
+    def execute(self, context: Context) -> Any:
+        if self.deferrable:
+            self.defer(
+                trigger=ComprehendPiiEntitiesDetectionJobCompletedTrigger(
+                    job_id=self.job_id,
+                    waiter_delay=int(self.poke_interval),
+                    waiter_max_attempts=self.max_retries,
+                    aws_conn_id=self.aws_conn_id,
+                ),
+                method_name="poke",
+            )
+        else:
+            super().execute(context=context)
+
+    def get_state(self) -> str:
+        return 
self.hook.conn.describe_pii_entities_detection_job(JobId=self.job_id)[
+            "PiiEntitiesDetectionJobProperties"
+        ]["JobStatus"]
diff --git a/airflow/providers/amazon/aws/triggers/comprehend.py 
b/airflow/providers/amazon/aws/triggers/comprehend.py
new file mode 100644
index 0000000000..7de6650c87
--- /dev/null
+++ b/airflow/providers/amazon/aws/triggers/comprehend.py
@@ -0,0 +1,61 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+    from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
+
+from airflow.providers.amazon.aws.hooks.comprehend import ComprehendHook
+from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
+
+
+class ComprehendPiiEntitiesDetectionJobCompletedTrigger(AwsBaseWaiterTrigger):
+    """
+    Trigger when a Comprehend pii entities detection job is complete.
+
+    :param job_id: The id of the Comprehend pii entities detection job.
+    :param waiter_delay: The amount of time in seconds to wait between 
attempts. (default: 120)
+    :param waiter_max_attempts: The maximum number of attempts to be made. 
(default: 75)
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+    """
+
+    def __init__(
+        self,
+        *,
+        job_id: str,
+        waiter_delay: int = 120,
+        waiter_max_attempts: int = 75,
+        aws_conn_id: str | None = "aws_default",
+    ) -> None:
+        super().__init__(
+            serialized_fields={"job_id": job_id},
+            waiter_name="pii_entities_detection_job_complete",
+            waiter_args={"JobId": job_id},
+            failure_message="Comprehend start pii entities detection job 
failed.",
+            status_message="Status of Comprehend start pii entities detection 
job is",
+            status_queries=["PiiEntitiesDetectionJobProperties.JobStatus"],
+            return_key="job_id",
+            return_value=job_id,
+            waiter_delay=waiter_delay,
+            waiter_max_attempts=waiter_max_attempts,
+            aws_conn_id=aws_conn_id,
+        )
+
+    def hook(self) -> AwsGenericHook:
+        return ComprehendHook(aws_conn_id=self.aws_conn_id)
diff --git a/airflow/providers/amazon/aws/waiters/comprehend.json 
b/airflow/providers/amazon/aws/waiters/comprehend.json
new file mode 100644
index 0000000000..9df82f319f
--- /dev/null
+++ b/airflow/providers/amazon/aws/waiters/comprehend.json
@@ -0,0 +1,49 @@
+{
+    "version": 2,
+    "waiters": {
+        "pii_entities_detection_job_complete": {
+            "delay": 120,
+            "maxAttempts": 75,
+            "operation": "DescribePiiEntitiesDetectionJob",
+            "acceptors": [
+                {
+                    "matcher": "path",
+                    "argument": "PiiEntitiesDetectionJobProperties.JobStatus",
+                    "expected": "SUBMITTED",
+                    "state": "retry"
+                },
+                {
+                    "matcher": "path",
+                    "argument": "PiiEntitiesDetectionJobProperties.JobStatus",
+                    "expected": "IN_PROGRESS",
+                    "state": "retry"
+                },
+                {
+                    "matcher": "path",
+                    "argument": "PiiEntitiesDetectionJobProperties.JobStatus",
+                    "expected": "COMPLETED",
+                    "state": "success"
+                },
+                {
+                    "matcher": "path",
+                    "argument": "PiiEntitiesDetectionJobProperties.JobStatus",
+                    "expected": "FAILED",
+                    "state": "failure"
+                },
+                {
+                    "matcher": "path",
+                    "argument": "PiiEntitiesDetectionJobProperties.JobStatus",
+                    "expected": "STOP_REQUESTED",
+                    "state": "failure"
+                },
+                {
+                    "matcher": "path",
+                    "argument": "PiiEntitiesDetectionJobProperties.JobStatus",
+                    "expected": "STOPPED",
+                    "state": "failure"
+                }
+
+            ]
+        }
+    }
+}
diff --git a/airflow/providers/amazon/provider.yaml 
b/airflow/providers/amazon/provider.yaml
index 641c315d62..7c06879143 100644
--- a/airflow/providers/amazon/provider.yaml
+++ b/airflow/providers/amazon/provider.yaml
@@ -166,6 +166,12 @@ integrations:
     external-doc-url: https://aws.amazon.com/cloudwatch/
     logo: /integration-logos/aws/[email protected]
     tags: [aws]
+  - integration-name: Amazon Comprehend
+    external-doc-url: https://aws.amazon.com/comprehend/
+    logo: /integration-logos/aws/[email protected]
+    how-to-guide:
+      - /docs/apache-airflow-providers-amazon/operators/comprehend.rst
+    tags: [aws]
   - integration-name: Amazon DataSync
     external-doc-url: https://aws.amazon.com/datasync/
     how-to-guide:
@@ -385,6 +391,9 @@ operators:
   - integration-name: Amazon CloudFormation
     python-modules:
       - airflow.providers.amazon.aws.operators.cloud_formation
+  - integration-name: Amazon Comprehend
+    python-modules:
+      - airflow.providers.amazon.aws.operators.comprehend
   - integration-name: Amazon DataSync
     python-modules:
       - airflow.providers.amazon.aws.operators.datasync
@@ -470,6 +479,9 @@ sensors:
   - integration-name: Amazon CloudFormation
     python-modules:
       - airflow.providers.amazon.aws.sensors.cloud_formation
+  - integration-name: Amazon Comprehend
+    python-modules:
+      - airflow.providers.amazon.aws.sensors.comprehend
   - integration-name: AWS Database Migration Service
     python-modules:
       - airflow.providers.amazon.aws.sensors.dms
@@ -545,6 +557,9 @@ hooks:
   - integration-name: Amazon Chime
     python-modules:
       - airflow.providers.amazon.aws.hooks.chime
+  - integration-name: Amazon Comprehend
+    python-modules:
+      - airflow.providers.amazon.aws.hooks.comprehend
   - integration-name: Amazon DynamoDB
     python-modules:
       - airflow.providers.amazon.aws.hooks.dynamodb
@@ -672,6 +687,9 @@ triggers:
   - integration-name: Amazon Bedrock
     python-modules:
       - airflow.providers.amazon.aws.triggers.bedrock
+  - integration-name: Amazon Comprehend
+    python-modules:
+      - airflow.providers.amazon.aws.triggers.comprehend
   - integration-name: Amazon EC2
     python-modules:
       - airflow.providers.amazon.aws.triggers.ec2
diff --git a/docs/apache-airflow-providers-amazon/operators/comprehend.rst 
b/docs/apache-airflow-providers-amazon/operators/comprehend.rst
new file mode 100644
index 0000000000..dd79e2df6a
--- /dev/null
+++ b/docs/apache-airflow-providers-amazon/operators/comprehend.rst
@@ -0,0 +1,74 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+    or more contributor license agreements.  See the NOTICE file
+    distributed with this work for additional information
+    regarding copyright ownership.  The ASF licenses this file
+    to you under the Apache License, Version 2.0 (the
+    "License"); you may not use this file except in compliance
+    with the License.  You may obtain a copy of the License at
+
+ ..   http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+    software distributed under the License is distributed on an
+    "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+    KIND, either express or implied.  See the License for the
+    specific language governing permissions and limitations
+    under the License.
+
+=================
+Amazon Comprehend
+=================
+
+`Amazon Comprehend <https://aws.amazon.com/comprehend/>`__ uses natural 
language processing (NLP) to
+extract insights about the content of documents. It develops insights by 
recognizing the entities, key phrases,
+language, sentiments, and other common elements in a document.
+
+Prerequisite Tasks
+------------------
+
+.. include:: ../_partials/prerequisite_tasks.rst
+
+Generic Parameters
+------------------
+
+.. include:: ../_partials/generic_parameters.rst
+
+Operators
+---------
+
+.. _howto/operator:ComprehendStartPiiEntitiesDetectionJobOperator:
+
+Create an Amazon Comprehend Start PII Entities Detection Job
+============================================================
+
+To create an Amazon Comprehend Start PII Entities Detection Job, you can use
+:class:`~airflow.providers.amazon.aws.operators.comprehend.ComprehendStartPiiEntitiesDetectionJobOperator`.
+
+.. exampleinclude:: 
/../../tests/system/providers/amazon/aws/example_comprehend.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_operator_start_pii_entities_detection_job]
+    :end-before: [END howto_operator_start_pii_entities_detection_job]
+
+Sensors
+-------
+
+.. _howto/sensor:ComprehendStartPiiEntitiesDetectionJobCompletedSensor:
+
+Wait for an Amazon Comprehend Start PII Entities Detection Job
+==============================================================
+
+To wait on the state of an Amazon Comprehend Start PII Entities Detection Job 
until it reaches a terminal
+state you can use
+:class:`~airflow.providers.amazon.aws.sensors.comprehend.ComprehendStartPiiEntitiesDetectionJobCompletedSensor`.
+
+.. exampleinclude:: 
/../../tests/system/providers/amazon/aws/example_comprehend.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_sensor_start_pii_entities_detection_job]
+    :end-before: [END howto_sensor_start_pii_entities_detection_job]
+
+Reference
+---------
+
+* `AWS boto3 library documentation for Amazon Comprehend 
<https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/comprehend.html>`__
diff --git a/docs/integration-logos/aws/[email protected] 
b/docs/integration-logos/aws/[email protected]
new file mode 100644
index 0000000000..24e8c34962
Binary files /dev/null and 
b/docs/integration-logos/aws/[email protected] differ
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index da04944231..40b744fc5d 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -1177,6 +1177,8 @@ picklable
 pid
 pidbox
 pigcmd
+Pii
+pii
 pinecone
 pinodb
 Pinot
diff --git a/tests/always/test_project_structure.py 
b/tests/always/test_project_structure.py
index 7ad127ac3b..428570ee58 100644
--- a/tests/always/test_project_structure.py
+++ b/tests/always/test_project_structure.py
@@ -522,6 +522,8 @@ class 
TestAmazonProviderProjectStructure(ExampleCoverageTest):
         "airflow.providers.amazon.aws.sensors.ecs.EcsBaseSensor",
         "airflow.providers.amazon.aws.sensors.eks.EksBaseSensor",
         "airflow.providers.amazon.aws.transfers.base.AwsToAwsBaseOperator",
+        
"airflow.providers.amazon.aws.operators.comprehend.ComprehendBaseOperator",
+        "airflow.providers.amazon.aws.sensors.comprehend.ComprehendBaseSensor",
     }
 
     MISSING_EXAMPLES_FOR_CLASSES = {
diff --git a/tests/providers/amazon/aws/hooks/test_comprehend.py 
b/tests/providers/amazon/aws/hooks/test_comprehend.py
new file mode 100644
index 0000000000..fded25e446
--- /dev/null
+++ b/tests/providers/amazon/aws/hooks/test_comprehend.py
@@ -0,0 +1,31 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import pytest
+
+from airflow.providers.amazon.aws.hooks.comprehend import ComprehendHook
+
+
+class TestComprehendHook:
+    @pytest.mark.parametrize(
+        "test_hook, service_name",
+        [pytest.param(ComprehendHook(), "comprehend", id="comprehend")],
+    )
+    def test_comprehend_hook(self, test_hook, service_name):
+        comprehend_hook = ComprehendHook()
+        assert comprehend_hook.conn is not None
diff --git a/tests/providers/amazon/aws/operators/test_comprehend.py 
b/tests/providers/amazon/aws/operators/test_comprehend.py
new file mode 100644
index 0000000000..b970b590ad
--- /dev/null
+++ b/tests/providers/amazon/aws/operators/test_comprehend.py
@@ -0,0 +1,163 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Generator
+from unittest import mock
+
+import pytest
+from moto import mock_aws
+
+from airflow.providers.amazon.aws.hooks.comprehend import ComprehendHook
+from airflow.providers.amazon.aws.operators.comprehend import (
+    ComprehendBaseOperator,
+    ComprehendStartPiiEntitiesDetectionJobOperator,
+)
+from airflow.utils.types import NOTSET
+
+if TYPE_CHECKING:
+    from airflow.providers.amazon.aws.hooks.base_aws import BaseAwsConnection
+
+INPUT_DATA_CONFIG = {
+    "S3Uri": "s3://input-data-comprehend/sample_data.txt",
+    "InputFormat": "ONE_DOC_PER_LINE",
+}
+OUTPUT_DATA_CONFIG = {"S3Uri": "s3://output-data-comprehend/redacted_output/"}
+LANGUAGE_CODE = "en"
+ROLE_ARN = "role_arn"
+
+
+class TestComprehendBaseOperator:
+    @pytest.mark.parametrize("aws_conn_id", [None, NOTSET, "aws_test_conn"])
+    @pytest.mark.parametrize("region_name", [None, NOTSET, "ca-central-1"])
+    def test_initialize_comprehend_base_operator(self, aws_conn_id, 
region_name):
+        op_kw = {"aws_conn_id": aws_conn_id, "region_name": region_name}
+        op_kw = {k: v for k, v in op_kw.items() if v is not NOTSET}
+
+        comprehend_base_op = ComprehendBaseOperator(
+            task_id="comprehend_base_operator",
+            input_data_config=INPUT_DATA_CONFIG,
+            output_data_config=OUTPUT_DATA_CONFIG,
+            language_code=LANGUAGE_CODE,
+            data_access_role_arn=ROLE_ARN,
+            **op_kw,
+        )
+
+        assert comprehend_base_op.aws_conn_id == (aws_conn_id if aws_conn_id 
is not NOTSET else "aws_default")
+        assert comprehend_base_op.region_name == (region_name if region_name 
is not NOTSET else None)
+
+    @mock.patch.object(ComprehendBaseOperator, "hook", 
new_callable=mock.PropertyMock)
+    def test_initialize_comprehend_base_operator_hook(self, 
comprehend_base_operator_mock_hook):
+        comprehend_base_op = ComprehendBaseOperator(
+            task_id="comprehend_base_operator",
+            input_data_config=INPUT_DATA_CONFIG,
+            output_data_config=OUTPUT_DATA_CONFIG,
+            language_code=LANGUAGE_CODE,
+            data_access_role_arn=ROLE_ARN,
+        )
+        mocked_hook = mock.MagicMock(name="MockHook")
+        mocked_client = mock.MagicMock(name="MockClient")
+        mocked_hook.conn = mocked_client
+        comprehend_base_operator_mock_hook.return_value = mocked_hook
+        assert comprehend_base_op.client == mocked_client
+        comprehend_base_operator_mock_hook.assert_called_once()
+
+
+class TestComprehendStartPiiEntitiesDetectionJobOperator:
+    JOB_ID = "random-job-id-1234567"
+    MODE = "ONLY_REDACTION"
+    JOB_NAME = "TEST_START_PII_ENTITIES_DETECTION_JOB-1"
+    DEFAULT_JOB_NAME_STARTS_WITH = "start_pii_entities_detection_job"
+    REDACTION_CONFIG = {"PiiEntityTypes": ["NAME", "ADDRESS"], "MaskMode": 
"REPLACE_WITH_PII_ENTITY_TYPE"}
+
+    @pytest.fixture
+    def mock_conn(self) -> Generator[BaseAwsConnection, None, None]:
+        with mock.patch.object(ComprehendHook, "conn") as _conn:
+            _conn.start_pii_entities_detection_job.return_value = {"JobId": 
self.JOB_ID}
+            yield _conn
+
+    @pytest.fixture
+    def comprehend_hook(self) -> Generator[ComprehendHook, None, None]:
+        with mock_aws():
+            hook = ComprehendHook(aws_conn_id="aws_default")
+            yield hook
+
+    def setup_method(self):
+        self.operator = ComprehendStartPiiEntitiesDetectionJobOperator(
+            task_id="start_pii_entities_detection_job",
+            input_data_config=INPUT_DATA_CONFIG,
+            output_data_config=OUTPUT_DATA_CONFIG,
+            data_access_role_arn=ROLE_ARN,
+            mode=self.MODE,
+            language_code=LANGUAGE_CODE,
+            start_pii_entities_kwargs={"JobName": self.JOB_NAME, 
"RedactionConfig": self.REDACTION_CONFIG},
+        )
+        self.operator.defer = mock.MagicMock()
+
+    def test_init(self):
+        assert self.operator.input_data_config == INPUT_DATA_CONFIG
+        assert self.operator.output_data_config == OUTPUT_DATA_CONFIG
+        assert self.operator.data_access_role_arn == ROLE_ARN
+        assert self.operator.mode == self.MODE
+        assert self.operator.language_code == LANGUAGE_CODE
+        assert self.operator.start_pii_entities_kwargs.get("JobName") == 
self.JOB_NAME
+        assert self.operator.start_pii_entities_kwargs.get("RedactionConfig") 
== self.REDACTION_CONFIG
+
+    @mock.patch.object(ComprehendHook, "conn")
+    def 
test_start_pii_entities_detection_job_name_starts_with_service_name(self, 
comprehend_mock_conn):
+        self.op = ComprehendStartPiiEntitiesDetectionJobOperator(
+            task_id="start_pii_entities_detection_job",
+            input_data_config=INPUT_DATA_CONFIG,
+            output_data_config=OUTPUT_DATA_CONFIG,
+            data_access_role_arn=ROLE_ARN,
+            mode=self.MODE,
+            language_code=LANGUAGE_CODE,
+            start_pii_entities_kwargs={"RedactionConfig": 
self.REDACTION_CONFIG},
+        )
+        self.op.wait_for_completion = False
+        self.op.execute({})
+        assert 
self.op.start_pii_entities_kwargs.get("JobName").startswith(self.DEFAULT_JOB_NAME_STARTS_WITH)
+        
comprehend_mock_conn.start_pii_entities_detection_job.assert_called_once_with(
+            InputDataConfig=INPUT_DATA_CONFIG,
+            OutputDataConfig=OUTPUT_DATA_CONFIG,
+            Mode=self.MODE,
+            DataAccessRoleArn=ROLE_ARN,
+            LanguageCode=LANGUAGE_CODE,
+            RedactionConfig=self.REDACTION_CONFIG,
+            JobName=self.op.start_pii_entities_kwargs.get("JobName"),
+        )
+
+    @pytest.mark.parametrize(
+        "wait_for_completion, deferrable",
+        [
+            pytest.param(False, False, id="no_wait"),
+            pytest.param(True, False, id="wait"),
+            pytest.param(False, True, id="defer"),
+        ],
+    )
+    @mock.patch.object(ComprehendHook, "get_waiter")
+    def test_start_pii_entities_detection_job_wait_combinations(
+        self, _, wait_for_completion, deferrable, mock_conn, comprehend_hook
+    ):
+        self.operator.wait_for_completion = wait_for_completion
+        self.operator.deferrable = deferrable
+
+        response = self.operator.execute({})
+
+        assert response == self.JOB_ID
+        assert comprehend_hook.get_waiter.call_count == wait_for_completion
+        assert self.operator.defer.call_count == deferrable
diff --git a/tests/providers/amazon/aws/sensors/test_comprehend.py 
b/tests/providers/amazon/aws/sensors/test_comprehend.py
new file mode 100644
index 0000000000..e066349031
--- /dev/null
+++ b/tests/providers/amazon/aws/sensors/test_comprehend.py
@@ -0,0 +1,94 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+
+import pytest
+
+from airflow.exceptions import AirflowException, AirflowSkipException
+from airflow.providers.amazon.aws.hooks.comprehend import ComprehendHook
+from airflow.providers.amazon.aws.sensors.comprehend import (
+    ComprehendStartPiiEntitiesDetectionJobCompletedSensor,
+)
+
+
+class TestComprehendStartPiiEntitiesDetectionJobCompletedSensor:
+    SENSOR = ComprehendStartPiiEntitiesDetectionJobCompletedSensor
+
+    def setup_method(self):
+        self.default_op_kwargs = dict(
+            task_id="test_pii_entities_detection_job_sensor",
+            job_id="job_id",
+            poke_interval=5,
+            max_retries=1,
+        )
+        self.sensor = self.SENSOR(**self.default_op_kwargs, aws_conn_id=None)
+
+    def test_base_aws_op_attributes(self):
+        op = self.SENSOR(**self.default_op_kwargs)
+        assert op.hook.aws_conn_id == "aws_default"
+        assert op.hook._region_name is None
+        assert op.hook._verify is None
+        assert op.hook._config is None
+
+        op = self.SENSOR(
+            **self.default_op_kwargs,
+            aws_conn_id="aws-test-custom-conn",
+            region_name="eu-west-1",
+            verify=False,
+            botocore_config={"read_timeout": 42},
+        )
+        assert op.hook.aws_conn_id == "aws-test-custom-conn"
+        assert op.hook._region_name == "eu-west-1"
+        assert op.hook._verify is False
+        assert op.hook._config is not None
+        assert op.hook._config.read_timeout == 42
+
+    @pytest.mark.parametrize("state", SENSOR.SUCCESS_STATES)
+    @mock.patch.object(ComprehendHook, "conn")
+    def test_poke_success_state(self, mock_conn, state):
+        mock_conn.describe_pii_entities_detection_job.return_value = {
+            "PiiEntitiesDetectionJobProperties": {"JobStatus": state}
+        }
+        assert self.sensor.poke({}) is True
+
+    @pytest.mark.parametrize("state", SENSOR.INTERMEDIATE_STATES)
+    @mock.patch.object(ComprehendHook, "conn")
+    def test_intermediate_state(self, mock_conn, state):
+        mock_conn.describe_pii_entities_detection_job.return_value = {
+            "PiiEntitiesDetectionJobProperties": {"JobStatus": state}
+        }
+        assert self.sensor.poke({}) is False
+
+    @pytest.mark.parametrize(
+        "soft_fail, expected_exception",
+        [
+            pytest.param(False, AirflowException, id="not-soft-fail"),
+            pytest.param(True, AirflowSkipException, id="soft-fail"),
+        ],
+    )
+    @pytest.mark.parametrize("state", SENSOR.FAILURE_STATES)
+    @mock.patch.object(ComprehendHook, "conn")
+    def test_poke_failure_states(self, mock_conn, state, soft_fail, 
expected_exception):
+        mock_conn.describe_pii_entities_detection_job.return_value = {
+            "PiiEntitiesDetectionJobProperties": {"JobStatus": state}
+        }
+        sensor = self.SENSOR(**self.default_op_kwargs, aws_conn_id=None, 
soft_fail=soft_fail)
+
+        with pytest.raises(expected_exception, match=sensor.FAILURE_MESSAGE):
+            sensor.poke({})
diff --git a/tests/providers/amazon/aws/triggers/test_comprehend.py 
b/tests/providers/amazon/aws/triggers/test_comprehend.py
new file mode 100644
index 0000000000..1c52aa8810
--- /dev/null
+++ b/tests/providers/amazon/aws/triggers/test_comprehend.py
@@ -0,0 +1,67 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+from unittest.mock import AsyncMock
+
+import pytest
+
+from airflow.providers.amazon.aws.hooks.comprehend import ComprehendHook
+from airflow.providers.amazon.aws.triggers.comprehend import 
ComprehendPiiEntitiesDetectionJobCompletedTrigger
+from airflow.triggers.base import TriggerEvent
+from tests.providers.amazon.aws.utils.test_waiter import 
assert_expected_waiter_type
+
+BASE_TRIGGER_CLASSPATH = "airflow.providers.amazon.aws.triggers.comprehend."
+
+
+class TestBaseComprehendTrigger:
+    EXPECTED_WAITER_NAME: str | None = None
+    JOB_ID: str | None = None
+
+    def test_setup(self):
+        # Ensure that all subclasses have an expected waiter name set.
+        if self.__class__.__name__ != "TestBaseComprehendTrigger":
+            assert isinstance(self.EXPECTED_WAITER_NAME, str)
+            assert isinstance(self.JOB_ID, str)
+
+
+class 
TestComprehendPiiEntitiesDetectionJobCompletedTrigger(TestBaseComprehendTrigger):
+    EXPECTED_WAITER_NAME = "pii_entities_detection_job_complete"
+    JOB_ID = "job_id"
+
+    def test_serialization(self):
+        """Assert that arguments and classpath are correctly serialized."""
+        trigger = 
ComprehendPiiEntitiesDetectionJobCompletedTrigger(job_id=self.JOB_ID)
+        classpath, kwargs = trigger.serialize()
+        assert classpath == BASE_TRIGGER_CLASSPATH + 
"ComprehendPiiEntitiesDetectionJobCompletedTrigger"
+        assert kwargs.get("job_id") == self.JOB_ID
+
+    @pytest.mark.asyncio
+    @mock.patch.object(ComprehendHook, "get_waiter")
+    @mock.patch.object(ComprehendHook, "async_conn")
+    async def test_run_success(self, mock_async_conn, mock_get_waiter):
+        mock_async_conn.__aenter__.return_value = mock.MagicMock()
+        mock_get_waiter().wait = AsyncMock()
+        trigger = 
ComprehendPiiEntitiesDetectionJobCompletedTrigger(job_id=self.JOB_ID)
+
+        generator = trigger.run()
+        response = await generator.asend(None)
+
+        assert response == TriggerEvent({"status": "success", "job_id": 
self.JOB_ID})
+        assert_expected_waiter_type(mock_get_waiter, self.EXPECTED_WAITER_NAME)
+        mock_get_waiter().wait.assert_called_once()
diff --git a/tests/providers/amazon/aws/waiters/test_comprehend.py 
b/tests/providers/amazon/aws/waiters/test_comprehend.py
new file mode 100644
index 0000000000..a514ea198f
--- /dev/null
+++ b/tests/providers/amazon/aws/waiters/test_comprehend.py
@@ -0,0 +1,71 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+
+import boto3
+import botocore
+import pytest
+
+from airflow.providers.amazon.aws.hooks.comprehend import ComprehendHook
+from airflow.providers.amazon.aws.sensors.comprehend import (
+    ComprehendStartPiiEntitiesDetectionJobCompletedSensor,
+)
+
+
+class TestComprehendCustomWaiters:
+    def test_service_waiters(self):
+        assert "pii_entities_detection_job_complete" in 
ComprehendHook().list_waiters()
+
+
+class TestComprehendCustomWaitersBase:
+    @pytest.fixture(autouse=True)
+    def mock_conn(self, monkeypatch):
+        self.client = boto3.client("comprehend")
+        monkeypatch.setattr(ComprehendHook, "conn", self.client)
+
+
+class 
TestComprehendStartPiiEntitiesDetectionJobCompleteWaiter(TestComprehendCustomWaitersBase):
+    WAITER_NAME = "pii_entities_detection_job_complete"
+
+    @pytest.fixture
+    def mock_get_job(self):
+        with mock.patch.object(self.client, 
"describe_pii_entities_detection_job") as mock_getter:
+            yield mock_getter
+
+    @pytest.mark.parametrize("state", 
ComprehendStartPiiEntitiesDetectionJobCompletedSensor.SUCCESS_STATES)
+    def test_pii_entities_detection_job_complete(self, state, mock_get_job):
+        mock_get_job.return_value = {"PiiEntitiesDetectionJobProperties": 
{"JobStatus": state}}
+
+        ComprehendHook().get_waiter(self.WAITER_NAME).wait(JobId="job_id")
+
+    @pytest.mark.parametrize("state", 
ComprehendStartPiiEntitiesDetectionJobCompletedSensor.FAILURE_STATES)
+    def test_pii_entities_detection_job_failed(self, state, mock_get_job):
+        mock_get_job.return_value = {"PiiEntitiesDetectionJobProperties": 
{"JobStatus": state}}
+
+        with pytest.raises(botocore.exceptions.WaiterError):
+            ComprehendHook().get_waiter(self.WAITER_NAME).wait(JobId="job_id")
+
+    def test_pii_entities_detection_job_wait(self, mock_get_job):
+        wait = {"PiiEntitiesDetectionJobProperties": {"JobStatus": 
"IN_PROGRESS"}}
+        success = {"PiiEntitiesDetectionJobProperties": {"JobStatus": 
"COMPLETED"}}
+        mock_get_job.side_effect = [wait, wait, success]
+
+        ComprehendHook().get_waiter(self.WAITER_NAME).wait(
+            JobId="job_id", WaiterConfig={"Delay": 0.01, "MaxAttempts": 3}
+        )
diff --git a/tests/system/providers/amazon/aws/example_comprehend.py 
b/tests/system/providers/amazon/aws/example_comprehend.py
new file mode 100644
index 0000000000..58e34329b6
--- /dev/null
+++ b/tests/system/providers/amazon/aws/example_comprehend.py
@@ -0,0 +1,137 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import json
+from datetime import datetime
+
+from airflow import DAG
+from airflow.decorators import task_group
+from airflow.models.baseoperator import chain
+from airflow.providers.amazon.aws.operators.comprehend import 
ComprehendStartPiiEntitiesDetectionJobOperator
+from airflow.providers.amazon.aws.operators.s3 import (
+    S3CreateBucketOperator,
+    S3CreateObjectOperator,
+    S3DeleteBucketOperator,
+)
+from airflow.providers.amazon.aws.sensors.comprehend import (
+    ComprehendStartPiiEntitiesDetectionJobCompletedSensor,
+)
+from airflow.utils.trigger_rule import TriggerRule
+from tests.system.providers.amazon.aws.utils import SystemTestContextBuilder
+
+ROLE_ARN_KEY = "ROLE_ARN"
+sys_test_context_task = 
SystemTestContextBuilder().add_variable(ROLE_ARN_KEY).build()
+
+DAG_ID = "example_comprehend"
+INPUT_S3_KEY_START_PII_ENTITIES_DETECTION_JOB = 
"start-pii-entities-detection-job/sample_data.txt"
+
+SAMPLE_DATA = {
+    "username": "bob1234",
+    "name": "Bob",
+    "sex": "M",
+    "address": "1773 Raymond Ville Suite 682",
+    "mail": "[email protected]",
+}
+
+
+@task_group
+def pii_entities_detection_job_workflow():
+    # [START howto_operator_start_pii_entities_detection_job]
+    start_pii_entities_detection_job = 
ComprehendStartPiiEntitiesDetectionJobOperator(
+        task_id="start_pii_entities_detection_job",
+        input_data_config=input_data_configurations,
+        output_data_config=output_data_configurations,
+        mode="ONLY_REDACTION",
+        data_access_role_arn=test_context[ROLE_ARN_KEY],
+        language_code="en",
+        start_pii_entities_kwargs=pii_entities_kwargs,
+    )
+    # [END howto_operator_start_pii_entities_detection_job]
+    start_pii_entities_detection_job.wait_for_completion = False
+
+    # [START howto_sensor_start_pii_entities_detection_job]
+    await_start_pii_entities_detection_job = 
ComprehendStartPiiEntitiesDetectionJobCompletedSensor(
+        task_id="await_start_pii_entities_detection_job", 
job_id=start_pii_entities_detection_job.output
+    )
+    # [END howto_sensor_start_pii_entities_detection_job]
+
+    chain(start_pii_entities_detection_job, 
await_start_pii_entities_detection_job)
+
+
+with DAG(
+    dag_id=DAG_ID,
+    schedule="@once",
+    start_date=datetime(2021, 1, 1),
+    tags=["example"],
+    catchup=False,
+) as dag:
+    test_context = sys_test_context_task()
+    env_id = test_context["ENV_ID"]
+    bucket_name = f"{env_id}-comprehend"
+    input_data_configurations = {
+        "S3Uri": 
f"s3://{bucket_name}/{INPUT_S3_KEY_START_PII_ENTITIES_DETECTION_JOB}",
+        "InputFormat": "ONE_DOC_PER_LINE",
+    }
+    output_data_configurations = {"S3Uri": 
f"s3://{bucket_name}/redacted_output/"}
+    pii_entities_kwargs = {
+        "RedactionConfig": {
+            "PiiEntityTypes": ["NAME", "ADDRESS"],
+            "MaskMode": "REPLACE_WITH_PII_ENTITY_TYPE",
+        }
+    }
+
+    create_bucket = S3CreateBucketOperator(
+        task_id="create_bucket",
+        bucket_name=bucket_name,
+    )
+
+    upload_sample_data = S3CreateObjectOperator(
+        task_id="upload_sample_data",
+        s3_bucket=bucket_name,
+        s3_key=INPUT_S3_KEY_START_PII_ENTITIES_DETECTION_JOB,
+        data=json.dumps(SAMPLE_DATA),
+    )
+
+    delete_bucket = S3DeleteBucketOperator(
+        task_id="delete_bucket",
+        trigger_rule=TriggerRule.ALL_DONE,
+        bucket_name=bucket_name,
+        force_delete=True,
+    )
+
+    chain(
+        # TEST SETUP
+        test_context,
+        create_bucket,
+        upload_sample_data,
+        # TEST BODY
+        pii_entities_detection_job_workflow(),
+        # TEST TEARDOWN
+        delete_bucket,
+    )
+
+    from tests.system.utils.watcher import watcher
+
+    # This test needs watcher in order to properly mark success/failure
+    # when "tearDown" task with trigger rule is part of the DAG
+    list(dag.tasks) >> watcher()
+
+from tests.system.utils import get_test_run  # noqa: E402
+
+# Needed to run the example DAG with pytest (see: 
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)

Reply via email to