This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 05c0841880 custom waiters with dynamic values, applied to appflow
(#29911)
05c0841880 is described below
commit 05c0841880ccfc25c9e525cafde3e46d7c6f9fce
Author: Raphaël Vandon <[email protected]>
AuthorDate: Tue Mar 21 06:50:39 2023 -0700
custom waiters with dynamic values, applied to appflow (#29911)
---
airflow/providers/amazon/aws/hooks/appflow.py | 28 +++++-------
airflow/providers/amazon/aws/hooks/base_aws.py | 28 ++++++++++--
.../providers/amazon/aws/hooks/batch_waiters.py | 4 +-
airflow/providers/amazon/aws/waiters/appflow.json | 30 ++++++++++++
tests/providers/amazon/aws/hooks/test_appflow.py | 52 ++++++++++-----------
tests/providers/amazon/aws/hooks/test_base_aws.py | 53 +++++++++++++++++++++-
.../providers/amazon/aws/operators/test_appflow.py | 16 +++++--
tests/providers/amazon/aws/waiters/test.json | 44 ++++++++++++++++++
8 files changed, 201 insertions(+), 54 deletions(-)
diff --git a/airflow/providers/amazon/aws/hooks/appflow.py
b/airflow/providers/amazon/aws/hooks/appflow.py
index 10de8c0030..14dee2ef10 100644
--- a/airflow/providers/amazon/aws/hooks/appflow.py
+++ b/airflow/providers/amazon/aws/hooks/appflow.py
@@ -16,8 +16,6 @@
# under the License.
from __future__ import annotations
-import json
-from time import sleep
from typing import TYPE_CHECKING
from airflow.compat.functools import cached_property
@@ -64,24 +62,20 @@ class AppflowHook(AwsBaseHook):
self.log.info("executionId: %s", execution_id)
if wait_for_completion:
- last_execs: dict = {}
- self.log.info("Waiting for flow run to complete...")
- while (
- execution_id not in last_execs or
last_execs[execution_id]["executionStatus"] == "InProgress"
- ):
- sleep(poll_interval)
- # queries the last 20 runs, which should contain ours.
- response_desc =
self.conn.describe_flow_execution_records(flowName=flow_name)
- last_execs = {fe["executionId"]: fe for fe in
response_desc["flowExecutions"]}
-
- exec_details = last_execs[execution_id]
- self.log.info("Run complete, execution details: %s", exec_details)
-
- if exec_details["executionStatus"] == "Error":
- raise Exception(f"Flow error:\n{json.dumps(exec_details,
default=str)}")
+ self.get_waiter("run_complete", {"EXECUTION_ID":
execution_id}).wait(
+ flowName=flow_name,
+ WaiterConfig={"Delay": poll_interval},
+ )
+ self._log_execution_description(flow_name, execution_id)
return execution_id
+ def _log_execution_description(self, flow_name: str, execution_id: str):
+ response_desc =
self.conn.describe_flow_execution_records(flowName=flow_name)
+ last_execs = {fe["executionId"]: fe for fe in
response_desc["flowExecutions"]}
+ exec_details = last_execs[execution_id]
+ self.log.info("Run complete, execution details: %s", exec_details)
+
def update_flow_filter(
self, flow_name: str, filter_tasks: list[TaskTypeDef],
set_trigger_ondemand: bool = False
) -> None:
diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py
b/airflow/providers/amazon/aws/hooks/base_aws.py
index d797653476..3395990fc3 100644
--- a/airflow/providers/amazon/aws/hooks/base_aws.py
+++ b/airflow/providers/amazon/aws/hooks/base_aws.py
@@ -40,6 +40,7 @@ from typing import TYPE_CHECKING, Any, Callable, Generic,
TypeVar, Union
import boto3
import botocore
import botocore.session
+import jinja2
import requests
import tenacity
from botocore.client import ClientMeta
@@ -796,7 +797,7 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
path =
Path(__file__).parents[1].joinpath(f"waiters/{self.client_type}.json").resolve()
return path if path.exists() else None
- def get_waiter(self, waiter_name: str) -> Waiter:
+ def get_waiter(self, waiter_name: str, parameters: dict[str, str] | None =
None) -> Waiter:
"""
First checks if there is a custom waiter with the provided waiter_name
and
uses that if it exists, otherwise it will check the service client for
a
@@ -804,17 +805,38 @@ class AwsGenericHook(BaseHook,
Generic[BaseAwsConnection]):
:param waiter_name: The name of the waiter. The name should exactly
match the
name of the key in the waiter model file (typically this is
CamelCase).
+ :param parameters: will scan the waiter config for the keys of that
dict, and replace them with the
+ corresponding value. If a custom waiter has such keys to be
expanded, they need to be provided
+ here.
"""
if self.waiter_path and (waiter_name in self._list_custom_waiters()):
# Technically if waiter_name is in custom_waiters then
self.waiter_path must
# exist but MyPy doesn't like the fact that self.waiter_path could
be None.
with open(self.waiter_path) as config_file:
- config = json.load(config_file)
- return BaseBotoWaiter(client=self.conn,
model_config=config).waiter(waiter_name)
+ config = json.loads(config_file.read())
+
+ config = self._apply_parameters_value(config, waiter_name,
parameters)
+ return BaseBotoWaiter(client=self.conn,
model_config=config).waiter(waiter_name)
# If there is no custom waiter found for the provided name,
# then try checking the service's official waiters.
return self.conn.get_waiter(waiter_name)
+ @staticmethod
+ def _apply_parameters_value(config: dict, waiter_name: str, parameters:
dict[str, str] | None) -> dict:
+ """Replaces potential jinja templates in acceptors definition"""
+ # only process the waiter we're going to use to not raise errors for
missing params for other waiters.
+ acceptors = config["waiters"][waiter_name]["acceptors"]
+ for a in acceptors:
+ arg = a["argument"]
+ template = jinja2.Template(arg, autoescape=False,
undefined=jinja2.StrictUndefined)
+ try:
+ a["argument"] = template.render(parameters or {})
+ except jinja2.UndefinedError as e:
+ raise AirflowException(
+ f"Parameter was not supplied for templated waiter's
acceptor '{arg}'", e
+ )
+ return config
+
def list_waiters(self) -> list[str]:
"""Returns a list containing the names of all waiters for the service,
official and custom."""
return [*self._list_official_waiters(), *self._list_custom_waiters()]
diff --git a/airflow/providers/amazon/aws/hooks/batch_waiters.py
b/airflow/providers/amazon/aws/hooks/batch_waiters.py
index 0bbb982e41..dcf111591c 100644
--- a/airflow/providers/amazon/aws/hooks/batch_waiters.py
+++ b/airflow/providers/amazon/aws/hooks/batch_waiters.py
@@ -138,7 +138,7 @@ class BatchWaitersHook(BatchClientHook):
"""
return self._waiter_model
- def get_waiter(self, waiter_name: str) -> botocore.waiter.Waiter:
+ def get_waiter(self, waiter_name: str, _: dict[str, str] | None = None) ->
botocore.waiter.Waiter:
"""
Get an AWS Batch service waiter, using the configured
``.waiter_model``.
@@ -168,6 +168,8 @@ class BatchWaitersHook(BatchClientHook):
the name (including the casing) of the key name in the waiter
model file (typically this is CamelCasing); see ``.list_waiters``.
+ :param _: unused, just here to match the method signature in base_aws
+
:return: a waiter object for the named AWS Batch service
"""
return botocore.waiter.create_waiter_with_client(waiter_name,
self.waiter_model, self.client)
diff --git a/airflow/providers/amazon/aws/waiters/appflow.json
b/airflow/providers/amazon/aws/waiters/appflow.json
new file mode 100644
index 0000000000..f45c427467
--- /dev/null
+++ b/airflow/providers/amazon/aws/waiters/appflow.json
@@ -0,0 +1,30 @@
+{
+ "version": 2,
+ "waiters": {
+ "run_complete": {
+ "operation": "DescribeFlowExecutionRecords",
+ "delay": 15,
+ "maxAttempts": 60,
+ "acceptors": [
+ {
+ "expected": "Successful",
+ "matcher": "path",
+ "state": "success",
+ "argument":
"flowExecutions[?executionId=='{{EXECUTION_ID}}'].executionStatus"
+ },
+ {
+ "expected": "Error",
+ "matcher": "path",
+ "state": "failure",
+ "argument":
"flowExecutions[?executionId=='{{EXECUTION_ID}}'].executionStatus"
+ },
+ {
+ "expected": true,
+ "matcher": "path",
+ "state": "failure",
+ "argument":
"length(flowExecutions[?executionId=='{{EXECUTION_ID}}']) > `1`"
+ }
+ ]
+ }
+ }
+}
diff --git a/tests/providers/amazon/aws/hooks/test_appflow.py
b/tests/providers/amazon/aws/hooks/test_appflow.py
index bbc587f79c..71949ae4b2 100644
--- a/tests/providers/amazon/aws/hooks/test_appflow.py
+++ b/tests/providers/amazon/aws/hooks/test_appflow.py
@@ -35,31 +35,30 @@ AWS_CONN_ID = "aws_default"
@pytest.fixture
def hook():
- with
mock.patch("airflow.providers.amazon.aws.hooks.appflow.AppflowHook.__init__",
return_value=None):
- with
mock.patch("airflow.providers.amazon.aws.hooks.appflow.AppflowHook.conn") as
mock_conn:
- mock_conn.describe_flow.return_value = {
- "sourceFlowConfig": {"connectorType": CONNECTION_TYPE},
- "tasks": [],
- "triggerConfig": {"triggerProperties": None},
- "flowName": FLOW_NAME,
- "destinationFlowConfigList": {},
- "lastRunExecutionDetails": {
- "mostRecentExecutionStatus": "Successful",
- "mostRecentExecutionTime": datetime(3000, 1, 1,
tzinfo=timezone.utc),
- },
- }
- mock_conn.update_flow.return_value = {}
- mock_conn.start_flow.return_value = {"executionId": EXECUTION_ID}
- mock_conn.describe_flow_execution_records.return_value = {
- "flowExecutions": [
- {
- "executionId": EXECUTION_ID,
- "executionResult": {"recordsProcessed": 1},
- "executionStatus": "Successful",
- }
- ]
- }
- yield AppflowHook(aws_conn_id=AWS_CONN_ID, region_name=REGION_NAME)
+ with
mock.patch("airflow.providers.amazon.aws.hooks.appflow.AppflowHook.conn") as
mock_conn:
+ mock_conn.describe_flow.return_value = {
+ "sourceFlowConfig": {"connectorType": CONNECTION_TYPE},
+ "tasks": [],
+ "triggerConfig": {"triggerProperties": None},
+ "flowName": FLOW_NAME,
+ "destinationFlowConfigList": {},
+ "lastRunExecutionDetails": {
+ "mostRecentExecutionStatus": "Successful",
+ "mostRecentExecutionTime": datetime(3000, 1, 1,
tzinfo=timezone.utc),
+ },
+ }
+ mock_conn.update_flow.return_value = {}
+ mock_conn.start_flow.return_value = {"executionId": EXECUTION_ID}
+ mock_conn.describe_flow_execution_records.return_value = {
+ "flowExecutions": [
+ {
+ "executionId": EXECUTION_ID,
+ "executionResult": {"recordsProcessed": 1},
+ "executionStatus": "Successful",
+ }
+ ]
+ }
+ yield AppflowHook(aws_conn_id=AWS_CONN_ID, region_name=REGION_NAME)
def test_conn_attributes(hook):
@@ -69,7 +68,8 @@ def test_conn_attributes(hook):
def test_run_flow(hook):
- hook.run_flow(flow_name=FLOW_NAME, poll_interval=0)
+ with
mock.patch("airflow.providers.amazon.aws.waiters.base_waiter.BaseBotoWaiter.waiter"):
+ hook.run_flow(flow_name=FLOW_NAME, poll_interval=0)
hook.conn.describe_flow_execution_records.assert_called_with(flowName=FLOW_NAME)
assert hook.conn.describe_flow_execution_records.call_count == 1
hook.conn.start_flow.assert_called_once_with(flowName=FLOW_NAME)
diff --git a/tests/providers/amazon/aws/hooks/test_base_aws.py
b/tests/providers/amazon/aws/hooks/test_base_aws.py
index 4e6087ae50..6540e7b69b 100644
--- a/tests/providers/amazon/aws/hooks/test_base_aws.py
+++ b/tests/providers/amazon/aws/hooks/test_base_aws.py
@@ -21,11 +21,13 @@ import json
import os
from base64 import b64encode
from datetime import datetime, timedelta, timezone
+from pathlib import Path
from unittest import mock
-from unittest.mock import mock_open
+from unittest.mock import MagicMock, PropertyMock, mock_open
from uuid import UUID
import boto3
+import jinja2
import pytest
from botocore.config import Config
from botocore.credentials import ReadOnlyCredentials
@@ -34,9 +36,11 @@ from botocore.utils import FileWebIdentityTokenLoader
from moto import mock_dynamodb, mock_emr, mock_iam, mock_sts
from moto.core import DEFAULT_ACCOUNT_ID
+from airflow import AirflowException
from airflow.models.connection import Connection
from airflow.providers.amazon.aws.hooks.base_aws import (
AwsBaseHook,
+ AwsGenericHook,
BaseSessionFactory,
resolve_session_factory,
)
@@ -47,7 +51,6 @@ MOCK_AWS_CONN_ID = "mock-conn-id"
MOCK_CONN_TYPE = "aws"
MOCK_BOTO3_SESSION = mock.MagicMock(return_value="Mock boto3.session.Session")
-
SAML_ASSERTION = """
<?xml version="1.0"?>
<samlp:Response xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol"
ID="_00000000-0000-0000-0000-000000000000" Version="2.0"
IssueInstant="2012-01-01T12:00:00.000Z"
Destination="https://signin.aws.amazon.com/saml"
Consent="urn:oasis:names:tc:SAML:2.0:consent:unspecified">
@@ -978,3 +981,49 @@ def
test_raise_no_creds_default_credentials_strategy(tmp_path_factory, monkeypat
# In normal circumstances lines below should not execute.
# We want to show additional information why this test not passed
assert not result, f"Credentials Method:
{hook.get_session().get_credentials().method}"
+
+
+TEST_WAITER_CONFIG_LOCATION =
Path(__file__).parents[1].joinpath("waiters/test.json")
+
+
[email protected](AwsGenericHook, "waiter_path", new_callable=PropertyMock)
+def test_waiter_config_params_not_provided(waiter_path_mock: MagicMock,
caplog):
+ waiter_path_mock.return_value = TEST_WAITER_CONFIG_LOCATION
+ hook = AwsBaseHook(client_type="mwaa") # needs to be a real client type
+
+ with pytest.raises(AirflowException) as ae:
+ hook.get_waiter("wait_for_test")
+
+ # should warn about missing param
+ assert "PARAM_1" in str(ae.value)
+
+
[email protected](AwsGenericHook, "waiter_path", new_callable=PropertyMock)
+def test_waiter_config_no_params_needed(waiter_path_mock: MagicMock, caplog):
+ waiter_path_mock.return_value = TEST_WAITER_CONFIG_LOCATION
+ hook = AwsBaseHook(client_type="mwaa") # needs to be a real client type
+
+ with caplog.at_level("WARN"):
+ hook.get_waiter("other_wait")
+
+ # other waiters in the json need params, but not this one, so we shouldn't
warn about it.
+ assert len(caplog.text) == 0
+
+
[email protected](AwsGenericHook, "waiter_path", new_callable=PropertyMock)
+def test_waiter_config_with_parameters_specified(waiter_path_mock: MagicMock):
+ waiter_path_mock.return_value = TEST_WAITER_CONFIG_LOCATION
+ hook = AwsBaseHook(client_type="mwaa") # needs to be a real client type
+
+ waiter = hook.get_waiter("wait_for_test", {"PARAM_1": "hello", "PARAM_2":
"world"})
+
+ assert waiter.config.acceptors[0].argument == "'hello' == 'world'"
+
+
[email protected](AwsGenericHook, "waiter_path", new_callable=PropertyMock)
+def test_waiter_config_param_wrong_format(waiter_path_mock: MagicMock):
+ waiter_path_mock.return_value = TEST_WAITER_CONFIG_LOCATION
+ hook = AwsBaseHook(client_type="mwaa") # needs to be a real client type
+
+ with pytest.raises(jinja2.TemplateSyntaxError):
+ hook.get_waiter("bad_param_wait")
diff --git a/tests/providers/amazon/aws/operators/test_appflow.py
b/tests/providers/amazon/aws/operators/test_appflow.py
index 23387d268a..2308810662 100644
--- a/tests/providers/amazon/aws/operators/test_appflow.py
+++ b/tests/providers/amazon/aws/operators/test_appflow.py
@@ -90,6 +90,12 @@ def appflow_conn():
yield mock_conn
[email protected]
+def waiter_mock():
+ with
mock.patch("airflow.providers.amazon.aws.waiters.base_waiter.BaseBotoWaiter.waiter")
as waiter:
+ yield waiter
+
+
def run_assertions_base(appflow_conn, tasks):
appflow_conn.describe_flow.assert_called_with(flowName=FLOW_NAME)
assert appflow_conn.describe_flow.call_count == 2
@@ -105,7 +111,7 @@ def run_assertions_base(appflow_conn, tasks):
appflow_conn.start_flow.assert_called_once_with(flowName=FLOW_NAME)
-def test_run(appflow_conn, ctx):
+def test_run(appflow_conn, ctx, waiter_mock):
operator = AppflowRunOperator(**DUMP_COMMON_ARGS)
operator.execute(ctx) # type: ignore
appflow_conn.describe_flow.assert_called_once_with(flowName=FLOW_NAME)
@@ -113,13 +119,13 @@ def test_run(appflow_conn, ctx):
appflow_conn.start_flow.assert_called_once_with(flowName=FLOW_NAME)
-def test_run_full(appflow_conn, ctx):
+def test_run_full(appflow_conn, ctx, waiter_mock):
operator = AppflowRunFullOperator(**DUMP_COMMON_ARGS)
operator.execute(ctx) # type: ignore
run_assertions_base(appflow_conn, [])
-def test_run_after(appflow_conn, ctx):
+def test_run_after(appflow_conn, ctx, waiter_mock):
operator = AppflowRunAfterOperator(
source_field="col0", filter_date="2022-05-26T00:00+00:00",
**DUMP_COMMON_ARGS
)
@@ -137,7 +143,7 @@ def test_run_after(appflow_conn, ctx):
)
-def test_run_before(appflow_conn, ctx):
+def test_run_before(appflow_conn, ctx, waiter_mock):
operator = AppflowRunBeforeOperator(
source_field="col0", filter_date="2022-05-26T00:00+00:00",
**DUMP_COMMON_ARGS
)
@@ -155,7 +161,7 @@ def test_run_before(appflow_conn, ctx):
)
-def test_run_daily(appflow_conn, ctx):
+def test_run_daily(appflow_conn, ctx, waiter_mock):
operator = AppflowRunDailyOperator(
source_field="col0", filter_date="2022-05-26T00:00+00:00",
**DUMP_COMMON_ARGS
)
diff --git a/tests/providers/amazon/aws/waiters/test.json
b/tests/providers/amazon/aws/waiters/test.json
new file mode 100644
index 0000000000..30dac9842a
--- /dev/null
+++ b/tests/providers/amazon/aws/waiters/test.json
@@ -0,0 +1,44 @@
+{
+ "version": 2,
+ "waiters": {
+ "wait_for_test": {
+ "operation": "GetEnvironment",
+ "delay": 15,
+ "maxAttempts": 60,
+ "acceptors": [
+ {
+ "expected": true,
+ "matcher": "path",
+ "state": "success",
+ "argument": "'{{PARAM_1}}' == '{{PARAM_2}}'"
+ }
+ ]
+ },
+ "other_wait": {
+ "operation": "GetEnvironment",
+ "delay": 15,
+ "maxAttempts": 60,
+ "acceptors": [
+ {
+ "expected": "blah",
+ "matcher": "path",
+ "state": "success",
+ "argument": "blah"
+ }
+ ]
+ },
+ "bad_param_wait": {
+ "operation": "GetEnvironment",
+ "delay": 15,
+ "maxAttempts": 60,
+ "acceptors": [
+ {
+ "expected": "blah",
+ "matcher": "path",
+ "state": "success",
+ "argument": "{{not a valid jinja template 💀}}"
+ }
+ ]
+ }
+ }
+}