This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 8ed38c1619 Add Appflow system test + improvements (#33614)
8ed38c1619 is described below

commit 8ed38c1619209aaae3cf900ed34d9a3b48a8bf8d
Author: RaphaĆ«l Vandon <[email protected]>
AuthorDate: Tue Aug 22 13:27:47 2023 -0700

    Add Appflow system test + improvements (#33614)
---
 airflow/providers/amazon/aws/hooks/appflow.py      |  24 ++-
 airflow/providers/amazon/aws/operators/appflow.py  |  37 ++--
 .../operators/appflow.rst                          |   4 +-
 .../providers/amazon/aws/operators/test_appflow.py |   3 +-
 .../system/providers/amazon/aws/example_appflow.py |  10 -
 .../providers/amazon/aws/example_appflow_run.py    | 204 +++++++++++++++++++++
 6 files changed, 248 insertions(+), 34 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/appflow.py 
b/airflow/providers/amazon/aws/hooks/appflow.py
index 71e7ddd1c7..c9741278c0 100644
--- a/airflow/providers/amazon/aws/hooks/appflow.py
+++ b/airflow/providers/amazon/aws/hooks/appflow.py
@@ -20,6 +20,7 @@ from functools import cached_property
 from typing import TYPE_CHECKING
 
 from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
+from airflow.providers.amazon.aws.utils.waiter_with_logging import wait
 
 if TYPE_CHECKING:
     from mypy_boto3_appflow.client import AppflowClient
@@ -49,13 +50,20 @@ class AppflowHook(AwsBaseHook):
         """Get the underlying boto3 Appflow client (cached)."""
         return super().conn
 
-    def run_flow(self, flow_name: str, poll_interval: int = 20, 
wait_for_completion: bool = True) -> str:
+    def run_flow(
+        self,
+        flow_name: str,
+        poll_interval: int = 20,
+        wait_for_completion: bool = True,
+        max_attempts: int = 60,
+    ) -> str:
         """
         Execute an AppFlow run.
 
         :param flow_name: The flow name
         :param poll_interval: Time (seconds) to wait between two consecutive 
calls to check the run status
         :param wait_for_completion: whether to wait for the run to end to 
return
+        :param max_attempts: the number of polls to do before timing 
out/returning a failure.
         :return: The run execution ID
         """
         response_start = self.conn.start_flow(flowName=flow_name)
@@ -63,9 +71,17 @@ class AppflowHook(AwsBaseHook):
         self.log.info("executionId: %s", execution_id)
 
         if wait_for_completion:
-            self.get_waiter("run_complete", {"EXECUTION_ID": 
execution_id}).wait(
-                flowName=flow_name,
-                WaiterConfig={"Delay": poll_interval},
+            wait(
+                waiter=self.get_waiter("run_complete", {"EXECUTION_ID": 
execution_id}),
+                waiter_delay=poll_interval,
+                waiter_max_attempts=max_attempts,
+                args={"flowName": flow_name},
+                failure_message="error while waiting for flow to complete",
+                status_message="waiting for flow completion, status",
+                status_args=[
+                    
f"flowExecutions[?executionId=='{execution_id}'].executionStatus",
+                    
f"flowExecutions[?executionId=='{execution_id}'].executionResult.errorInfo",
+                ],
             )
             self._log_execution_description(flow_name, execution_id)
 
diff --git a/airflow/providers/amazon/aws/operators/appflow.py 
b/airflow/providers/amazon/aws/operators/appflow.py
index f2fe75f395..7b7402acde 100644
--- a/airflow/providers/amazon/aws/operators/appflow.py
+++ b/airflow/providers/amazon/aws/operators/appflow.py
@@ -16,12 +16,13 @@
 # under the License.
 from __future__ import annotations
 
+import warnings
 from datetime import datetime, timedelta
 from functools import cached_property
 from time import sleep
 from typing import TYPE_CHECKING, cast
 
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarning
 from airflow.models import BaseOperator
 from airflow.operators.python import ShortCircuitOperator
 from airflow.providers.amazon.aws.hooks.appflow import AppflowHook
@@ -51,6 +52,7 @@ class AppflowBaseOperator(BaseOperator):
     :param source_field: The field name to apply filters
     :param filter_date: The date value (or template) to be used in filters.
     :param poll_interval: how often in seconds to check the query status
+    :param max_attempts: how many times to check for status before timing out
     :param aws_conn_id: aws connection to use
     :param region: aws region to use
     :param wait_for_completion: whether to wait for the run to end to return
@@ -58,29 +60,33 @@ class AppflowBaseOperator(BaseOperator):
 
     ui_color = "#2bccbd"
 
+    template_fields = ("flow_name", "source", "source_field", "filter_date")
+
     UPDATE_PROPAGATION_TIME: int = 15
 
     def __init__(
         self,
-        source: str,
         flow_name: str,
         flow_update: bool,
+        source: str | None = None,
         source_field: str | None = None,
         filter_date: str | None = None,
         poll_interval: int = 20,
+        max_attempts: int = 60,
         aws_conn_id: str = "aws_default",
         region: str | None = None,
         wait_for_completion: bool = True,
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
-        if source not in SUPPORTED_SOURCES:
+        if source is not None and source not in SUPPORTED_SOURCES:
             raise ValueError(f"{source} is not a supported source (options: 
{SUPPORTED_SOURCES})!")
         self.filter_date = filter_date
         self.flow_name = flow_name
         self.source = source
         self.source_field = source_field
         self.poll_interval = poll_interval
+        self.max_attempts = max_attempts
         self.aws_conn_id = aws_conn_id
         self.region = region
         self.flow_update = flow_update
@@ -95,7 +101,8 @@ class AppflowBaseOperator(BaseOperator):
         self.filter_date_parsed: datetime | None = (
             datetime.fromisoformat(self.filter_date) if self.filter_date else 
None
         )
-        self.connector_type = self._get_connector_type()
+        if self.source is not None:
+            self.connector_type = self._get_connector_type()
         if self.flow_update:
             self._update_flow()
             # while schedule flows will pick up the update right away, 
on-demand flows might use out of date
@@ -118,6 +125,7 @@ class AppflowBaseOperator(BaseOperator):
         execution_id = self.hook.run_flow(
             flow_name=self.flow_name,
             poll_interval=self.poll_interval,
+            max_attempts=self.max_attempts,
             wait_for_completion=self.wait_for_completion,
         )
         task_instance = context["task_instance"]
@@ -127,13 +135,13 @@ class AppflowBaseOperator(BaseOperator):
 
 class AppflowRunOperator(AppflowBaseOperator):
     """
-    Execute a Appflow run with filters as is.
+    Execute a Appflow run as is.
 
     .. seealso::
         For more information on how to use this operator, take a look at the 
guide:
         :ref:`howto/operator:AppflowRunOperator`
 
-    :param source: The source name (Supported: salesforce, zendesk)
+    :param source: Obsolete, unnecessary for this operator
     :param flow_name: The flow name
     :param poll_interval: how often in seconds to check the query status
     :param aws_conn_id: aws connection to use
@@ -143,18 +151,21 @@ class AppflowRunOperator(AppflowBaseOperator):
 
     def __init__(
         self,
-        source: str,
         flow_name: str,
+        source: str | None = None,
         poll_interval: int = 20,
         aws_conn_id: str = "aws_default",
         region: str | None = None,
         wait_for_completion: bool = True,
         **kwargs,
     ) -> None:
-        if source not in {"salesforce", "zendesk"}:
-            raise ValueError(NOT_SUPPORTED_SOURCE_MSG.format(source=source, 
entity="AppflowRunOperator"))
+        if source is not None:
+            warnings.warn(
+                "The `source` parameter is unused when simply running a flow, 
please remove it.",
+                AirflowProviderDeprecationWarning,
+                stacklevel=2,
+            )
         super().__init__(
-            source=source,
             flow_name=flow_name,
             flow_update=False,
             source_field=None,
@@ -227,8 +238,6 @@ class AppflowRunBeforeOperator(AppflowBaseOperator):
     :param wait_for_completion: whether to wait for the run to end to return
     """
 
-    template_fields = ("filter_date",)
-
     def __init__(
         self,
         source: str,
@@ -297,8 +306,6 @@ class AppflowRunAfterOperator(AppflowBaseOperator):
     :param wait_for_completion: whether to wait for the run to end to return
     """
 
-    template_fields = ("filter_date",)
-
     def __init__(
         self,
         source: str,
@@ -365,8 +372,6 @@ class AppflowRunDailyOperator(AppflowBaseOperator):
     :param wait_for_completion: whether to wait for the run to end to return
     """
 
-    template_fields = ("filter_date",)
-
     def __init__(
         self,
         source: str,
diff --git a/docs/apache-airflow-providers-amazon/operators/appflow.rst 
b/docs/apache-airflow-providers-amazon/operators/appflow.rst
index 8ef7e9c288..14d3a5e5ca 100644
--- a/docs/apache-airflow-providers-amazon/operators/appflow.rst
+++ b/docs/apache-airflow-providers-amazon/operators/appflow.rst
@@ -41,10 +41,10 @@ Operators
 Run Flow
 ========
 
-To run an AppFlow flow keeping all filters as is, use:
+To run an AppFlow flow keeping as is, use:
 :class:`~airflow.providers.amazon.aws.operators.appflow.AppflowRunOperator`.
 
-.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_appflow.py
+.. exampleinclude:: 
/../../tests/system/providers/amazon/aws/example_appflow_run.py
     :language: python
     :dedent: 4
     :start-after: [START howto_operator_appflow_run]
diff --git a/tests/providers/amazon/aws/operators/test_appflow.py 
b/tests/providers/amazon/aws/operators/test_appflow.py
index 2308810662..ce36faad5f 100644
--- a/tests/providers/amazon/aws/operators/test_appflow.py
+++ b/tests/providers/amazon/aws/operators/test_appflow.py
@@ -114,9 +114,8 @@ def run_assertions_base(appflow_conn, tasks):
 def test_run(appflow_conn, ctx, waiter_mock):
     operator = AppflowRunOperator(**DUMP_COMMON_ARGS)
     operator.execute(ctx)  # type: ignore
-    appflow_conn.describe_flow.assert_called_once_with(flowName=FLOW_NAME)
-    appflow_conn.describe_flow_execution_records.assert_called_once()
     appflow_conn.start_flow.assert_called_once_with(flowName=FLOW_NAME)
+    appflow_conn.describe_flow_execution_records.assert_called_once()
 
 
 def test_run_full(appflow_conn, ctx, waiter_mock):
diff --git a/tests/system/providers/amazon/aws/example_appflow.py 
b/tests/system/providers/amazon/aws/example_appflow.py
index 4469c0290b..8a6458aa56 100644
--- a/tests/system/providers/amazon/aws/example_appflow.py
+++ b/tests/system/providers/amazon/aws/example_appflow.py
@@ -27,7 +27,6 @@ from airflow.providers.amazon.aws.operators.appflow import (
     AppflowRunBeforeOperator,
     AppflowRunDailyOperator,
     AppflowRunFullOperator,
-    AppflowRunOperator,
 )
 from tests.system.providers.amazon.aws.utils import SystemTestContextBuilder
 
@@ -48,14 +47,6 @@ with DAG(
     source_name = "salesforce"
     flow_name = f"{env_id}-salesforce-campaign"
 
-    # [START howto_operator_appflow_run]
-    campaign_dump = AppflowRunOperator(
-        task_id="campaign_dump",
-        source=source_name,
-        flow_name=flow_name,
-    )
-    # [END howto_operator_appflow_run]
-
     # [START howto_operator_appflow_run_full]
     campaign_dump_full = AppflowRunFullOperator(
         task_id="campaign_dump_full",
@@ -111,7 +102,6 @@ with DAG(
         # TEST SETUP
         test_context,
         # TEST BODY
-        campaign_dump,
         campaign_dump_full,
         campaign_dump_daily,
         campaign_dump_before,
diff --git a/tests/system/providers/amazon/aws/example_appflow_run.py 
b/tests/system/providers/amazon/aws/example_appflow_run.py
new file mode 100644
index 0000000000..dacc549815
--- /dev/null
+++ b/tests/system/providers/amazon/aws/example_appflow_run.py
@@ -0,0 +1,204 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import json
+from datetime import datetime
+
+import boto3
+
+from airflow import DAG
+from airflow.decorators import task
+from airflow.models.baseoperator import chain
+from airflow.providers.amazon.aws.operators.appflow import (
+    AppflowRunOperator,
+)
+from airflow.providers.amazon.aws.operators.s3 import (
+    S3CreateBucketOperator,
+    S3CreateObjectOperator,
+    S3DeleteBucketOperator,
+)
+from airflow.utils.trigger_rule import TriggerRule
+from tests.system.providers.amazon.aws.utils import SystemTestContextBuilder
+
+sys_test_context_task = SystemTestContextBuilder().build()
+
+DAG_ID = "example_appflow_run"
+
+
+@task
+def create_s3_to_s3_flow(flow_name: str, bucket_name: str, source_folder: str):
+    """creates a flow that takes a CSV and converts it to a json containing 
the same data"""
+    client = boto3.client("appflow")
+    client.create_flow(
+        flowName=flow_name,
+        triggerConfig={"triggerType": "OnDemand"},
+        sourceFlowConfig={
+            "connectorType": "S3",
+            "sourceConnectorProperties": {
+                "S3": {
+                    "bucketName": bucket_name,
+                    "bucketPrefix": source_folder,
+                    "s3InputFormatConfig": {"s3InputFileType": "CSV"},
+                },
+            },
+        },
+        destinationFlowConfigList=[
+            {
+                "connectorType": "S3",
+                "destinationConnectorProperties": {
+                    "S3": {
+                        "bucketName": bucket_name,
+                        "s3OutputFormatConfig": {
+                            "fileType": "JSON",
+                            "aggregationConfig": {
+                                "aggregationType": "None",
+                            },
+                        },
+                    }
+                },
+            },
+        ],
+        tasks=[
+            {
+                "sourceFields": ["col1", "col2"],
+                "connectorOperator": {"S3": "PROJECTION"},
+                "taskType": "Filter",
+            },
+            {
+                "sourceFields": ["col1"],
+                "connectorOperator": {"S3": "NO_OP"},
+                "destinationField": "col1",
+                "taskType": "Map",
+                "taskProperties": {"DESTINATION_DATA_TYPE": "string", 
"SOURCE_DATA_TYPE": "string"},
+            },
+            {
+                "sourceFields": ["col2"],
+                "connectorOperator": {"S3": "NO_OP"},
+                "destinationField": "col2",
+                "taskType": "Map",
+                "taskProperties": {"DESTINATION_DATA_TYPE": "string", 
"SOURCE_DATA_TYPE": "string"},
+            },
+        ],
+    )
+
+
+@task
+def setup_bucket_permissions(bucket_name):
+    s3 = boto3.client("s3")
+    s3.put_bucket_policy(
+        Bucket=bucket_name,
+        Policy=json.dumps(
+            {
+                "Version": "2008-10-17",
+                "Statement": [
+                    {
+                        "Sid": "AllowAppFlowSourceActions",
+                        "Effect": "Allow",
+                        "Principal": {"Service": "appflow.amazonaws.com"},
+                        "Action": ["s3:ListBucket", "s3:GetObject"],
+                        "Resource": [f"arn:aws:s3:::{bucket_name}", 
f"arn:aws:s3:::{bucket_name}/*"],
+                    },
+                    {
+                        "Sid": "AllowAppFlowDestinationActions",
+                        "Effect": "Allow",
+                        "Principal": {"Service": "appflow.amazonaws.com"},
+                        "Action": [
+                            "s3:PutObject",
+                            "s3:AbortMultipartUpload",
+                            "s3:ListMultipartUploadParts",
+                            "s3:ListBucketMultipartUploads",
+                            "s3:GetBucketAcl",
+                            "s3:PutObjectAcl",
+                        ],
+                        "Resource": [f"arn:aws:s3:::{bucket_name}", 
f"arn:aws:s3:::{bucket_name}/*"],
+                    },
+                ],
+            }
+        ),
+    )
+
+
+@task(trigger_rule=TriggerRule.ALL_DONE)
+def delete_flow(flow_name: str):
+    client = boto3.client("appflow")
+    client.delete_flow(flowName=flow_name, forceDelete=True)
+
+
+with DAG(
+    dag_id=DAG_ID,
+    schedule="@once",
+    start_date=datetime(2022, 1, 1),
+    catchup=False,
+    tags=["example"],
+) as dag:
+    test_context = sys_test_context_task()
+    env_id = test_context["ENV_ID"]
+
+    flow_name = f"{env_id}-flow"
+    bucket_name = f"{env_id}-for-appflow"
+    source_folder = "source"
+
+    create_bucket = S3CreateBucketOperator(task_id="create_bucket", 
bucket_name=bucket_name)
+
+    upload_csv = S3CreateObjectOperator(
+        task_id="upload_csv",
+        s3_bucket=bucket_name,
+        s3_key="source_folder/data.csv",
+        data="""col1,col2\n"data1","data2"\n""",
+        replace=True,
+    )
+
+    # [START howto_operator_appflow_run]
+    run_flow = AppflowRunOperator(
+        task_id="run_flow",
+        flow_name=flow_name,
+    )
+    # [END howto_operator_appflow_run]
+    run_flow.poll_interval = 1
+
+    delete_bucket = S3DeleteBucketOperator(
+        task_id="delete_bucket",
+        trigger_rule=TriggerRule.ALL_DONE,
+        bucket_name=bucket_name,
+        force_delete=True,
+    )
+
+    chain(
+        # TEST SETUP
+        test_context,
+        create_bucket,
+        setup_bucket_permissions(bucket_name),
+        upload_csv,
+        create_s3_to_s3_flow(flow_name, bucket_name, source_folder),
+        # TEST BODY
+        run_flow,
+        # TEARDOWN
+        delete_flow(flow_name),
+        delete_bucket,
+    )
+
+    from tests.system.utils.watcher import watcher
+
+    # This test needs watcher in order to properly mark success/failure
+    # when "tearDown" task with trigger rule is part of the DAG
+    list(dag.tasks) >> watcher()
+
+from tests.system.utils import get_test_run  # noqa: E402
+
+# Needed to run the example DAG with pytest (see: 
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)

Reply via email to