This is an automated email from the ASF dual-hosted git repository. uranusjr pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new bc4a22c6bd Use base aws classes in Amazon AppFlow Operators (#35082) bc4a22c6bd is described below commit bc4a22c6bd8096e7b62147031035cb14896fe934 Author: Andrey Anshin <andrey.ans...@taragol.is> AuthorDate: Mon Oct 23 12:47:31 2023 +0400 Use base aws classes in Amazon AppFlow Operators (#35082) --- airflow/providers/amazon/aws/hooks/appflow.py | 14 +-- airflow/providers/amazon/aws/operators/appflow.py | 108 +++++++++------------ airflow/providers/amazon/aws/operators/base_aws.py | 2 +- .../operators/appflow.rst | 5 + .../providers/amazon/aws/operators/test_appflow.py | 54 +++++++++++ 5 files changed, 111 insertions(+), 72 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/appflow.py b/airflow/providers/amazon/aws/hooks/appflow.py index 3962a6cf20..f60b5eea1a 100644 --- a/airflow/providers/amazon/aws/hooks/appflow.py +++ b/airflow/providers/amazon/aws/hooks/appflow.py @@ -16,19 +16,18 @@ # under the License. from __future__ import annotations -from functools import cached_property from typing import TYPE_CHECKING -from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook +from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook from airflow.providers.amazon.aws.utils.waiter_with_logging import wait if TYPE_CHECKING: - from mypy_boto3_appflow.client import AppflowClient + from mypy_boto3_appflow.client import AppflowClient # noqa -class AppflowHook(AwsBaseHook): +class AppflowHook(AwsGenericHook["AppflowClient"]): """ - Interact with Amazon Appflow. + Interact with Amazon AppFlow. Provide thin wrapper around :external+boto3:py:class:`boto3.client("appflow") <Appflow.Client>`. @@ -44,11 +43,6 @@ class AppflowHook(AwsBaseHook): kwargs["client_type"] = "appflow" super().__init__(*args, **kwargs) - @cached_property - def conn(self) -> AppflowClient: - """Get the underlying boto3 Appflow client (cached).""" - return super().conn - def run_flow( self, flow_name: str, diff --git a/airflow/providers/amazon/aws/operators/appflow.py b/airflow/providers/amazon/aws/operators/appflow.py index 184fc7fab1..2eb8f704ca 100644 --- a/airflow/providers/amazon/aws/operators/appflow.py +++ b/airflow/providers/amazon/aws/operators/appflow.py @@ -19,14 +19,14 @@ from __future__ import annotations import time import warnings from datetime import datetime, timedelta -from functools import cached_property from typing import TYPE_CHECKING, cast from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning -from airflow.models import BaseOperator from airflow.operators.python import ShortCircuitOperator from airflow.providers.amazon.aws.hooks.appflow import AppflowHook +from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator from airflow.providers.amazon.aws.utils import datetime_to_epoch_ms +from airflow.providers.amazon.aws.utils.mixins import AwsBaseHookMixin, AwsHookParams, aws_template_fields if TYPE_CHECKING: from mypy_boto3_appflow.type_defs import ( @@ -42,9 +42,9 @@ MANDATORY_FILTER_DATE_MSG = "The filter_date argument is mandatory for {entity}! NOT_SUPPORTED_SOURCE_MSG = "Source {source} is not supported for {entity}!" -class AppflowBaseOperator(BaseOperator): +class AppflowBaseOperator(AwsBaseOperator[AppflowHook]): """ - Amazon Appflow Base Operator class (not supposed to be used directly in DAGs). + Amazon AppFlow Base Operator class (not supposed to be used directly in DAGs). :param source: The source name (Supported: salesforce, zendesk) :param flow_name: The flow name @@ -53,14 +53,22 @@ class AppflowBaseOperator(BaseOperator): :param filter_date: The date value (or template) to be used in filters. :param poll_interval: how often in seconds to check the query status :param max_attempts: how many times to check for status before timing out - :param aws_conn_id: aws connection to use - :param region: aws region to use :param wait_for_completion: whether to wait for the run to end to return + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html """ + aws_hook_class = AppflowHook ui_color = "#2bccbd" - - template_fields = ("flow_name", "source", "source_field", "filter_date") + template_fields = aws_template_fields("flow_name", "source", "source_field", "filter_date") UPDATE_PROPAGATION_TIME: int = 15 @@ -73,8 +81,6 @@ class AppflowBaseOperator(BaseOperator): filter_date: str | None = None, poll_interval: int = 20, max_attempts: int = 60, - aws_conn_id: str = "aws_default", - region: str | None = None, wait_for_completion: bool = True, **kwargs, ) -> None: @@ -87,16 +93,9 @@ class AppflowBaseOperator(BaseOperator): self.source_field = source_field self.poll_interval = poll_interval self.max_attempts = max_attempts - self.aws_conn_id = aws_conn_id - self.region = region self.flow_update = flow_update self.wait_for_completion = wait_for_completion - @cached_property - def hook(self) -> AppflowHook: - """Create and return an AppflowHook.""" - return AppflowHook(aws_conn_id=self.aws_conn_id, region_name=self.region) - def execute(self, context: Context) -> None: self.filter_date_parsed: datetime | None = ( datetime.fromisoformat(self.filter_date) if self.filter_date else None @@ -135,7 +134,7 @@ class AppflowBaseOperator(BaseOperator): class AppflowRunOperator(AppflowBaseOperator): """ - Execute a Appflow run as is. + Execute an AppFlow run as is. .. seealso:: For more information on how to use this operator, take a look at the guide: @@ -154,8 +153,6 @@ class AppflowRunOperator(AppflowBaseOperator): flow_name: str, source: str | None = None, poll_interval: int = 20, - aws_conn_id: str = "aws_default", - region: str | None = None, wait_for_completion: bool = True, **kwargs, ) -> None: @@ -171,8 +168,6 @@ class AppflowRunOperator(AppflowBaseOperator): source_field=None, filter_date=None, poll_interval=poll_interval, - aws_conn_id=aws_conn_id, - region=region, wait_for_completion=wait_for_completion, **kwargs, ) @@ -180,7 +175,7 @@ class AppflowRunOperator(AppflowBaseOperator): class AppflowRunFullOperator(AppflowBaseOperator): """ - Execute a Appflow full run removing any filter. + Execute an AppFlow full run removing any filter. .. seealso:: For more information on how to use this operator, take a look at the guide: @@ -189,8 +184,6 @@ class AppflowRunFullOperator(AppflowBaseOperator): :param source: The source name (Supported: salesforce, zendesk) :param flow_name: The flow name :param poll_interval: how often in seconds to check the query status - :param aws_conn_id: aws connection to use - :param region: aws region to use :param wait_for_completion: whether to wait for the run to end to return """ @@ -199,8 +192,6 @@ class AppflowRunFullOperator(AppflowBaseOperator): source: str, flow_name: str, poll_interval: int = 20, - aws_conn_id: str = "aws_default", - region: str | None = None, wait_for_completion: bool = True, **kwargs, ) -> None: @@ -213,8 +204,6 @@ class AppflowRunFullOperator(AppflowBaseOperator): source_field=None, filter_date=None, poll_interval=poll_interval, - aws_conn_id=aws_conn_id, - region=region, wait_for_completion=wait_for_completion, **kwargs, ) @@ -222,7 +211,7 @@ class AppflowRunFullOperator(AppflowBaseOperator): class AppflowRunBeforeOperator(AppflowBaseOperator): """ - Execute a Appflow run after updating the filters to select only previous data. + Execute an AppFlow run after updating the filters to select only previous data. .. seealso:: For more information on how to use this operator, take a look at the guide: @@ -245,8 +234,6 @@ class AppflowRunBeforeOperator(AppflowBaseOperator): source_field: str, filter_date: str, poll_interval: int = 20, - aws_conn_id: str = "aws_default", - region: str | None = None, wait_for_completion: bool = True, **kwargs, ) -> None: @@ -263,8 +250,6 @@ class AppflowRunBeforeOperator(AppflowBaseOperator): source_field=source_field, filter_date=filter_date, poll_interval=poll_interval, - aws_conn_id=aws_conn_id, - region=region, wait_for_completion=wait_for_completion, **kwargs, ) @@ -290,7 +275,7 @@ class AppflowRunBeforeOperator(AppflowBaseOperator): class AppflowRunAfterOperator(AppflowBaseOperator): """ - Execute a Appflow run after updating the filters to select only future data. + Execute an AppFlow run after updating the filters to select only future data. .. seealso:: For more information on how to use this operator, take a look at the guide: @@ -301,8 +286,6 @@ class AppflowRunAfterOperator(AppflowBaseOperator): :param source_field: The field name to apply filters :param filter_date: The date value (or template) to be used in filters. :param poll_interval: how often in seconds to check the query status - :param aws_conn_id: aws connection to use - :param region: aws region to use :param wait_for_completion: whether to wait for the run to end to return """ @@ -313,8 +296,6 @@ class AppflowRunAfterOperator(AppflowBaseOperator): source_field: str, filter_date: str, poll_interval: int = 20, - aws_conn_id: str = "aws_default", - region: str | None = None, wait_for_completion: bool = True, **kwargs, ) -> None: @@ -329,8 +310,6 @@ class AppflowRunAfterOperator(AppflowBaseOperator): source_field=source_field, filter_date=filter_date, poll_interval=poll_interval, - aws_conn_id=aws_conn_id, - region=region, wait_for_completion=wait_for_completion, **kwargs, ) @@ -356,7 +335,7 @@ class AppflowRunAfterOperator(AppflowBaseOperator): class AppflowRunDailyOperator(AppflowBaseOperator): """ - Execute a Appflow run after updating the filters to select only a single day. + Execute an AppFlow run after updating the filters to select only a single day. .. seealso:: For more information on how to use this operator, take a look at the guide: @@ -367,8 +346,6 @@ class AppflowRunDailyOperator(AppflowBaseOperator): :param source_field: The field name to apply filters :param filter_date: The date value (or template) to be used in filters. :param poll_interval: how often in seconds to check the query status - :param aws_conn_id: aws connection to use - :param region: aws region to use :param wait_for_completion: whether to wait for the run to end to return """ @@ -379,8 +356,6 @@ class AppflowRunDailyOperator(AppflowBaseOperator): source_field: str, filter_date: str, poll_interval: int = 20, - aws_conn_id: str = "aws_default", - region: str | None = None, wait_for_completion: bool = True, **kwargs, ) -> None: @@ -395,8 +370,6 @@ class AppflowRunDailyOperator(AppflowBaseOperator): source_field=source_field, filter_date=filter_date, poll_interval=poll_interval, - aws_conn_id=aws_conn_id, - region=region, wait_for_completion=wait_for_completion, **kwargs, ) @@ -423,9 +396,9 @@ class AppflowRunDailyOperator(AppflowBaseOperator): ) -class AppflowRecordsShortCircuitOperator(ShortCircuitOperator): +class AppflowRecordsShortCircuitOperator(ShortCircuitOperator, AwsBaseHookMixin[AppflowHook]): """ - Short-circuit in case of a empty Appflow's run. + Short-circuit in case of an empty AppFlow's run. .. seealso:: For more information on how to use this operator, take a look at the guide: @@ -434,10 +407,20 @@ class AppflowRecordsShortCircuitOperator(ShortCircuitOperator): :param flow_name: The flow name :param appflow_run_task_id: Run task ID from where this operator should extract the execution ID :param ignore_downstream_trigger_rules: Ignore downstream trigger rules - :param aws_conn_id: aws connection to use - :param region: aws region to use + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html """ + aws_hook_class = AppflowHook + template_fields = aws_template_fields() ui_color = "#33ffec" # Light blue def __init__( @@ -446,10 +429,15 @@ class AppflowRecordsShortCircuitOperator(ShortCircuitOperator): flow_name: str, appflow_run_task_id: str, ignore_downstream_trigger_rules: bool = True, - aws_conn_id: str = "aws_default", - region: str | None = None, + aws_conn_id: str | None = "aws_default", + region_name: str | None = None, + verify: bool | str | None = None, + botocore_config: dict | None = None, **kwargs, ) -> None: + hook_params = AwsHookParams.from_constructor( + aws_conn_id, region_name, verify, botocore_config, additional_params=kwargs + ) super().__init__( python_callable=self._has_new_records_func, op_kwargs={ @@ -459,8 +447,11 @@ class AppflowRecordsShortCircuitOperator(ShortCircuitOperator): ignore_downstream_trigger_rules=ignore_downstream_trigger_rules, **kwargs, ) - self.aws_conn_id = aws_conn_id - self.region = region + self.aws_conn_id = hook_params.aws_conn_id + self.region_name = hook_params.region_name + self.verify = hook_params.verify + self.botocore_config = hook_params.botocore_config + self.validate_attributes() @staticmethod def _get_target_execution_id( @@ -471,11 +462,6 @@ class AppflowRecordsShortCircuitOperator(ShortCircuitOperator): return record return None - @cached_property - def hook(self) -> AppflowHook: - """Create and return an AppflowHook.""" - return AppflowHook(aws_conn_id=self.aws_conn_id, region_name=self.region) - def _has_new_records_func(self, **kwargs) -> bool: appflow_task_id = kwargs["appflow_run_task_id"] self.log.info("appflow_task_id: %s", appflow_task_id) diff --git a/airflow/providers/amazon/aws/operators/base_aws.py b/airflow/providers/amazon/aws/operators/base_aws.py index aaa5059afb..8cc1b35302 100644 --- a/airflow/providers/amazon/aws/operators/base_aws.py +++ b/airflow/providers/amazon/aws/operators/base_aws.py @@ -64,7 +64,7 @@ class AwsBaseOperator(BaseOperator, AwsBaseHookMixin[AwsHookType]): pass :param aws_conn_id: The Airflow connection used for AWS credentials. - If this is None or empty then the default boto3 behaviour is used. If + If this is ``None`` or empty then the default boto3 behaviour is used. If running Airflow in a distributed manner and aws_conn_id is None or empty, then default boto3 configuration would be used (and must be maintained on each worker node). diff --git a/docs/apache-airflow-providers-amazon/operators/appflow.rst b/docs/apache-airflow-providers-amazon/operators/appflow.rst index 14d3a5e5ca..c28cfb001b 100644 --- a/docs/apache-airflow-providers-amazon/operators/appflow.rst +++ b/docs/apache-airflow-providers-amazon/operators/appflow.rst @@ -35,6 +35,11 @@ Prerequisite Tasks .. include:: ../_partials/prerequisite_tasks.rst +Generic Parameters +------------------ + +.. include:: ../_partials/generic_parameters.rst + Operators --------- diff --git a/tests/providers/amazon/aws/operators/test_appflow.py b/tests/providers/amazon/aws/operators/test_appflow.py index ce36faad5f..dc3c2fddca 100644 --- a/tests/providers/amazon/aws/operators/test_appflow.py +++ b/tests/providers/amazon/aws/operators/test_appflow.py @@ -196,3 +196,57 @@ def test_short_circuit(appflow_conn, ctx): flowName=FLOW_NAME, maxResults=100 ) mock_xcom_push.assert_called_with("records_processed", 1) + + +@pytest.mark.parametrize( + "op_class, op_base_args", + [ + pytest.param( + AppflowRunAfterOperator, + dict(**DUMP_COMMON_ARGS, source_field="col0", filter_date="2022-05-26T00:00+00:00"), + id="run-after-op", + ), + pytest.param( + AppflowRunBeforeOperator, + dict(**DUMP_COMMON_ARGS, source_field="col1", filter_date="2077-10-23T00:03+00:00"), + id="run-before-op", + ), + pytest.param( + AppflowRunDailyOperator, + dict(**DUMP_COMMON_ARGS, source_field="col2", filter_date="2023-10-20T12:22+00:00"), + id="run-daily-op", + ), + pytest.param(AppflowRunFullOperator, DUMP_COMMON_ARGS, id="run-full-op"), + pytest.param(AppflowRunOperator, DUMP_COMMON_ARGS, id="run-op"), + pytest.param( + AppflowRecordsShortCircuitOperator, + dict(task_id=SHORT_CIRCUIT_TASK_ID, flow_name=FLOW_NAME, appflow_run_task_id=TASK_ID), + id="records-short-circuit", + ), + ], +) +def test_base_aws_op_attributes(op_class, op_base_args): + op = op_class(**op_base_args) + hook = op.hook + assert hook is op.hook + assert hook.aws_conn_id == CONN_ID + assert hook._region_name is None + assert hook._verify is None + assert hook._config is None + + op = op_class(**op_base_args, region_name="eu-west-1", verify=False, botocore_config={"read_timeout": 42}) + hook = op.hook + assert hook is op.hook + assert hook.aws_conn_id == CONN_ID + assert hook._region_name == "eu-west-1" + assert hook._verify is False + assert hook._config.read_timeout == 42 + + # Compatibility check: previously Appflow Operators use `region` instead of `region_name` + warning_message = "`region` is deprecated and will be removed in the future" + with pytest.warns(DeprecationWarning, match=warning_message): + op = op_class(**op_base_args, region="us-west-1") + assert op.region_name == "us-west-1" + + with pytest.warns(DeprecationWarning, match=warning_message): + assert op.region == "us-west-1"