This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new bc4a22c6bd Use base aws classes in Amazon AppFlow Operators (#35082)
bc4a22c6bd is described below

commit bc4a22c6bd8096e7b62147031035cb14896fe934
Author: Andrey Anshin <andrey.ans...@taragol.is>
AuthorDate: Mon Oct 23 12:47:31 2023 +0400

    Use base aws classes in Amazon AppFlow Operators (#35082)
---
 airflow/providers/amazon/aws/hooks/appflow.py      |  14 +--
 airflow/providers/amazon/aws/operators/appflow.py  | 108 +++++++++------------
 airflow/providers/amazon/aws/operators/base_aws.py |   2 +-
 .../operators/appflow.rst                          |   5 +
 .../providers/amazon/aws/operators/test_appflow.py |  54 +++++++++++
 5 files changed, 111 insertions(+), 72 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/appflow.py 
b/airflow/providers/amazon/aws/hooks/appflow.py
index 3962a6cf20..f60b5eea1a 100644
--- a/airflow/providers/amazon/aws/hooks/appflow.py
+++ b/airflow/providers/amazon/aws/hooks/appflow.py
@@ -16,19 +16,18 @@
 # under the License.
 from __future__ import annotations
 
-from functools import cached_property
 from typing import TYPE_CHECKING
 
-from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
+from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
 from airflow.providers.amazon.aws.utils.waiter_with_logging import wait
 
 if TYPE_CHECKING:
-    from mypy_boto3_appflow.client import AppflowClient
+    from mypy_boto3_appflow.client import AppflowClient  # noqa
 
 
-class AppflowHook(AwsBaseHook):
+class AppflowHook(AwsGenericHook["AppflowClient"]):
     """
-    Interact with Amazon Appflow.
+    Interact with Amazon AppFlow.
 
     Provide thin wrapper around 
:external+boto3:py:class:`boto3.client("appflow") <Appflow.Client>`.
 
@@ -44,11 +43,6 @@ class AppflowHook(AwsBaseHook):
         kwargs["client_type"] = "appflow"
         super().__init__(*args, **kwargs)
 
-    @cached_property
-    def conn(self) -> AppflowClient:
-        """Get the underlying boto3 Appflow client (cached)."""
-        return super().conn
-
     def run_flow(
         self,
         flow_name: str,
diff --git a/airflow/providers/amazon/aws/operators/appflow.py 
b/airflow/providers/amazon/aws/operators/appflow.py
index 184fc7fab1..2eb8f704ca 100644
--- a/airflow/providers/amazon/aws/operators/appflow.py
+++ b/airflow/providers/amazon/aws/operators/appflow.py
@@ -19,14 +19,14 @@ from __future__ import annotations
 import time
 import warnings
 from datetime import datetime, timedelta
-from functools import cached_property
 from typing import TYPE_CHECKING, cast
 
 from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarning
-from airflow.models import BaseOperator
 from airflow.operators.python import ShortCircuitOperator
 from airflow.providers.amazon.aws.hooks.appflow import AppflowHook
+from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
 from airflow.providers.amazon.aws.utils import datetime_to_epoch_ms
+from airflow.providers.amazon.aws.utils.mixins import AwsBaseHookMixin, 
AwsHookParams, aws_template_fields
 
 if TYPE_CHECKING:
     from mypy_boto3_appflow.type_defs import (
@@ -42,9 +42,9 @@ MANDATORY_FILTER_DATE_MSG = "The filter_date argument is 
mandatory for {entity}!
 NOT_SUPPORTED_SOURCE_MSG = "Source {source} is not supported for {entity}!"
 
 
-class AppflowBaseOperator(BaseOperator):
+class AppflowBaseOperator(AwsBaseOperator[AppflowHook]):
     """
-    Amazon Appflow Base Operator class (not supposed to be used directly in 
DAGs).
+    Amazon AppFlow Base Operator class (not supposed to be used directly in 
DAGs).
 
     :param source: The source name (Supported: salesforce, zendesk)
     :param flow_name: The flow name
@@ -53,14 +53,22 @@ class AppflowBaseOperator(BaseOperator):
     :param filter_date: The date value (or template) to be used in filters.
     :param poll_interval: how often in seconds to check the query status
     :param max_attempts: how many times to check for status before timing out
-    :param aws_conn_id: aws connection to use
-    :param region: aws region to use
     :param wait_for_completion: whether to wait for the run to end to return
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+        If this is ``None`` or empty then the default boto3 behaviour is used. 
If
+        running Airflow in a distributed manner and aws_conn_id is None or
+        empty, then default boto3 configuration would be used (and must be
+        maintained on each worker node).
+    :param region_name: AWS region_name. If not specified then the default 
boto3 behaviour is used.
+    :param verify: Whether or not to verify SSL certificates. See:
+        
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+    :param botocore_config: Configuration dictionary (key-values) for botocore 
client. See:
+        
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
     """
 
+    aws_hook_class = AppflowHook
     ui_color = "#2bccbd"
-
-    template_fields = ("flow_name", "source", "source_field", "filter_date")
+    template_fields = aws_template_fields("flow_name", "source", 
"source_field", "filter_date")
 
     UPDATE_PROPAGATION_TIME: int = 15
 
@@ -73,8 +81,6 @@ class AppflowBaseOperator(BaseOperator):
         filter_date: str | None = None,
         poll_interval: int = 20,
         max_attempts: int = 60,
-        aws_conn_id: str = "aws_default",
-        region: str | None = None,
         wait_for_completion: bool = True,
         **kwargs,
     ) -> None:
@@ -87,16 +93,9 @@ class AppflowBaseOperator(BaseOperator):
         self.source_field = source_field
         self.poll_interval = poll_interval
         self.max_attempts = max_attempts
-        self.aws_conn_id = aws_conn_id
-        self.region = region
         self.flow_update = flow_update
         self.wait_for_completion = wait_for_completion
 
-    @cached_property
-    def hook(self) -> AppflowHook:
-        """Create and return an AppflowHook."""
-        return AppflowHook(aws_conn_id=self.aws_conn_id, 
region_name=self.region)
-
     def execute(self, context: Context) -> None:
         self.filter_date_parsed: datetime | None = (
             datetime.fromisoformat(self.filter_date) if self.filter_date else 
None
@@ -135,7 +134,7 @@ class AppflowBaseOperator(BaseOperator):
 
 class AppflowRunOperator(AppflowBaseOperator):
     """
-    Execute a Appflow run as is.
+    Execute an AppFlow run as is.
 
     .. seealso::
         For more information on how to use this operator, take a look at the 
guide:
@@ -154,8 +153,6 @@ class AppflowRunOperator(AppflowBaseOperator):
         flow_name: str,
         source: str | None = None,
         poll_interval: int = 20,
-        aws_conn_id: str = "aws_default",
-        region: str | None = None,
         wait_for_completion: bool = True,
         **kwargs,
     ) -> None:
@@ -171,8 +168,6 @@ class AppflowRunOperator(AppflowBaseOperator):
             source_field=None,
             filter_date=None,
             poll_interval=poll_interval,
-            aws_conn_id=aws_conn_id,
-            region=region,
             wait_for_completion=wait_for_completion,
             **kwargs,
         )
