This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 05c0841880 custom waiters with dynamic values, applied to appflow 
(#29911)
05c0841880 is described below

commit 05c0841880ccfc25c9e525cafde3e46d7c6f9fce
Author: Raphaël Vandon <[email protected]>
AuthorDate: Tue Mar 21 06:50:39 2023 -0700

    custom waiters with dynamic values, applied to appflow (#29911)
---
 airflow/providers/amazon/aws/hooks/appflow.py      | 28 +++++-------
 airflow/providers/amazon/aws/hooks/base_aws.py     | 28 ++++++++++--
 .../providers/amazon/aws/hooks/batch_waiters.py    |  4 +-
 airflow/providers/amazon/aws/waiters/appflow.json  | 30 ++++++++++++
 tests/providers/amazon/aws/hooks/test_appflow.py   | 52 ++++++++++-----------
 tests/providers/amazon/aws/hooks/test_base_aws.py  | 53 +++++++++++++++++++++-
 .../providers/amazon/aws/operators/test_appflow.py | 16 +++++--
 tests/providers/amazon/aws/waiters/test.json       | 44 ++++++++++++++++++
 8 files changed, 201 insertions(+), 54 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/appflow.py 
b/airflow/providers/amazon/aws/hooks/appflow.py
index 10de8c0030..14dee2ef10 100644
--- a/airflow/providers/amazon/aws/hooks/appflow.py
+++ b/airflow/providers/amazon/aws/hooks/appflow.py
@@ -16,8 +16,6 @@
 # under the License.
 from __future__ import annotations
 
-import json
-from time import sleep
 from typing import TYPE_CHECKING
 
 from airflow.compat.functools import cached_property
@@ -64,24 +62,20 @@ class AppflowHook(AwsBaseHook):
         self.log.info("executionId: %s", execution_id)
 
         if wait_for_completion:
-            last_execs: dict = {}
-            self.log.info("Waiting for flow run to complete...")
-            while (
-                execution_id not in last_execs or 
last_execs[execution_id]["executionStatus"] == "InProgress"
-            ):
-                sleep(poll_interval)
-                # queries the last 20 runs, which should contain ours.
-                response_desc = 
self.conn.describe_flow_execution_records(flowName=flow_name)
-                last_execs = {fe["executionId"]: fe for fe in 
response_desc["flowExecutions"]}
-
-            exec_details = last_execs[execution_id]
-            self.log.info("Run complete, execution details: %s", exec_details)
-
-            if exec_details["executionStatus"] == "Error":
-                raise Exception(f"Flow error:\n{json.dumps(exec_details, 
default=str)}")
+            self.get_waiter("run_complete", {"EXECUTION_ID": 
execution_id}).wait(
+                flowName=flow_name,
+                WaiterConfig={"Delay": poll_interval},
+            )
+            self._log_execution_description(flow_name, execution_id)
 
         return execution_id
 
+    def _log_execution_description(self, flow_name: str, execution_id: str):
+        response_desc = 
self.conn.describe_flow_execution_records(flowName=flow_name)
+        last_execs = {fe["executionId"]: fe for fe in 
response_desc["flowExecutions"]}
+        exec_details = last_execs[execution_id]
+        self.log.info("Run complete, execution details: %s", exec_details)
+
     def update_flow_filter(
         self, flow_name: str, filter_tasks: list[TaskTypeDef], 
set_trigger_ondemand: bool = False
     ) -> None:
diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py 
b/airflow/providers/amazon/aws/hooks/base_aws.py
index d797653476..3395990fc3 100644
--- a/airflow/providers/amazon/aws/hooks/base_aws.py
+++ b/airflow/providers/amazon/aws/hooks/base_aws.py
@@ -40,6 +40,7 @@ from typing import TYPE_CHECKING, Any, Callable, Generic, 
TypeVar, Union
 import boto3
 import botocore
 import botocore.session
+import jinja2
 import requests
 import tenacity
 from botocore.client import ClientMeta
@@ -796,7 +797,7 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
         path = 
Path(__file__).parents[1].joinpath(f"waiters/{self.client_type}.json").resolve()
         return path if path.exists() else None
 
-    def get_waiter(self, waiter_name: str) -> Waiter:
+    def get_waiter(self, waiter_name: str, parameters: dict[str, str] | None = 
None) -> Waiter:
         """
         First checks if there is a custom waiter with the provided waiter_name 
and
         uses that if it exists, otherwise it will check the service client for 
a
@@ -804,17 +805,38 @@ class AwsGenericHook(BaseHook, 
Generic[BaseAwsConnection]):
 
         :param waiter_name: The name of the waiter.  The name should exactly 
match the
             name of the key in the waiter model file (typically this is 
CamelCase).
+        :param parameters: will scan the waiter config for the keys of that 
dict, and replace them with the
+            corresponding value. If a custom waiter has such keys to be 
expanded, they need to be provided
+            here.
         """
         if self.waiter_path and (waiter_name in self._list_custom_waiters()):
             # Technically if waiter_name is in custom_waiters then 
self.waiter_path must
             # exist but MyPy doesn't like the fact that self.waiter_path could 
be None.
             with open(self.waiter_path) as config_file:
-                config = json.load(config_file)
-                return BaseBotoWaiter(client=self.conn, 
model_config=config).waiter(waiter_name)
+                config = json.loads(config_file.read())
+
+            config = self._apply_parameters_value(config, waiter_name, 
parameters)
+            return BaseBotoWaiter(client=self.conn, 
model_config=config).waiter(waiter_name)
         # If there is no custom waiter found for the provided name,
         # then try checking the service's official waiters.
         return self.conn.get_waiter(waiter_name)
 
+    @staticmethod
+    def _apply_parameters_value(config: dict, waiter_name: str, parameters: 
dict[str, str] | None) -> dict:
+        """Replaces potential jinja templates in acceptors definition"""
+        # only process the waiter we're going to use to not raise errors for 
missing params for other waiters.
+        acceptors = config["waiters"][waiter_name]["acceptors"]
+        for a in acceptors:
+            arg = a["argument"]
+            template = jinja2.Template(arg, autoescape=False, 
undefined=jinja2.StrictUndefined)
+            try:
+                a["argument"] = template.render(parameters or {})
+            except jinja2.UndefinedError as e:
+                raise AirflowException(
+                    f"Parameter was not supplied for templated waiter's 
acceptor '{arg}'", e
+                )
+        return config
+
     def list_waiters(self) -> list[str]:
         """Returns a list containing the names of all waiters for the service, 
official and custom."""
         return [*self._list_official_waiters(), *self._list_custom_waiters()]
diff --git a/airflow/providers/amazon/aws/hooks/batch_waiters.py 
b/airflow/providers/amazon/aws/hooks/batch_waiters.py
index 0bbb982e41..dcf111591c 100644
--- a/airflow/providers/amazon/aws/hooks/batch_waiters.py
+++ b/airflow/providers/amazon/aws/hooks/batch_waiters.py
@@ -138,7 +138,7 @@ class BatchWaitersHook(BatchClientHook):
         """
         return self._waiter_model
 
-    def get_waiter(self, waiter_name: str) -> botocore.waiter.Waiter:
+    def get_waiter(self, waiter_name: str, _: dict[str, str] | None = None) -> 
botocore.waiter.Waiter:
         """
         Get an AWS Batch service waiter, using the configured 
``.waiter_model``.
 
@@ -168,6 +168,8 @@ class BatchWaitersHook(BatchClientHook):
             the name (including the casing) of the key name in the waiter
             model file (typically this is CamelCasing); see ``.list_waiters``.
 
+        :param _: unused, just here to match the method signature in base_aws
+
         :return: a waiter object for the named AWS Batch service
         """
         return botocore.waiter.create_waiter_with_client(waiter_name, 
self.waiter_model, self.client)
diff --git a/airflow/providers/amazon/aws/waiters/appflow.json 
b/airflow/providers/amazon/aws/waiters/appflow.json
new file mode 100644
index 0000000000..f45c427467
--- /dev/null
+++ b/airflow/providers/amazon/aws/waiters/appflow.json
@@ -0,0 +1,30 @@
+{
+    "version": 2,
+    "waiters": {
+        "run_complete": {
+            "operation": "DescribeFlowExecutionRecords",
+            "delay": 15,
+            "maxAttempts": 60,
+            "acceptors": [
+                {
+                    "expected": "Successful",
+                    "matcher": "path",
+                    "state": "success",
+                    "argument": 
"flowExecutions[?executionId=='{{EXECUTION_ID}}'].executionStatus"
+                },
+                {
+                    "expected": "Error",
+                    "matcher": "path",
+                    "state": "failure",
+                    "argument": 
"flowExecutions[?executionId=='{{EXECUTION_ID}}'].executionStatus"
+                },
+                {
+                    "expected": true,
+                    "matcher": "path",
+                    "state": "failure",
+                    "argument": 
"length(flowExecutions[?executionId=='{{EXECUTION_ID}}']) > `1`"
+                }
+            ]
+        }
+    }
+}
diff --git a/tests/providers/amazon/aws/hooks/test_appflow.py 
b/tests/providers/amazon/aws/hooks/test_appflow.py
index bbc587f79c..71949ae4b2 100644
--- a/tests/providers/amazon/aws/hooks/test_appflow.py
+++ b/tests/providers/amazon/aws/hooks/test_appflow.py
@@ -35,31 +35,30 @@ AWS_CONN_ID = "aws_default"
 
 @pytest.fixture
 def hook():
-    with 
mock.patch("airflow.providers.amazon.aws.hooks.appflow.AppflowHook.__init__", 
return_value=None):
-        with 
mock.patch("airflow.providers.amazon.aws.hooks.appflow.AppflowHook.conn") as 
mock_conn:
-            mock_conn.describe_flow.return_value = {
-                "sourceFlowConfig": {"connectorType": CONNECTION_TYPE},
-                "tasks": [],
-                "triggerConfig": {"triggerProperties": None},
-                "flowName": FLOW_NAME,
-                "destinationFlowConfigList": {},
-                "lastRunExecutionDetails": {
-                    "mostRecentExecutionStatus": "Successful",
-                    "mostRecentExecutionTime": datetime(3000, 1, 1, 
tzinfo=timezone.utc),
-                },
-            }
-            mock_conn.update_flow.return_value = {}
-            mock_conn.start_flow.return_value = {"executionId": EXECUTION_ID}
-            mock_conn.describe_flow_execution_records.return_value = {
-                "flowExecutions": [
-                    {
-                        "executionId": EXECUTION_ID,
-                        "executionResult": {"recordsProcessed": 1},
-                        "executionStatus": "Successful",
-                    }
-                ]
-            }
-            yield AppflowHook(aws_conn_id=AWS_CONN_ID, region_name=REGION_NAME)
+    with 
mock.patch("airflow.providers.amazon.aws.hooks.appflow.AppflowHook.conn") as 
mock_conn:
+        mock_conn.describe_flow.return_value = {
+            "sourceFlowConfig": {"connectorType": CONNECTION_TYPE},
+            "tasks": [],
+            "triggerConfig": {"triggerProperties": None},
+            "flowName": FLOW_NAME,
+            "destinationFlowConfigList": {},
+            "lastRunExecutionDetails": {
+                "mostRecentExecutionStatus": "Successful",
+                "mostRecentExecutionTime": datetime(3000, 1, 1, 
tzinfo=timezone.utc),
+            },
+        }
+        mock_conn.update_flow.return_value = {}
+        mock_conn.start_flow.return_value = {"executionId": EXECUTION_ID}
+        mock_conn.describe_flow_execution_records.return_value = {
+            "flowExecutions": [
+                {
+                    "executionId": EXECUTION_ID,
+                    "executionResult": {"recordsProcessed": 1},
+                    "executionStatus": "Successful",
+                }
+            ]
+        }
+        yield AppflowHook(aws_conn_id=AWS_CONN_ID, region_name=REGION_NAME)
 
 
 def test_conn_attributes(hook):
@@ -69,7 +68,8 @@ def test_conn_attributes(hook):
 
 
 def test_run_flow(hook):
-    hook.run_flow(flow_name=FLOW_NAME, poll_interval=0)
+    with 
mock.patch("airflow.providers.amazon.aws.waiters.base_waiter.BaseBotoWaiter.waiter"):
+        hook.run_flow(flow_name=FLOW_NAME, poll_interval=0)
     
hook.conn.describe_flow_execution_records.assert_called_with(flowName=FLOW_NAME)
     assert hook.conn.describe_flow_execution_records.call_count == 1
     hook.conn.start_flow.assert_called_once_with(flowName=FLOW_NAME)
diff --git a/tests/providers/amazon/aws/hooks/test_base_aws.py 
b/tests/providers/amazon/aws/hooks/test_base_aws.py
index 4e6087ae50..6540e7b69b 100644
--- a/tests/providers/amazon/aws/hooks/test_base_aws.py
+++ b/tests/providers/amazon/aws/hooks/test_base_aws.py
@@ -21,11 +21,13 @@ import json
 import os
 from base64 import b64encode
 from datetime import datetime, timedelta, timezone
+from pathlib import Path
 from unittest import mock
-from unittest.mock import mock_open
+from unittest.mock import MagicMock, PropertyMock, mock_open
 from uuid import UUID
 
 import boto3
+import jinja2
 import pytest
 from botocore.config import Config
 from botocore.credentials import ReadOnlyCredentials
@@ -34,9 +36,11 @@ from botocore.utils import FileWebIdentityTokenLoader
 from moto import mock_dynamodb, mock_emr, mock_iam, mock_sts
 from moto.core import DEFAULT_ACCOUNT_ID
 
+from airflow import AirflowException
 from airflow.models.connection import Connection
 from airflow.providers.amazon.aws.hooks.base_aws import (
     AwsBaseHook,
+    AwsGenericHook,
     BaseSessionFactory,
     resolve_session_factory,
 )
@@ -47,7 +51,6 @@ MOCK_AWS_CONN_ID = "mock-conn-id"
 MOCK_CONN_TYPE = "aws"
 MOCK_BOTO3_SESSION = mock.MagicMock(return_value="Mock boto3.session.Session")
 
-
 SAML_ASSERTION = """
 <?xml version="1.0"?>
 <samlp:Response xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol" 
ID="_00000000-0000-0000-0000-000000000000" Version="2.0" 
IssueInstant="2012-01-01T12:00:00.000Z" 
Destination="https://signin.aws.amazon.com/saml"; 
Consent="urn:oasis:names:tc:SAML:2.0:consent:unspecified">
@@ -978,3 +981,49 @@ def 
test_raise_no_creds_default_credentials_strategy(tmp_path_factory, monkeypat
         # In normal circumstances lines below should not execute.
         # We want to show additional information why this test not passed
         assert not result, f"Credentials Method: 
{hook.get_session().get_credentials().method}"
+
+
+TEST_WAITER_CONFIG_LOCATION = 
Path(__file__).parents[1].joinpath("waiters/test.json")
+
+
[email protected](AwsGenericHook, "waiter_path", new_callable=PropertyMock)
+def test_waiter_config_params_not_provided(waiter_path_mock: MagicMock, 
caplog):
+    waiter_path_mock.return_value = TEST_WAITER_CONFIG_LOCATION
+    hook = AwsBaseHook(client_type="mwaa")  # needs to be a real client type
+
+    with pytest.raises(AirflowException) as ae:
+        hook.get_waiter("wait_for_test")
+
+    # should warn about missing param
+    assert "PARAM_1" in str(ae.value)
+
+
[email protected](AwsGenericHook, "waiter_path", new_callable=PropertyMock)
+def test_waiter_config_no_params_needed(waiter_path_mock: MagicMock, caplog):
+    waiter_path_mock.return_value = TEST_WAITER_CONFIG_LOCATION
+    hook = AwsBaseHook(client_type="mwaa")  # needs to be a real client type
+
+    with caplog.at_level("WARN"):
+        hook.get_waiter("other_wait")
+
+    # other waiters in the json need params, but not this one, so we shouldn't 
warn about it.
+    assert len(caplog.text) == 0
+
+
[email protected](AwsGenericHook, "waiter_path", new_callable=PropertyMock)
+def test_waiter_config_with_parameters_specified(waiter_path_mock: MagicMock):
+    waiter_path_mock.return_value = TEST_WAITER_CONFIG_LOCATION
+    hook = AwsBaseHook(client_type="mwaa")  # needs to be a real client type
+
+    waiter = hook.get_waiter("wait_for_test", {"PARAM_1": "hello", "PARAM_2": 
"world"})
+
+    assert waiter.config.acceptors[0].argument == "'hello' == 'world'"
+
+
[email protected](AwsGenericHook, "waiter_path", new_callable=PropertyMock)
+def test_waiter_config_param_wrong_format(waiter_path_mock: MagicMock):
+    waiter_path_mock.return_value = TEST_WAITER_CONFIG_LOCATION
+    hook = AwsBaseHook(client_type="mwaa")  # needs to be a real client type
+
+    with pytest.raises(jinja2.TemplateSyntaxError):
+        hook.get_waiter("bad_param_wait")
diff --git a/tests/providers/amazon/aws/operators/test_appflow.py 
b/tests/providers/amazon/aws/operators/test_appflow.py
index 23387d268a..2308810662 100644
--- a/tests/providers/amazon/aws/operators/test_appflow.py
+++ b/tests/providers/amazon/aws/operators/test_appflow.py
@@ -90,6 +90,12 @@ def appflow_conn():
         yield mock_conn
 
 
[email protected]
+def waiter_mock():
+    with 
mock.patch("airflow.providers.amazon.aws.waiters.base_waiter.BaseBotoWaiter.waiter")
 as waiter:
+        yield waiter
+
+
 def run_assertions_base(appflow_conn, tasks):
     appflow_conn.describe_flow.assert_called_with(flowName=FLOW_NAME)
     assert appflow_conn.describe_flow.call_count == 2
@@ -105,7 +111,7 @@ def run_assertions_base(appflow_conn, tasks):
     appflow_conn.start_flow.assert_called_once_with(flowName=FLOW_NAME)
 
 
-def test_run(appflow_conn, ctx):
+def test_run(appflow_conn, ctx, waiter_mock):
     operator = AppflowRunOperator(**DUMP_COMMON_ARGS)
     operator.execute(ctx)  # type: ignore
     appflow_conn.describe_flow.assert_called_once_with(flowName=FLOW_NAME)
@@ -113,13 +119,13 @@ def test_run(appflow_conn, ctx):
     appflow_conn.start_flow.assert_called_once_with(flowName=FLOW_NAME)
 
 
-def test_run_full(appflow_conn, ctx):
+def test_run_full(appflow_conn, ctx, waiter_mock):
     operator = AppflowRunFullOperator(**DUMP_COMMON_ARGS)
     operator.execute(ctx)  # type: ignore
     run_assertions_base(appflow_conn, [])
 
 
-def test_run_after(appflow_conn, ctx):
+def test_run_after(appflow_conn, ctx, waiter_mock):
     operator = AppflowRunAfterOperator(
         source_field="col0", filter_date="2022-05-26T00:00+00:00", 
**DUMP_COMMON_ARGS
     )
@@ -137,7 +143,7 @@ def test_run_after(appflow_conn, ctx):
     )
 
 
-def test_run_before(appflow_conn, ctx):
+def test_run_before(appflow_conn, ctx, waiter_mock):
     operator = AppflowRunBeforeOperator(
         source_field="col0", filter_date="2022-05-26T00:00+00:00", 
**DUMP_COMMON_ARGS
     )
@@ -155,7 +161,7 @@ def test_run_before(appflow_conn, ctx):
     )
 
 
