This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 9dd77520be Introduce Amazon Comprehend Service (#39592)
9dd77520be is described below
commit 9dd77520be3d8492156958d57b63b5779a3f55eb
Author: gopidesupavan <[email protected]>
AuthorDate: Wed May 15 15:05:53 2024 +0100
Introduce Amazon Comprehend Service (#39592)
---
airflow/providers/amazon/aws/hooks/comprehend.py | 37 ++++
.../providers/amazon/aws/operators/comprehend.py | 192 +++++++++++++++++++++
airflow/providers/amazon/aws/sensors/comprehend.py | 147 ++++++++++++++++
.../providers/amazon/aws/triggers/comprehend.py | 61 +++++++
.../providers/amazon/aws/waiters/comprehend.json | 49 ++++++
airflow/providers/amazon/provider.yaml | 18 ++
.../operators/comprehend.rst | 74 ++++++++
.../aws/[email protected] | Bin 0 -> 7254 bytes
docs/spelling_wordlist.txt | 2 +
tests/always/test_project_structure.py | 2 +
.../providers/amazon/aws/hooks/test_comprehend.py | 31 ++++
.../amazon/aws/operators/test_comprehend.py | 163 +++++++++++++++++
.../amazon/aws/sensors/test_comprehend.py | 94 ++++++++++
.../amazon/aws/triggers/test_comprehend.py | 67 +++++++
.../amazon/aws/waiters/test_comprehend.py | 71 ++++++++
.../providers/amazon/aws/example_comprehend.py | 137 +++++++++++++++
16 files changed, 1145 insertions(+)
diff --git a/airflow/providers/amazon/aws/hooks/comprehend.py
b/airflow/providers/amazon/aws/hooks/comprehend.py
new file mode 100644
index 0000000000..897aaf72ee
--- /dev/null
+++ b/airflow/providers/amazon/aws/hooks/comprehend.py
@@ -0,0 +1,37 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
+
+
+class ComprehendHook(AwsBaseHook):
+ """
+ Interact with AWS Comprehend.
+
+ Provide thin wrapper around
:external+boto3:py:class:`boto3.client("comprehend") <Comprehend.Client>`.
+
+ Additional arguments (such as ``aws_conn_id``) may be specified and
+ are passed down to the underlying AwsBaseHook.
+
+ .. seealso::
+ - :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
+ """
+
+ def __init__(self, *args, **kwargs) -> None:
+ kwargs["client_type"] = "comprehend"
+ super().__init__(*args, **kwargs)
diff --git a/airflow/providers/amazon/aws/operators/comprehend.py
b/airflow/providers/amazon/aws/operators/comprehend.py
new file mode 100644
index 0000000000..780e227af4
--- /dev/null
+++ b/airflow/providers/amazon/aws/operators/comprehend.py
@@ -0,0 +1,192 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from functools import cached_property
+from typing import TYPE_CHECKING, Any, Sequence
+
+from airflow.configuration import conf
+from airflow.exceptions import AirflowException
+from airflow.providers.amazon.aws.hooks.comprehend import ComprehendHook
+from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
+from airflow.providers.amazon.aws.triggers.comprehend import
ComprehendPiiEntitiesDetectionJobCompletedTrigger
+from airflow.providers.amazon.aws.utils import validate_execute_complete_event
+from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
+from airflow.utils.timezone import utcnow
+
+if TYPE_CHECKING:
+ import boto3
+
+ from airflow.utils.context import Context
+
+
+class ComprehendBaseOperator(AwsBaseOperator[ComprehendHook]):
+ """
+ This is the base operator for Comprehend Service operators (not supposed
to be used directly in DAGs).
+
+ :param input_data_config: The input properties for a PII entities
detection job. (templated)
+ :param output_data_config: Provides `configuration` parameters for the
output of PII entity detection
+ jobs. (templated)
+ :param data_access_role_arn: The Amazon Resource Name (ARN) of the IAM
role that grants Amazon Comprehend
+ read access to your input data. (templated)
+ :param language_code: The language of the input documents. (templated)
+ """
+
+ aws_hook_class = ComprehendHook
+
+ template_fields: Sequence[str] = aws_template_fields(
+ "input_data_config", "output_data_config", "data_access_role_arn",
"language_code"
+ )
+
+ template_fields_renderers: dict = {"input_data_config": "json",
"output_data_config": "json"}
+
+ def __init__(
+ self,
+ input_data_config: dict,
+ output_data_config: dict,
+ data_access_role_arn: str,
+ language_code: str,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.input_data_config = input_data_config
+ self.output_data_config = output_data_config
+ self.data_access_role_arn = data_access_role_arn
+ self.language_code = language_code
+
+ @cached_property
+ def client(self) -> boto3.client:
+ """Create and return the Comprehend client."""
+ return self.hook.conn
+
+ def execute(self, context: Context):
+ """Must overwrite in child classes."""
+ raise NotImplementedError("Please implement execute() in subclass")
+
+
+class ComprehendStartPiiEntitiesDetectionJobOperator(ComprehendBaseOperator):
+ """
+ Create a comprehend pii entities detection job for a collection of
documents.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:ComprehendStartPiiEntitiesDetectionJobOperator`
+
+ :param input_data_config: The input properties for a PII entities
detection job. (templated)
+ :param output_data_config: Provides `configuration` parameters for the
output of PII entity detection
+ jobs. (templated)
+ :param mode: Specifies whether the output provides the locations (offsets)
of PII entities or a file in
+ which PII entities are redacted. If you set the mode parameter to
ONLY_REDACTION. In that case you
+ must provide a RedactionConfig in start_pii_entities_kwargs.
+ :param data_access_role_arn: The Amazon Resource Name (ARN) of the IAM
role that grants Amazon Comprehend
+ read access to your input data. (templated)
+ :param language_code: The language of the input documents. (templated)
+ :param start_pii_entities_kwargs: Any optional parameters to pass to the
job. If JobName is not provided
+ in start_pii_entities_kwargs, operator will create.
+
+ :param wait_for_completion: Whether to wait for job to stop. (default:
True)
+ :param waiter_delay: Time in seconds to wait between status checks.
(default: 60)
+ :param waiter_max_attempts: Maximum number of attempts to check for job
completion. (default: 20)
+ :param deferrable: If True, the operator will wait asynchronously for the
job to stop.
+ This implies waiting for completion. This mode requires aiobotocore
module to be installed.
+ (default: False)
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ If this is ``None`` or empty then the default boto3 behaviour is used.
If
+ running Airflow in a distributed manner and aws_conn_id is None or
+ empty, then default boto3 configuration would be used (and must be
+ maintained on each worker node).
+ :param region_name: AWS region_name. If not specified then the default
boto3 behaviour is used.
+ :param verify: Whether to verify SSL certificates. See:
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+ :param botocore_config: Configuration dictionary (key-values) for botocore
client. See:
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
+ """
+
+ def __init__(
+ self,
+ input_data_config: dict,
+ output_data_config: dict,
+ mode: str,
+ data_access_role_arn: str,
+ language_code: str,
+ start_pii_entities_kwargs: dict[str, Any] | None = None,
+ wait_for_completion: bool = True,
+ waiter_delay: int = 60,
+ waiter_max_attempts: int = 20,
+ deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
+ **kwargs,
+ ):
+ super().__init__(
+ input_data_config=input_data_config,
+ output_data_config=output_data_config,
+ data_access_role_arn=data_access_role_arn,
+ language_code=language_code,
+ **kwargs,
+ )
+ self.mode = mode
+ self.start_pii_entities_kwargs = start_pii_entities_kwargs or {}
+ self.wait_for_completion = wait_for_completion
+ self.waiter_delay = waiter_delay
+ self.waiter_max_attempts = waiter_max_attempts
+ self.deferrable = deferrable
+
+ def execute(self, context: Context) -> str:
+ if self.start_pii_entities_kwargs.get("JobName", None) is None:
+ self.start_pii_entities_kwargs["JobName"] = (
+ f"start_pii_entities_detection_job-{int(utcnow().timestamp())}"
+ )
+
+ self.log.info(
+ "Submitting start pii entities detection job '%s'.",
self.start_pii_entities_kwargs["JobName"]
+ )
+ job_id = self.client.start_pii_entities_detection_job(
+ InputDataConfig=self.input_data_config,
+ OutputDataConfig=self.output_data_config,
+ Mode=self.mode,
+ DataAccessRoleArn=self.data_access_role_arn,
+ LanguageCode=self.language_code,
+ **self.start_pii_entities_kwargs,
+ )["JobId"]
+
+ message_description = f"start pii entities detection job {job_id} to
complete."
+ if self.deferrable:
+ self.log.info("Deferring %s", message_description)
+ self.defer(
+ trigger=ComprehendPiiEntitiesDetectionJobCompletedTrigger(
+ job_id=job_id,
+ waiter_delay=self.waiter_delay,
+ waiter_max_attempts=self.waiter_max_attempts,
+ aws_conn_id=self.aws_conn_id,
+ ),
+ method_name="execute_complete",
+ )
+ elif self.wait_for_completion:
+ self.log.info("Waiting for %s", message_description)
+ self.hook.get_waiter("pii_entities_detection_job_complete").wait(
+ JobId=job_id,
+ WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts":
self.waiter_max_attempts},
+ )
+
+ return job_id
+
+ def execute_complete(self, context: Context, event: dict[str, Any] | None
= None) -> str:
+ event = validate_execute_complete_event(event)
+ if event["status"] != "success":
+ raise AirflowException("Error while running job: %s", event)
+
+ self.log.info("Comprehend pii entities detection job `%s` complete.",
event["job_id"])
+ return event["job_id"]
diff --git a/airflow/providers/amazon/aws/sensors/comprehend.py
b/airflow/providers/amazon/aws/sensors/comprehend.py
new file mode 100644
index 0000000000..8f0e328cbc
--- /dev/null
+++ b/airflow/providers/amazon/aws/sensors/comprehend.py
@@ -0,0 +1,147 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import abc
+from typing import TYPE_CHECKING, Any, Sequence
+
+from airflow.configuration import conf
+from airflow.exceptions import AirflowException, AirflowSkipException
+from airflow.providers.amazon.aws.hooks.comprehend import ComprehendHook
+from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
+from airflow.providers.amazon.aws.triggers.comprehend import
ComprehendPiiEntitiesDetectionJobCompletedTrigger
+from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
+
+if TYPE_CHECKING:
+ from airflow.utils.context import Context
+
+
+class ComprehendBaseSensor(AwsBaseSensor[ComprehendHook]):
+ """
+ General sensor behavior for Amazon Comprehend.
+
+ Subclasses must implement following methods:
+ - ``get_state()``
+
+ Subclasses must set the following fields:
+ - ``INTERMEDIATE_STATES``
+ - ``FAILURE_STATES``
+ - ``SUCCESS_STATES``
+ - ``FAILURE_MESSAGE``
+
+ :param deferrable: If True, the sensor will operate in deferrable mode.
This mode requires aiobotocore
+ module to be installed.
+ (default: False, but can be overridden in config file by setting
default_deferrable to True)
+ """
+
+ aws_hook_class = ComprehendHook
+
+ INTERMEDIATE_STATES: tuple[str, ...] = ()
+ FAILURE_STATES: tuple[str, ...] = ()
+ SUCCESS_STATES: tuple[str, ...] = ()
+ FAILURE_MESSAGE = ""
+
+ ui_color = "#66c3ff"
+
+ def __init__(
+ self,
+ deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
+ **kwargs: Any,
+ ):
+ super().__init__(**kwargs)
+ self.deferrable = deferrable
+
+ def poke(self, context: Context, **kwargs) -> bool:
+ state = self.get_state()
+ if state in self.FAILURE_STATES:
+ # TODO: remove this if block when min_airflow_version is set to
higher than 2.7.1
+ if self.soft_fail:
+ raise AirflowSkipException(self.FAILURE_MESSAGE)
+ raise AirflowException(self.FAILURE_MESSAGE)
+
+ return state not in self.INTERMEDIATE_STATES
+
+ @abc.abstractmethod
+ def get_state(self) -> str:
+ """Implement in subclasses."""
+
+
+class
ComprehendStartPiiEntitiesDetectionJobCompletedSensor(ComprehendBaseSensor):
+ """
+ Poll the state of the pii entities detection job until it reaches a
completed state; fails if the job fails.
+
+ .. seealso::
+ For more information on how to use this sensor, take a look at the
guide:
+
:ref:`howto/sensor:ComprehendStartPiiEntitiesDetectionJobCompletedSensor`
+
+ :param job_id: The id of the Comprehend pii entities detection job.
+
+ :param deferrable: If True, the sensor will operate in deferrable mode.
This mode requires aiobotocore
+ module to be installed.
+ (default: False, but can be overridden in config file by setting
default_deferrable to True)
+ :param poke_interval: Polling period in seconds to check for the status of
the job. (default: 120)
+ :param max_retries: Number of times before returning the current state.
(default: 75)
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ If this is ``None`` or empty then the default boto3 behaviour is used.
If
+ running Airflow in a distributed manner and aws_conn_id is None or
+ empty, then default boto3 configuration would be used (and must be
+ maintained on each worker node).
+ :param region_name: AWS region_name. If not specified then the default
boto3 behaviour is used.
+ :param verify: Whether to verify SSL certificates. See:
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+ :param botocore_config: Configuration dictionary (key-values) for botocore
client. See:
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
+ """
+
+ INTERMEDIATE_STATES: tuple[str, ...] = ("IN_PROGRESS",)
+ FAILURE_STATES: tuple[str, ...] = ("FAILED", "STOP_REQUESTED", "STOPPED")
+ SUCCESS_STATES: tuple[str, ...] = ("COMPLETED",)
+ FAILURE_MESSAGE = "Comprehend start pii entities detection job sensor
failed."
+
+ template_fields: Sequence[str] = aws_template_fields("job_id")
+
+ def __init__(
+ self,
+ *,
+ job_id: str,
+ max_retries: int = 75,
+ poke_interval: int = 120,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.job_id = job_id
+ self.max_retries = max_retries
+ self.poke_interval = poke_interval
+
+ def execute(self, context: Context) -> Any:
+ if self.deferrable:
+ self.defer(
+ trigger=ComprehendPiiEntitiesDetectionJobCompletedTrigger(
+ job_id=self.job_id,
+ waiter_delay=int(self.poke_interval),
+ waiter_max_attempts=self.max_retries,
+ aws_conn_id=self.aws_conn_id,
+ ),
+ method_name="poke",
+ )
+ else:
+ super().execute(context=context)
+
+ def get_state(self) -> str:
+ return
self.hook.conn.describe_pii_entities_detection_job(JobId=self.job_id)[
+ "PiiEntitiesDetectionJobProperties"
+ ]["JobStatus"]
diff --git a/airflow/providers/amazon/aws/triggers/comprehend.py
b/airflow/providers/amazon/aws/triggers/comprehend.py
new file mode 100644
index 0000000000..7de6650c87
--- /dev/null
+++ b/airflow/providers/amazon/aws/triggers/comprehend.py
@@ -0,0 +1,61 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
+
+from airflow.providers.amazon.aws.hooks.comprehend import ComprehendHook
+from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
+
+
+class ComprehendPiiEntitiesDetectionJobCompletedTrigger(AwsBaseWaiterTrigger):
+ """
+ Trigger when a Comprehend pii entities detection job is complete.
+
+ :param job_id: The id of the Comprehend pii entities detection job.
+ :param waiter_delay: The amount of time in seconds to wait between
attempts. (default: 120)
+ :param waiter_max_attempts: The maximum number of attempts to be made.
(default: 75)
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ """
+
+ def __init__(
+ self,
+ *,
+ job_id: str,
+ waiter_delay: int = 120,
+ waiter_max_attempts: int = 75,
+ aws_conn_id: str | None = "aws_default",
+ ) -> None:
+ super().__init__(
+ serialized_fields={"job_id": job_id},
+ waiter_name="pii_entities_detection_job_complete",
+ waiter_args={"JobId": job_id},
+ failure_message="Comprehend start pii entities detection job
failed.",
+ status_message="Status of Comprehend start pii entities detection
job is",
+ status_queries=["PiiEntitiesDetectionJobProperties.JobStatus"],
+ return_key="job_id",
+ return_value=job_id,
+ waiter_delay=waiter_delay,
+ waiter_max_attempts=waiter_max_attempts,
+ aws_conn_id=aws_conn_id,
+ )
+
+ def hook(self) -> AwsGenericHook:
+ return ComprehendHook(aws_conn_id=self.aws_conn_id)
diff --git a/airflow/providers/amazon/aws/waiters/comprehend.json
b/airflow/providers/amazon/aws/waiters/comprehend.json
new file mode 100644
index 0000000000..9df82f319f
--- /dev/null
+++ b/airflow/providers/amazon/aws/waiters/comprehend.json
@@ -0,0 +1,49 @@
+{
+ "version": 2,
+ "waiters": {
+ "pii_entities_detection_job_complete": {
+ "delay": 120,
+ "maxAttempts": 75,
+ "operation": "DescribePiiEntitiesDetectionJob",
+ "acceptors": [
+ {
+ "matcher": "path",
+ "argument": "PiiEntitiesDetectionJobProperties.JobStatus",
+ "expected": "SUBMITTED",
+ "state": "retry"
+ },
+ {
+ "matcher": "path",
+ "argument": "PiiEntitiesDetectionJobProperties.JobStatus",
+ "expected": "IN_PROGRESS",
+ "state": "retry"
+ },
+ {
+ "matcher": "path",
+ "argument": "PiiEntitiesDetectionJobProperties.JobStatus",
+ "expected": "COMPLETED",
+ "state": "success"
+ },
+ {
+ "matcher": "path",
+ "argument": "PiiEntitiesDetectionJobProperties.JobStatus",
+ "expected": "FAILED",
+ "state": "failure"
+ },
+ {
+ "matcher": "path",
+ "argument": "PiiEntitiesDetectionJobProperties.JobStatus",
+ "expected": "STOP_REQUESTED",
+ "state": "failure"
+ },
+ {
+ "matcher": "path",
+ "argument": "PiiEntitiesDetectionJobProperties.JobStatus",
+ "expected": "STOPPED",
+ "state": "failure"
+ }
+
+ ]
+ }
+ }
+}
diff --git a/airflow/providers/amazon/provider.yaml
b/airflow/providers/amazon/provider.yaml
index 641c315d62..7c06879143 100644
--- a/airflow/providers/amazon/provider.yaml
+++ b/airflow/providers/amazon/provider.yaml
@@ -166,6 +166,12 @@ integrations:
external-doc-url: https://aws.amazon.com/cloudwatch/
logo: /integration-logos/aws/[email protected]
tags: [aws]
+ - integration-name: Amazon Comprehend
+ external-doc-url: https://aws.amazon.com/comprehend/
+ logo: /integration-logos/aws/[email protected]
+ how-to-guide:
+ - /docs/apache-airflow-providers-amazon/operators/comprehend.rst
+ tags: [aws]
- integration-name: Amazon DataSync
external-doc-url: https://aws.amazon.com/datasync/
how-to-guide:
@@ -385,6 +391,9 @@ operators:
- integration-name: Amazon CloudFormation
python-modules:
- airflow.providers.amazon.aws.operators.cloud_formation
+ - integration-name: Amazon Comprehend
+ python-modules:
+ - airflow.providers.amazon.aws.operators.comprehend
- integration-name: Amazon DataSync
python-modules:
- airflow.providers.amazon.aws.operators.datasync
@@ -470,6 +479,9 @@ sensors:
- integration-name: Amazon CloudFormation
python-modules:
- airflow.providers.amazon.aws.sensors.cloud_formation
+ - integration-name: Amazon Comprehend
+ python-modules:
+ - airflow.providers.amazon.aws.sensors.comprehend
- integration-name: AWS Database Migration Service
python-modules:
- airflow.providers.amazon.aws.sensors.dms
@@ -545,6 +557,9 @@ hooks:
- integration-name: Amazon Chime
python-modules:
- airflow.providers.amazon.aws.hooks.chime
+ - integration-name: Amazon Comprehend
+ python-modules:
+ - airflow.providers.amazon.aws.hooks.comprehend
- integration-name: Amazon DynamoDB
python-modules:
- airflow.providers.amazon.aws.hooks.dynamodb
@@ -672,6 +687,9 @@ triggers:
- integration-name: Amazon Bedrock
python-modules:
- airflow.providers.amazon.aws.triggers.bedrock
+ - integration-name: Amazon Comprehend
+ python-modules:
+ - airflow.providers.amazon.aws.triggers.comprehend
- integration-name: Amazon EC2
python-modules:
- airflow.providers.amazon.aws.triggers.ec2
diff --git a/docs/apache-airflow-providers-amazon/operators/comprehend.rst
b/docs/apache-airflow-providers-amazon/operators/comprehend.rst
new file mode 100644
index 0000000000..dd79e2df6a
--- /dev/null
+++ b/docs/apache-airflow-providers-amazon/operators/comprehend.rst
@@ -0,0 +1,74 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+=================
+Amazon Comprehend
+=================
+
+`Amazon Comprehend <https://aws.amazon.com/comprehend/>`__ uses natural
language processing (NLP) to
+extract insights about the content of documents. It develops insights by
recognizing the entities, key phrases,
+language, sentiments, and other common elements in a document.
+
+Prerequisite Tasks
+------------------
+
+.. include:: ../_partials/prerequisite_tasks.rst
+
+Generic Parameters
+------------------
+
+.. include:: ../_partials/generic_parameters.rst
+
+Operators
+---------
+
+.. _howto/operator:ComprehendStartPiiEntitiesDetectionJobOperator:
+
+Create an Amazon Comprehend Start PII Entities Detection Job
+============================================================
+
+To create an Amazon Comprehend Start PII Entities Detection Job, you can use
+:class:`~airflow.providers.amazon.aws.operators.comprehend.ComprehendStartPiiEntitiesDetectionJobOperator`.
+
+.. exampleinclude::
/../../tests/system/providers/amazon/aws/example_comprehend.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_start_pii_entities_detection_job]
+ :end-before: [END howto_operator_start_pii_entities_detection_job]
+
+Sensors
+-------
+
+.. _howto/sensor:ComprehendStartPiiEntitiesDetectionJobCompletedSensor:
+
+Wait for an Amazon Comprehend Start PII Entities Detection Job
+==============================================================
+
+To wait on the state of an Amazon Comprehend Start PII Entities Detection Job
until it reaches a terminal
+state you can use
+:class:`~airflow.providers.amazon.aws.sensors.comprehend.ComprehendStartPiiEntitiesDetectionJobCompletedSensor`.
+
+.. exampleinclude::
/../../tests/system/providers/amazon/aws/example_comprehend.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_sensor_start_pii_entities_detection_job]
+ :end-before: [END howto_sensor_start_pii_entities_detection_job]
+
+Reference
+---------
+
+* `AWS boto3 library documentation for Amazon Comprehend
<https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/comprehend.html>`__
diff --git a/docs/integration-logos/aws/[email protected]
b/docs/integration-logos/aws/[email protected]
new file mode 100644
index 0000000000..24e8c34962
Binary files /dev/null and
b/docs/integration-logos/aws/[email protected] differ
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index da04944231..40b744fc5d 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -1177,6 +1177,8 @@ picklable
pid
pidbox
pigcmd
+Pii
+pii
pinecone
pinodb
Pinot
diff --git a/tests/always/test_project_structure.py
b/tests/always/test_project_structure.py
index 7ad127ac3b..428570ee58 100644
--- a/tests/always/test_project_structure.py
+++ b/tests/always/test_project_structure.py
@@ -522,6 +522,8 @@ class
TestAmazonProviderProjectStructure(ExampleCoverageTest):
"airflow.providers.amazon.aws.sensors.ecs.EcsBaseSensor",
"airflow.providers.amazon.aws.sensors.eks.EksBaseSensor",
"airflow.providers.amazon.aws.transfers.base.AwsToAwsBaseOperator",
+
"airflow.providers.amazon.aws.operators.comprehend.ComprehendBaseOperator",
+ "airflow.providers.amazon.aws.sensors.comprehend.ComprehendBaseSensor",
}
MISSING_EXAMPLES_FOR_CLASSES = {
diff --git a/tests/providers/amazon/aws/hooks/test_comprehend.py
b/tests/providers/amazon/aws/hooks/test_comprehend.py
new file mode 100644
index 0000000000..fded25e446
--- /dev/null
+++ b/tests/providers/amazon/aws/hooks/test_comprehend.py
@@ -0,0 +1,31 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import pytest
+
+from airflow.providers.amazon.aws.hooks.comprehend import ComprehendHook
+
+
+class TestComprehendHook:
+ @pytest.mark.parametrize(
+ "test_hook, service_name",
+ [pytest.param(ComprehendHook(), "comprehend", id="comprehend")],
+ )
+ def test_comprehend_hook(self, test_hook, service_name):
+ comprehend_hook = ComprehendHook()
+ assert comprehend_hook.conn is not None
diff --git a/tests/providers/amazon/aws/operators/test_comprehend.py
b/tests/providers/amazon/aws/operators/test_comprehend.py
new file mode 100644
index 0000000000..b970b590ad
--- /dev/null
+++ b/tests/providers/amazon/aws/operators/test_comprehend.py
@@ -0,0 +1,163 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Generator
+from unittest import mock
+
+import pytest
+from moto import mock_aws
+
+from airflow.providers.amazon.aws.hooks.comprehend import ComprehendHook
+from airflow.providers.amazon.aws.operators.comprehend import (
+ ComprehendBaseOperator,
+ ComprehendStartPiiEntitiesDetectionJobOperator,
+)
+from airflow.utils.types import NOTSET
+
+if TYPE_CHECKING:
+ from airflow.providers.amazon.aws.hooks.base_aws import BaseAwsConnection
+
+INPUT_DATA_CONFIG = {
+ "S3Uri": "s3://input-data-comprehend/sample_data.txt",
+ "InputFormat": "ONE_DOC_PER_LINE",
+}
+OUTPUT_DATA_CONFIG = {"S3Uri": "s3://output-data-comprehend/redacted_output/"}
+LANGUAGE_CODE = "en"
+ROLE_ARN = "role_arn"
+
+
+class TestComprehendBaseOperator:
+ @pytest.mark.parametrize("aws_conn_id", [None, NOTSET, "aws_test_conn"])
+ @pytest.mark.parametrize("region_name", [None, NOTSET, "ca-central-1"])
+ def test_initialize_comprehend_base_operator(self, aws_conn_id,
region_name):
+ op_kw = {"aws_conn_id": aws_conn_id, "region_name": region_name}
+ op_kw = {k: v for k, v in op_kw.items() if v is not NOTSET}
+
+ comprehend_base_op = ComprehendBaseOperator(
+ task_id="comprehend_base_operator",
+ input_data_config=INPUT_DATA_CONFIG,
+ output_data_config=OUTPUT_DATA_CONFIG,
+ language_code=LANGUAGE_CODE,
+ data_access_role_arn=ROLE_ARN,
+ **op_kw,
+ )
+
+ assert comprehend_base_op.aws_conn_id == (aws_conn_id if aws_conn_id
is not NOTSET else "aws_default")
+ assert comprehend_base_op.region_name == (region_name if region_name
is not NOTSET else None)
+
+ @mock.patch.object(ComprehendBaseOperator, "hook",
new_callable=mock.PropertyMock)
+ def test_initialize_comprehend_base_operator_hook(self,
comprehend_base_operator_mock_hook):
+ comprehend_base_op = ComprehendBaseOperator(
+ task_id="comprehend_base_operator",
+ input_data_config=INPUT_DATA_CONFIG,
+ output_data_config=OUTPUT_DATA_CONFIG,
+ language_code=LANGUAGE_CODE,
+ data_access_role_arn=ROLE_ARN,
+ )
+ mocked_hook = mock.MagicMock(name="MockHook")
+ mocked_client = mock.MagicMock(name="MockClient")
+ mocked_hook.conn = mocked_client
+ comprehend_base_operator_mock_hook.return_value = mocked_hook
+ assert comprehend_base_op.client == mocked_client
+ comprehend_base_operator_mock_hook.assert_called_once()
+
+
+class TestComprehendStartPiiEntitiesDetectionJobOperator:
+ JOB_ID = "random-job-id-1234567"
+ MODE = "ONLY_REDACTION"
+ JOB_NAME = "TEST_START_PII_ENTITIES_DETECTION_JOB-1"
+ DEFAULT_JOB_NAME_STARTS_WITH = "start_pii_entities_detection_job"
+ REDACTION_CONFIG = {"PiiEntityTypes": ["NAME", "ADDRESS"], "MaskMode":
"REPLACE_WITH_PII_ENTITY_TYPE"}
+
+ @pytest.fixture
+ def mock_conn(self) -> Generator[BaseAwsConnection, None, None]:
+ with mock.patch.object(ComprehendHook, "conn") as _conn:
+ _conn.start_pii_entities_detection_job.return_value = {"JobId":
self.JOB_ID}
+ yield _conn
+
+ @pytest.fixture
+ def comprehend_hook(self) -> Generator[ComprehendHook, None, None]:
+ with mock_aws():
+ hook = ComprehendHook(aws_conn_id="aws_default")
+ yield hook
+
+ def setup_method(self):
+ self.operator = ComprehendStartPiiEntitiesDetectionJobOperator(
+ task_id="start_pii_entities_detection_job",
+ input_data_config=INPUT_DATA_CONFIG,
+ output_data_config=OUTPUT_DATA_CONFIG,
+ data_access_role_arn=ROLE_ARN,
+ mode=self.MODE,
+ language_code=LANGUAGE_CODE,
+ start_pii_entities_kwargs={"JobName": self.JOB_NAME,
"RedactionConfig": self.REDACTION_CONFIG},
+ )
+ self.operator.defer = mock.MagicMock()
+
+ def test_init(self):
+ assert self.operator.input_data_config == INPUT_DATA_CONFIG
+ assert self.operator.output_data_config == OUTPUT_DATA_CONFIG
+ assert self.operator.data_access_role_arn == ROLE_ARN
+ assert self.operator.mode == self.MODE
+ assert self.operator.language_code == LANGUAGE_CODE
+ assert self.operator.start_pii_entities_kwargs.get("JobName") ==
self.JOB_NAME
+ assert self.operator.start_pii_entities_kwargs.get("RedactionConfig")
== self.REDACTION_CONFIG
+
+ @mock.patch.object(ComprehendHook, "conn")
+ def
test_start_pii_entities_detection_job_name_starts_with_service_name(self,
comprehend_mock_conn):
+ self.op = ComprehendStartPiiEntitiesDetectionJobOperator(
+ task_id="start_pii_entities_detection_job",
+ input_data_config=INPUT_DATA_CONFIG,
+ output_data_config=OUTPUT_DATA_CONFIG,
+ data_access_role_arn=ROLE_ARN,
+ mode=self.MODE,
+ language_code=LANGUAGE_CODE,
+ start_pii_entities_kwargs={"RedactionConfig":
self.REDACTION_CONFIG},
+ )
+ self.op.wait_for_completion = False
+ self.op.execute({})
+ assert
self.op.start_pii_entities_kwargs.get("JobName").startswith(self.DEFAULT_JOB_NAME_STARTS_WITH)
+
comprehend_mock_conn.start_pii_entities_detection_job.assert_called_once_with(
+ InputDataConfig=INPUT_DATA_CONFIG,
+ OutputDataConfig=OUTPUT_DATA_CONFIG,
+ Mode=self.MODE,
+ DataAccessRoleArn=ROLE_ARN,
+ LanguageCode=LANGUAGE_CODE,
+ RedactionConfig=self.REDACTION_CONFIG,
+ JobName=self.op.start_pii_entities_kwargs.get("JobName"),
+ )
+
+ @pytest.mark.parametrize(
+ "wait_for_completion, deferrable",
+ [
+ pytest.param(False, False, id="no_wait"),
+ pytest.param(True, False, id="wait"),
+ pytest.param(False, True, id="defer"),
+ ],
+ )
+ @mock.patch.object(ComprehendHook, "get_waiter")
+ def test_start_pii_entities_detection_job_wait_combinations(
+ self, _, wait_for_completion, deferrable, mock_conn, comprehend_hook
+ ):
+ self.operator.wait_for_completion = wait_for_completion
+ self.operator.deferrable = deferrable
+
+ response = self.operator.execute({})
+
+ assert response == self.JOB_ID
+ assert comprehend_hook.get_waiter.call_count == wait_for_completion
+ assert self.operator.defer.call_count == deferrable
diff --git a/tests/providers/amazon/aws/sensors/test_comprehend.py
b/tests/providers/amazon/aws/sensors/test_comprehend.py
new file mode 100644
index 0000000000..e066349031
--- /dev/null
+++ b/tests/providers/amazon/aws/sensors/test_comprehend.py
@@ -0,0 +1,94 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+
+import pytest
+
+from airflow.exceptions import AirflowException, AirflowSkipException
+from airflow.providers.amazon.aws.hooks.comprehend import ComprehendHook
+from airflow.providers.amazon.aws.sensors.comprehend import (
+ ComprehendStartPiiEntitiesDetectionJobCompletedSensor,
+)
+
+
+class TestComprehendStartPiiEntitiesDetectionJobCompletedSensor:
+ SENSOR = ComprehendStartPiiEntitiesDetectionJobCompletedSensor
+
+ def setup_method(self):
+ self.default_op_kwargs = dict(
+ task_id="test_pii_entities_detection_job_sensor",
+ job_id="job_id",
+ poke_interval=5,
+ max_retries=1,
+ )
+ self.sensor = self.SENSOR(**self.default_op_kwargs, aws_conn_id=None)
+
+ def test_base_aws_op_attributes(self):
+ op = self.SENSOR(**self.default_op_kwargs)
+ assert op.hook.aws_conn_id == "aws_default"
+ assert op.hook._region_name is None
+ assert op.hook._verify is None
+ assert op.hook._config is None
+
+ op = self.SENSOR(
+ **self.default_op_kwargs,
+ aws_conn_id="aws-test-custom-conn",
+ region_name="eu-west-1",
+ verify=False,
+ botocore_config={"read_timeout": 42},
+ )
+ assert op.hook.aws_conn_id == "aws-test-custom-conn"
+ assert op.hook._region_name == "eu-west-1"
+ assert op.hook._verify is False
+ assert op.hook._config is not None
+ assert op.hook._config.read_timeout == 42
+
+ @pytest.mark.parametrize("state", SENSOR.SUCCESS_STATES)
+ @mock.patch.object(ComprehendHook, "conn")
+ def test_poke_success_state(self, mock_conn, state):
+ mock_conn.describe_pii_entities_detection_job.return_value = {
+ "PiiEntitiesDetectionJobProperties": {"JobStatus": state}
+ }
+ assert self.sensor.poke({}) is True
+
+ @pytest.mark.parametrize("state", SENSOR.INTERMEDIATE_STATES)
+ @mock.patch.object(ComprehendHook, "conn")
+ def test_intermediate_state(self, mock_conn, state):
+ mock_conn.describe_pii_entities_detection_job.return_value = {
+ "PiiEntitiesDetectionJobProperties": {"JobStatus": state}
+ }
+ assert self.sensor.poke({}) is False
+
+ @pytest.mark.parametrize(
+ "soft_fail, expected_exception",
+ [
+ pytest.param(False, AirflowException, id="not-soft-fail"),
+ pytest.param(True, AirflowSkipException, id="soft-fail"),
+ ],
+ )
+ @pytest.mark.parametrize("state", SENSOR.FAILURE_STATES)
+ @mock.patch.object(ComprehendHook, "conn")
+ def test_poke_failure_states(self, mock_conn, state, soft_fail,
expected_exception):
+ mock_conn.describe_pii_entities_detection_job.return_value = {
+ "PiiEntitiesDetectionJobProperties": {"JobStatus": state}
+ }
+ sensor = self.SENSOR(**self.default_op_kwargs, aws_conn_id=None,
soft_fail=soft_fail)
+
+ with pytest.raises(expected_exception, match=sensor.FAILURE_MESSAGE):
+ sensor.poke({})
diff --git a/tests/providers/amazon/aws/triggers/test_comprehend.py
b/tests/providers/amazon/aws/triggers/test_comprehend.py
new file mode 100644
index 0000000000..1c52aa8810
--- /dev/null
+++ b/tests/providers/amazon/aws/triggers/test_comprehend.py
@@ -0,0 +1,67 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+from unittest.mock import AsyncMock
+
+import pytest
+
+from airflow.providers.amazon.aws.hooks.comprehend import ComprehendHook
+from airflow.providers.amazon.aws.triggers.comprehend import
ComprehendPiiEntitiesDetectionJobCompletedTrigger
+from airflow.triggers.base import TriggerEvent
+from tests.providers.amazon.aws.utils.test_waiter import
assert_expected_waiter_type
+
+BASE_TRIGGER_CLASSPATH = "airflow.providers.amazon.aws.triggers.comprehend."
+
+
+class TestBaseComprehendTrigger:
+ EXPECTED_WAITER_NAME: str | None = None
+ JOB_ID: str | None = None
+
+ def test_setup(self):
+ # Ensure that all subclasses have an expected waiter name set.
+ if self.__class__.__name__ != "TestBaseComprehendTrigger":
+ assert isinstance(self.EXPECTED_WAITER_NAME, str)
+ assert isinstance(self.JOB_ID, str)
+
+
+class
TestComprehendPiiEntitiesDetectionJobCompletedTrigger(TestBaseComprehendTrigger):
+ EXPECTED_WAITER_NAME = "pii_entities_detection_job_complete"
+ JOB_ID = "job_id"
+
+ def test_serialization(self):
+ """Assert that arguments and classpath are correctly serialized."""
+ trigger =
ComprehendPiiEntitiesDetectionJobCompletedTrigger(job_id=self.JOB_ID)
+ classpath, kwargs = trigger.serialize()
+ assert classpath == BASE_TRIGGER_CLASSPATH +
"ComprehendPiiEntitiesDetectionJobCompletedTrigger"
+ assert kwargs.get("job_id") == self.JOB_ID
+
+ @pytest.mark.asyncio
+ @mock.patch.object(ComprehendHook, "get_waiter")
+ @mock.patch.object(ComprehendHook, "async_conn")
+ async def test_run_success(self, mock_async_conn, mock_get_waiter):
+ mock_async_conn.__aenter__.return_value = mock.MagicMock()
+ mock_get_waiter().wait = AsyncMock()
+ trigger =
ComprehendPiiEntitiesDetectionJobCompletedTrigger(job_id=self.JOB_ID)
+
+ generator = trigger.run()
+ response = await generator.asend(None)
+
+ assert response == TriggerEvent({"status": "success", "job_id":
self.JOB_ID})
+ assert_expected_waiter_type(mock_get_waiter, self.EXPECTED_WAITER_NAME)
+ mock_get_waiter().wait.assert_called_once()
diff --git a/tests/providers/amazon/aws/waiters/test_comprehend.py
b/tests/providers/amazon/aws/waiters/test_comprehend.py
new file mode 100644
index 0000000000..a514ea198f
--- /dev/null
+++ b/tests/providers/amazon/aws/waiters/test_comprehend.py
@@ -0,0 +1,71 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+
+import boto3
+import botocore
+import pytest
+
+from airflow.providers.amazon.aws.hooks.comprehend import ComprehendHook
+from airflow.providers.amazon.aws.sensors.comprehend import (
+ ComprehendStartPiiEntitiesDetectionJobCompletedSensor,
+)
+
+
+class TestComprehendCustomWaiters:
+ def test_service_waiters(self):
+ assert "pii_entities_detection_job_complete" in
ComprehendHook().list_waiters()
+
+
+class TestComprehendCustomWaitersBase:
+ @pytest.fixture(autouse=True)
+ def mock_conn(self, monkeypatch):
+ self.client = boto3.client("comprehend")
+ monkeypatch.setattr(ComprehendHook, "conn", self.client)
+
+
+class
TestComprehendStartPiiEntitiesDetectionJobCompleteWaiter(TestComprehendCustomWaitersBase):
+ WAITER_NAME = "pii_entities_detection_job_complete"
+
+ @pytest.fixture
+ def mock_get_job(self):
+ with mock.patch.object(self.client,
"describe_pii_entities_detection_job") as mock_getter:
+ yield mock_getter
+
+ @pytest.mark.parametrize("state",
ComprehendStartPiiEntitiesDetectionJobCompletedSensor.SUCCESS_STATES)
+ def test_pii_entities_detection_job_complete(self, state, mock_get_job):
+ mock_get_job.return_value = {"PiiEntitiesDetectionJobProperties":
{"JobStatus": state}}
+
+ ComprehendHook().get_waiter(self.WAITER_NAME).wait(JobId="job_id")
+
+ @pytest.mark.parametrize("state",
ComprehendStartPiiEntitiesDetectionJobCompletedSensor.FAILURE_STATES)
+ def test_pii_entities_detection_job_failed(self, state, mock_get_job):
+ mock_get_job.return_value = {"PiiEntitiesDetectionJobProperties":
{"JobStatus": state}}
+
+ with pytest.raises(botocore.exceptions.WaiterError):
+ ComprehendHook().get_waiter(self.WAITER_NAME).wait(JobId="job_id")
+
+ def test_pii_entities_detection_job_wait(self, mock_get_job):
+ wait = {"PiiEntitiesDetectionJobProperties": {"JobStatus":
"IN_PROGRESS"}}
+ success = {"PiiEntitiesDetectionJobProperties": {"JobStatus":
"COMPLETED"}}
+ mock_get_job.side_effect = [wait, wait, success]
+
+ ComprehendHook().get_waiter(self.WAITER_NAME).wait(
+ JobId="job_id", WaiterConfig={"Delay": 0.01, "MaxAttempts": 3}
+ )
diff --git a/tests/system/providers/amazon/aws/example_comprehend.py
b/tests/system/providers/amazon/aws/example_comprehend.py
new file mode 100644
index 0000000000..58e34329b6
--- /dev/null
+++ b/tests/system/providers/amazon/aws/example_comprehend.py
@@ -0,0 +1,137 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import json
+from datetime import datetime
+
+from airflow import DAG
+from airflow.decorators import task_group
+from airflow.models.baseoperator import chain
+from airflow.providers.amazon.aws.operators.comprehend import
ComprehendStartPiiEntitiesDetectionJobOperator
+from airflow.providers.amazon.aws.operators.s3 import (
+ S3CreateBucketOperator,
+ S3CreateObjectOperator,
+ S3DeleteBucketOperator,
+)
+from airflow.providers.amazon.aws.sensors.comprehend import (
+ ComprehendStartPiiEntitiesDetectionJobCompletedSensor,
+)
+from airflow.utils.trigger_rule import TriggerRule
+from tests.system.providers.amazon.aws.utils import SystemTestContextBuilder
+
+ROLE_ARN_KEY = "ROLE_ARN"
+sys_test_context_task =
SystemTestContextBuilder().add_variable(ROLE_ARN_KEY).build()
+
+DAG_ID = "example_comprehend"
+INPUT_S3_KEY_START_PII_ENTITIES_DETECTION_JOB =
"start-pii-entities-detection-job/sample_data.txt"
+
+SAMPLE_DATA = {
+ "username": "bob1234",
+ "name": "Bob",
+ "sex": "M",
+ "address": "1773 Raymond Ville Suite 682",
+ "mail": "[email protected]",
+}
+
+
+@task_group
+def pii_entities_detection_job_workflow():
+ # [START howto_operator_start_pii_entities_detection_job]
+ start_pii_entities_detection_job =
ComprehendStartPiiEntitiesDetectionJobOperator(
+ task_id="start_pii_entities_detection_job",
+ input_data_config=input_data_configurations,
+ output_data_config=output_data_configurations,
+ mode="ONLY_REDACTION",
+ data_access_role_arn=test_context[ROLE_ARN_KEY],
+ language_code="en",
+ start_pii_entities_kwargs=pii_entities_kwargs,
+ )
+ # [END howto_operator_start_pii_entities_detection_job]
+ start_pii_entities_detection_job.wait_for_completion = False
+
+ # [START howto_sensor_start_pii_entities_detection_job]
+ await_start_pii_entities_detection_job =
ComprehendStartPiiEntitiesDetectionJobCompletedSensor(
+ task_id="await_start_pii_entities_detection_job",
job_id=start_pii_entities_detection_job.output
+ )
+ # [END howto_sensor_start_pii_entities_detection_job]
+
+ chain(start_pii_entities_detection_job,
await_start_pii_entities_detection_job)
+
+
+with DAG(
+ dag_id=DAG_ID,
+ schedule="@once",
+ start_date=datetime(2021, 1, 1),
+ tags=["example"],
+ catchup=False,
+) as dag:
+ test_context = sys_test_context_task()
+ env_id = test_context["ENV_ID"]
+ bucket_name = f"{env_id}-comprehend"
+ input_data_configurations = {
+ "S3Uri":
f"s3://{bucket_name}/{INPUT_S3_KEY_START_PII_ENTITIES_DETECTION_JOB}",
+ "InputFormat": "ONE_DOC_PER_LINE",
+ }
+ output_data_configurations = {"S3Uri":
f"s3://{bucket_name}/redacted_output/"}
+ pii_entities_kwargs = {
+ "RedactionConfig": {
+ "PiiEntityTypes": ["NAME", "ADDRESS"],
+ "MaskMode": "REPLACE_WITH_PII_ENTITY_TYPE",
+ }
+ }
+
+ create_bucket = S3CreateBucketOperator(
+ task_id="create_bucket",
+ bucket_name=bucket_name,
+ )
+
+ upload_sample_data = S3CreateObjectOperator(
+ task_id="upload_sample_data",
+ s3_bucket=bucket_name,
+ s3_key=INPUT_S3_KEY_START_PII_ENTITIES_DETECTION_JOB,
+ data=json.dumps(SAMPLE_DATA),
+ )
+
+ delete_bucket = S3DeleteBucketOperator(
+ task_id="delete_bucket",
+ trigger_rule=TriggerRule.ALL_DONE,
+ bucket_name=bucket_name,
+ force_delete=True,
+ )
+
+ chain(
+ # TEST SETUP
+ test_context,
+ create_bucket,
+ upload_sample_data,
+ # TEST BODY
+ pii_entities_detection_job_workflow(),
+ # TEST TEARDOWN
+ delete_bucket,
+ )
+
+ from tests.system.utils.watcher import watcher
+
+ # This test needs watcher in order to properly mark success/failure
+ # when "tearDown" task with trigger rule is part of the DAG
+ list(dag.tasks) >> watcher()
+
+from tests.system.utils import get_test_run # noqa: E402
+
+# Needed to run the example DAG with pytest (see:
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)