@@ -180,7 +175,7 @@ class AppflowRunOperator(AppflowBaseOperator):
 
 class AppflowRunFullOperator(AppflowBaseOperator):
     """
-    Execute a Appflow full run removing any filter.
+    Execute an AppFlow full run removing any filter.
 
     .. seealso::
         For more information on how to use this operator, take a look at the 
guide:
@@ -189,8 +184,6 @@ class AppflowRunFullOperator(AppflowBaseOperator):
     :param source: The source name (Supported: salesforce, zendesk)
     :param flow_name: The flow name
     :param poll_interval: how often in seconds to check the query status
-    :param aws_conn_id: aws connection to use
-    :param region: aws region to use
     :param wait_for_completion: whether to wait for the run to end to return
     """
 
@@ -199,8 +192,6 @@ class AppflowRunFullOperator(AppflowBaseOperator):
         source: str,
         flow_name: str,
         poll_interval: int = 20,
-        aws_conn_id: str = "aws_default",
-        region: str | None = None,
         wait_for_completion: bool = True,
         **kwargs,
     ) -> None:
@@ -213,8 +204,6 @@ class AppflowRunFullOperator(AppflowBaseOperator):
             source_field=None,
             filter_date=None,
             poll_interval=poll_interval,
-            aws_conn_id=aws_conn_id,
-            region=region,
             wait_for_completion=wait_for_completion,
             **kwargs,
         )