-def test_run_daily(appflow_conn, ctx):
+def test_run_daily(appflow_conn, ctx, waiter_mock):
     operator = AppflowRunDailyOperator(
         source_field="col0", filter_date="2022-05-26T00:00+00:00", 
**DUMP_COMMON_ARGS
     )
diff --git a/tests/providers/amazon/aws/waiters/test.json 
b/tests/providers/amazon/aws/waiters/test.json
new file mode 100644
index 0000000000..30dac9842a
--- /dev/null
+++ b/tests/providers/amazon/aws/waiters/test.json
@@ -0,0 +1,44 @@
+{
+    "version": 2,
+    "waiters": {
+        "wait_for_test": {
+            "operation": "GetEnvironment",
+            "delay": 15,
+            "maxAttempts": 60,
+            "acceptors": [
+                {
+                    "expected": true,
+                    "matcher": "path",
+                    "state": "success",
+                    "argument": "'{{PARAM_1}}' == '{{PARAM_2}}'"
+                }
+            ]
+        },
+        "other_wait": {
+            "operation": "GetEnvironment",
+            "delay": 15,
+            "maxAttempts": 60,
+            "acceptors": [
+                {
+                    "expected": "blah",
+                    "matcher": "path",
+                    "state": "success",
+                    "argument": "blah"
+                }
+            ]
+        },
+        "bad_param_wait": {
+            "operation": "GetEnvironment",
+            "delay": 15,
+            "maxAttempts": 60,
+            "acceptors": [
+                {
+                    "expected": "blah",
+                    "matcher": "path",
+                    "state": "success",
+                    "argument": "{{not a valid jinja template 💀}}"
+                }
+            ]
+        }
+    }
+}

Reply via email to