@@ -222,7 +211,7 @@ class AppflowRunFullOperator(AppflowBaseOperator):
 
 class AppflowRunBeforeOperator(AppflowBaseOperator):
     """
-    Execute a Appflow run after updating the filters to select only previous 
data.
+    Execute an AppFlow run after updating the filters to select only previous 
data.
 
     .. seealso::
         For more information on how to use this operator, take a look at the 
guide:
@@ -245,8 +234,6 @@ class AppflowRunBeforeOperator(AppflowBaseOperator):
         source_field: str,
         filter_date: str,
         poll_interval: int = 20,
-        aws_conn_id: str = "aws_default",
-        region: str | None = None,
         wait_for_completion: bool = True,
         **kwargs,
     ) -> None:
@@ -263,8 +250,6 @@ class AppflowRunBeforeOperator(AppflowBaseOperator):
             source_field=source_field,
             filter_date=filter_date,
             poll_interval=poll_interval,
-            aws_conn_id=aws_conn_id,
-            region=region,
             wait_for_completion=wait_for_completion,
             **kwargs,
         )
@@ -290,7 +275,7 @@ class AppflowRunBeforeOperator(AppflowBaseOperator):
 
 class AppflowRunAfterOperator(AppflowBaseOperator):
     """
-    Execute a Appflow run after updating the filters to select only future 
data.
+    Execute an AppFlow run after updating the filters to select only future 
data.
 
     .. seealso::
         For more information on how to use this operator, take a look at the 
guide:
@@ -301,8 +286,6 @@ class AppflowRunAfterOperator(AppflowBaseOperator):
     :param source_field: The field name to apply filters
     :param filter_date: The date value (or template) to be used in filters.
     :param poll_interval: how often in seconds to check the query status
-    :param aws_conn_id: aws connection to use
-    :param region: aws region to use
     :param wait_for_completion: whether to wait for the run to end to return
     """
 
@@ -313,8 +296,6 @@ class AppflowRunAfterOperator(AppflowBaseOperator):
         source_field: str,
         filter_date: str,
         poll_interval: int = 20,
-        aws_conn_id: str = "aws_default",
-        region: str | None = None,
         wait_for_completion: bool = True,
         **kwargs,
     ) -> None:
@@ -329,8 +310,6 @@ class AppflowRunAfterOperator(AppflowBaseOperator):
             source_field=source_field,
             filter_date=filter_date,
             poll_interval=poll_interval,
-            aws_conn_id=aws_conn_id,
-            region=region,
             wait_for_completion=wait_for_completion,
             **kwargs,
         )
@@ -356,7 +335,7 @@ class AppflowRunAfterOperator(AppflowBaseOperator):
 
 class AppflowRunDailyOperator(AppflowBaseOperator):
     """
-    Execute a Appflow run after updating the filters to select only a single 
day.
+    Execute an AppFlow run after updating the filters to select only a single 
day.
 
     .. seealso::
         For more information on how to use this operator, take a look at the 
guide:
@@ -367,8 +346,6 @@ class AppflowRunDailyOperator(AppflowBaseOperator):
     :param source_field: The field name to apply filters
     :param filter_date: The date value (or template) to be used in filters.
     :param poll_interval: how often in seconds to check the query status
-    :param aws_conn_id: aws connection to use
-    :param region: aws region to use
     :param wait_for_completion: whether to wait for the run to end to return
     """
 
@@ -379,8 +356,6 @@ class AppflowRunDailyOperator(AppflowBaseOperator):
         source_field: str,
         filter_date: str,
         poll_interval: int = 20,
-        aws_conn_id: str = "aws_default",
-        region: str | None = None,
         wait_for_completion: bool = True,
         **kwargs,
     ) -> None:
@@ -395,8 +370,6 @@ class AppflowRunDailyOperator(AppflowBaseOperator):
             source_field=source_field,
             filter_date=filter_date,
             poll_interval=poll_interval,
-            aws_conn_id=aws_conn_id,
-            region=region,
             wait_for_completion=wait_for_completion,
             **kwargs,
         )
@@ -423,9 +396,9 @@ class AppflowRunDailyOperator(AppflowBaseOperator):
         )
 
 
-class AppflowRecordsShortCircuitOperator(ShortCircuitOperator):
+class AppflowRecordsShortCircuitOperator(ShortCircuitOperator, 
AwsBaseHookMixin[AppflowHook]):
     """
-    Short-circuit in case of a empty Appflow's run.
+    Short-circuit in case of an empty AppFlow's run.
 
     .. seealso::
         For more information on how to use this operator, take a look at the 
guide:
@@ -434,10 +407,20 @@ class 
AppflowRecordsShortCircuitOperator(ShortCircuitOperator):
     :param flow_name: The flow name
     :param appflow_run_task_id: Run task ID from where this operator should 
extract the execution ID
     :param ignore_downstream_trigger_rules: Ignore downstream trigger rules
-    :param aws_conn_id: aws connection to use
-    :param region: aws region to use
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+        If this is ``None`` or empty then the default boto3 behaviour is used. 
If
+        running Airflow in a distributed manner and aws_conn_id is None or
+        empty, then default boto3 configuration would be used (and must be
+        maintained on each worker node).
+    :param region_name: AWS region_name. If not specified then the default 
boto3 behaviour is used.
+    :param verify: Whether or not to verify SSL certificates. See:
+        
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+    :param botocore_config: Configuration dictionary (key-values) for botocore 
client. See:
+        
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
     """
 
+    aws_hook_class = AppflowHook
+    template_fields = aws_template_fields()
     ui_color = "#33ffec"  # Light blue
 
     def __init__(
@@ -446,10 +429,15 @@ class 
AppflowRecordsShortCircuitOperator(ShortCircuitOperator):
         flow_name: str,
         appflow_run_task_id: str,
         ignore_downstream_trigger_rules: bool = True,
-        aws_conn_id: str = "aws_default",
-        region: str | None = None,
+        aws_conn_id: str | None = "aws_default",
+        region_name: str | None = None,
+        verify: bool | str | None = None,
+        botocore_config: dict | None = None,
         **kwargs,
     ) -> None:
+        hook_params = AwsHookParams.from_constructor(
+            aws_conn_id, region_name, verify, botocore_config, 
additional_params=kwargs
+        )
         super().__init__(
             python_callable=self._has_new_records_func,
             op_kwargs={
@@ -459,8 +447,11 @@ class 
AppflowRecordsShortCircuitOperator(ShortCircuitOperator):
             ignore_downstream_trigger_rules=ignore_downstream_trigger_rules,
             **kwargs,
         )
-        self.aws_conn_id = aws_conn_id
-        self.region = region
+        self.aws_conn_id = hook_params.aws_conn_id
+        self.region_name = hook_params.region_name
+        self.verify = hook_params.verify
+        self.botocore_config = hook_params.botocore_config
+        self.validate_attributes()
 
     @staticmethod
     def _get_target_execution_id(
@@ -471,11 +462,6 @@ class 
AppflowRecordsShortCircuitOperator(ShortCircuitOperator):
                 return record
         return None
 
-    @cached_property
-    def hook(self) -> AppflowHook:
-        """Create and return an AppflowHook."""
-        return AppflowHook(aws_conn_id=self.aws_conn_id, 
region_name=self.region)
-
     def _has_new_records_func(self, **kwargs) -> bool:
         appflow_task_id = kwargs["appflow_run_task_id"]
         self.log.info("appflow_task_id: %s", appflow_task_id)
diff --git a/airflow/providers/amazon/aws/operators/base_aws.py 
b/airflow/providers/amazon/aws/operators/base_aws.py
index aaa5059afb..8cc1b35302 100644
--- a/airflow/providers/amazon/aws/operators/base_aws.py
+++ b/airflow/providers/amazon/aws/operators/base_aws.py
@@ -64,7 +64,7 @@ class AwsBaseOperator(BaseOperator, 
AwsBaseHookMixin[AwsHookType]):
                 pass
 
     :param aws_conn_id: The Airflow connection used for AWS credentials.
-        If this is None or empty then the default boto3 behaviour is used. If
+        If this is ``None`` or empty then the default boto3 behaviour is used. 
If
         running Airflow in a distributed manner and aws_conn_id is None or
         empty, then default boto3 configuration would be used (and must be
         maintained on each worker node).
diff --git a/docs/apache-airflow-providers-amazon/operators/appflow.rst 
b/docs/apache-airflow-providers-amazon/operators/appflow.rst
index 14d3a5e5ca..c28cfb001b 100644
--- a/docs/apache-airflow-providers-amazon/operators/appflow.rst
+++ b/docs/apache-airflow-providers-amazon/operators/appflow.rst
@@ -35,6 +35,11 @@ Prerequisite Tasks
 
 .. include:: ../_partials/prerequisite_tasks.rst
 
+Generic Parameters
+------------------
+
+.. include:: ../_partials/generic_parameters.rst
+
 Operators
 ---------
 
diff --git a/tests/providers/amazon/aws/operators/test_appflow.py 
b/tests/providers/amazon/aws/operators/test_appflow.py
index ce36faad5f..dc3c2fddca 100644
--- a/tests/providers/amazon/aws/operators/test_appflow.py
+++ b/tests/providers/amazon/aws/operators/test_appflow.py
@@ -196,3 +196,57 @@ def test_short_circuit(appflow_conn, ctx):
                 flowName=FLOW_NAME, maxResults=100
             )
             mock_xcom_push.assert_called_with("records_processed", 1)
+
+
+@pytest.mark.parametrize(
+    "op_class, op_base_args",
+    [
+        pytest.param(
+            AppflowRunAfterOperator,
+            dict(**DUMP_COMMON_ARGS, source_field="col0", 
filter_date="2022-05-26T00:00+00:00"),
+            id="run-after-op",
+        ),
+        pytest.param(
+            AppflowRunBeforeOperator,
+            dict(**DUMP_COMMON_ARGS, source_field="col1", 
filter_date="2077-10-23T00:03+00:00"),
+            id="run-before-op",
+        ),
+        pytest.param(
+            AppflowRunDailyOperator,
+            dict(**DUMP_COMMON_ARGS, source_field="col2", 
filter_date="2023-10-20T12:22+00:00"),
+            id="run-daily-op",
+        ),
+        pytest.param(AppflowRunFullOperator, DUMP_COMMON_ARGS, 
id="run-full-op"),
+        pytest.param(AppflowRunOperator, DUMP_COMMON_ARGS, id="run-op"),
+        pytest.param(
+            AppflowRecordsShortCircuitOperator,
+            dict(task_id=SHORT_CIRCUIT_TASK_ID, flow_name=FLOW_NAME, 
appflow_run_task_id=TASK_ID),
+            id="records-short-circuit",
+        ),
+    ],
+)
+def test_base_aws_op_attributes(op_class, op_base_args):
+    op = op_class(**op_base_args)
+    hook = op.hook
+    assert hook is op.hook
+    assert hook.aws_conn_id == CONN_ID
+    assert hook._region_name is None
+    assert hook._verify is None
+    assert hook._config is None
+
+    op = op_class(**op_base_args, region_name="eu-west-1", verify=False, 
botocore_config={"read_timeout": 42})
+    hook = op.hook
+    assert hook is op.hook
+    assert hook.aws_conn_id == CONN_ID
+    assert hook._region_name == "eu-west-1"
+    assert hook._verify is False
+    assert hook._config.read_timeout == 42
+
+    # Compatibility check: previously Appflow Operators use `region` instead 
of `region_name`
+    warning_message = "`region` is deprecated and will be removed in the 
future"
+    with pytest.warns(DeprecationWarning, match=warning_message):
+        op = op_class(**op_base_args, region="us-west-1")
+    assert op.region_name == "us-west-1"
+
+    with pytest.warns(DeprecationWarning, match=warning_message):
+        assert op.region == "us-west-1"

Reply via email to