This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 6fd8f36893 Unit tests: Reorganize AWS custom waiter (#38785)
6fd8f36893 is described below
commit 6fd8f368937711b4edd3b8b0ecf08080afee27c1
Author: D. Ferruzzi <[email protected]>
AuthorDate: Sat Apr 6 01:59:45 2024 -0700
Unit tests: Reorganize AWS custom waiter (#38785)
---
tests/providers/amazon/aws/waiters/test_batch.py | 78 +++++
.../amazon/aws/waiters/test_custom_waiters.py | 341 ---------------------
tests/providers/amazon/aws/waiters/test_dynamo.py | 82 +++++
tests/providers/amazon/aws/waiters/test_ecs.py | 139 +++++++++
tests/providers/amazon/aws/waiters/test_eks.py | 56 ++++
tests/providers/amazon/aws/waiters/test_emr.py | 95 ++++++
6 files changed, 450 insertions(+), 341 deletions(-)
diff --git a/tests/providers/amazon/aws/waiters/test_batch.py
b/tests/providers/amazon/aws/waiters/test_batch.py
new file mode 100644
index 0000000000..f71caf9eda
--- /dev/null
+++ b/tests/providers/amazon/aws/waiters/test_batch.py
@@ -0,0 +1,78 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+
+import boto3
+import pytest
+from botocore.exceptions import WaiterError
+
+from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
+
+
+class TestCustomBatchServiceWaiters:
+ JOB_ID = "test_job_id"
+
+ @pytest.fixture(autouse=True)
+ def setup_test_cases(self, monkeypatch):
+ self.client = boto3.client("batch", region_name="eu-west-3")
+ monkeypatch.setattr(BatchClientHook, "conn", self.client)
+
+ @pytest.fixture
+ def mock_describe_jobs(self):
+ """Mock ``BatchClientHook.Client.describe_jobs`` method."""
+ with mock.patch.object(self.client, "describe_jobs") as m:
+ yield m
+
+ def test_service_waiters(self):
+ hook_waiters = BatchClientHook(aws_conn_id=None).list_waiters()
+ assert "batch_job_complete" in hook_waiters
+
+ @staticmethod
+ def describe_jobs(status: str):
+ """
+ Helper function for generate minimal DescribeJobs response for a
single job.
+
https://docs.aws.amazon.com/batch/latest/APIReference/API_DescribeJobs.html
+ """
+ return {
+ "jobs": [
+ {
+ "status": status,
+ },
+ ],
+ }
+
+ def test_job_succeeded(self, mock_describe_jobs):
+ """Test job succeeded"""
+ mock_describe_jobs.side_effect = [
+ self.describe_jobs(BatchClientHook.RUNNING_STATE),
+ self.describe_jobs(BatchClientHook.SUCCESS_STATE),
+ ]
+ waiter =
BatchClientHook(aws_conn_id=None).get_waiter("batch_job_complete")
+ waiter.wait(jobs=[self.JOB_ID], WaiterConfig={"Delay": 0.01,
"MaxAttempts": 2})
+
+ def test_job_failed(self, mock_describe_jobs):
+ """Test job failed"""
+ mock_describe_jobs.side_effect = [
+ self.describe_jobs(BatchClientHook.RUNNING_STATE),
+ self.describe_jobs(BatchClientHook.FAILURE_STATE),
+ ]
+ waiter =
BatchClientHook(aws_conn_id=None).get_waiter("batch_job_complete")
+
+ with pytest.raises(WaiterError, match="Waiter encountered a terminal
failure state"):
+ waiter.wait(jobs=[self.JOB_ID], WaiterConfig={"Delay": 0.01,
"MaxAttempts": 2})
diff --git a/tests/providers/amazon/aws/waiters/test_custom_waiters.py
b/tests/providers/amazon/aws/waiters/test_custom_waiters.py
index 3272d603c4..82e0594337 100644
--- a/tests/providers/amazon/aws/waiters/test_custom_waiters.py
+++ b/tests/providers/amazon/aws/waiters/test_custom_waiters.py
@@ -17,22 +17,13 @@
from __future__ import annotations
-import json
-from typing import Sequence
from unittest import mock
import boto3
import pytest
-from botocore.exceptions import WaiterError
from botocore.waiter import WaiterModel
-from moto import mock_aws
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
-from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
-from airflow.providers.amazon.aws.hooks.dynamodb import DynamoDBHook
-from airflow.providers.amazon.aws.hooks.ecs import EcsClusterStates, EcsHook,
EcsTaskDefinitionStates
-from airflow.providers.amazon.aws.hooks.eks import EksHook
-from airflow.providers.amazon.aws.hooks.emr import EmrHook
from airflow.providers.amazon.aws.waiters.base_waiter import BaseBotoWaiter
@@ -92,335 +83,3 @@ class TestBaseWaiter:
with mock.patch("botocore.client.BaseClient.get_waiter") as m:
hook.get_waiter(waiter_name="FooBar")
m.assert_called_once_with("FooBar")
-
-
-class TestCustomEKSServiceWaiters:
- def test_service_waiters(self):
- hook = EksHook()
- with open(hook.waiter_path) as config_file:
- expected_waiters = json.load(config_file)["waiters"]
-
- for waiter in list(expected_waiters.keys()):
- assert waiter in hook.list_waiters()
- assert waiter in hook._list_custom_waiters()
-
- @mock_aws
- def test_existing_waiter_inherited(self):
- """
- AwsBaseHook::get_waiter will first check if there is a custom waiter
with the
- provided name and pass that through is it exists, otherwise it will
check the
- custom waiters for the given service. This test checks to make sure
that the
- waiter is the same whichever way you get it and no modifications are
made.
- """
- hook_waiter = EksHook().get_waiter("cluster_active")
- client_waiter = EksHook().conn.get_waiter("cluster_active")
- boto_waiter = boto3.client("eks").get_waiter("cluster_active")
-
- assert_all_match(hook_waiter.name, client_waiter.name,
boto_waiter.name)
- assert_all_match(len(hook_waiter.__dict__),
len(client_waiter.__dict__), len(boto_waiter.__dict__))
- for attr in hook_waiter.__dict__:
- # Not all attributes in a Waiter are directly comparable
- # so the best we can do it make sure the same attrs exist.
- assert hasattr(boto_waiter, attr)
- assert hasattr(client_waiter, attr)
-
-
-class TestCustomECSServiceWaiters:
- """Test waiters from ``amazon/aws/waiters/ecs.json``."""
-
- @pytest.fixture(autouse=True)
- def setup_test_cases(self, monkeypatch):
- self.client = boto3.client("ecs", region_name="eu-west-3")
- monkeypatch.setattr(EcsHook, "conn", self.client)
-
- @pytest.fixture
- def mock_describe_clusters(self):
- """Mock ``ECS.Client.describe_clusters`` method."""
- with mock.patch.object(self.client, "describe_clusters") as m:
- yield m
-
- @pytest.fixture
- def mock_describe_task_definition(self):
- """Mock ``ECS.Client.describe_task_definition`` method."""
- with mock.patch.object(self.client, "describe_task_definition") as m:
- yield m
-
- def test_service_waiters(self):
- hook_waiters = EcsHook(aws_conn_id=None).list_waiters()
- assert "cluster_active" in hook_waiters
- assert "cluster_inactive" in hook_waiters
-
- @staticmethod
- def describe_clusters(
- status: str | EcsClusterStates, cluster_name: str = "spam-egg",
failures: dict | list | None = None
- ):
- """
- Helper function for generate minimal DescribeClusters response for
single job.
-
https://docs.aws.amazon.com/AmazonECS/latest/APIReference/API_DescribeClusters.html
- """
- if isinstance(status, EcsClusterStates):
- status = status.value
- else:
- assert status in EcsClusterStates.__members__.values()
-
- failures = failures or []
- if isinstance(failures, dict):
- failures = [failures]
-
- return {"clusters": [{"clusterName": cluster_name, "status": status}],
"failures": failures}
-
- def test_cluster_active(self, mock_describe_clusters):
- """Test cluster reach Active state during creation."""
- mock_describe_clusters.side_effect = [
- self.describe_clusters(EcsClusterStates.DEPROVISIONING),
- self.describe_clusters(EcsClusterStates.PROVISIONING),
- self.describe_clusters(EcsClusterStates.ACTIVE),
- ]
- waiter = EcsHook(aws_conn_id=None).get_waiter("cluster_active")
- waiter.wait(clusters=["spam-egg"], WaiterConfig={"Delay": 0.01,
"MaxAttempts": 3})
-
- @pytest.mark.parametrize("state", ["FAILED", "INACTIVE"])
- def test_cluster_active_failure_states(self, mock_describe_clusters,
state):
- """Test cluster reach inactive state during creation."""
- mock_describe_clusters.side_effect = [
- self.describe_clusters(EcsClusterStates.PROVISIONING),
- self.describe_clusters(state),
- ]
- waiter = EcsHook(aws_conn_id=None).get_waiter("cluster_active")
- with pytest.raises(WaiterError, match=f'matched expected path:
"{state}"'):
- waiter.wait(clusters=["spam-egg"], WaiterConfig={"Delay": 0.01,
"MaxAttempts": 3})
-
- def test_cluster_active_failure_reasons(self, mock_describe_clusters):
- """Test cluster reach failure state during creation."""
- mock_describe_clusters.side_effect = [
- self.describe_clusters(EcsClusterStates.PROVISIONING),
- self.describe_clusters(EcsClusterStates.PROVISIONING,
failures={"reason": "MISSING"}),
- ]
- waiter = EcsHook(aws_conn_id=None).get_waiter("cluster_active")
- with pytest.raises(WaiterError, match='matched expected path:
"MISSING"'):
- waiter.wait(clusters=["spam-egg"], WaiterConfig={"Delay": 0.01,
"MaxAttempts": 3})
-
- def test_cluster_inactive(self, mock_describe_clusters):
- """Test cluster reach Inactive state during deletion."""
- mock_describe_clusters.side_effect = [
- self.describe_clusters(EcsClusterStates.ACTIVE),
- self.describe_clusters(EcsClusterStates.ACTIVE),
- self.describe_clusters(EcsClusterStates.INACTIVE),
- ]
- waiter = EcsHook(aws_conn_id=None).get_waiter("cluster_inactive")
- waiter.wait(clusters=["spam-egg"], WaiterConfig={"Delay": 0.01,
"MaxAttempts": 3})
-
- def test_cluster_inactive_failure_reasons(self, mock_describe_clusters):
- """Test cluster reach failure state during deletion."""
- mock_describe_clusters.side_effect = [
- self.describe_clusters(EcsClusterStates.ACTIVE),
- self.describe_clusters(EcsClusterStates.DEPROVISIONING),
- self.describe_clusters(EcsClusterStates.DEPROVISIONING,
failures={"reason": "MISSING"}),
- ]
- waiter = EcsHook(aws_conn_id=None).get_waiter("cluster_inactive")
- waiter.wait(clusters=["spam-egg"], WaiterConfig={"Delay": 0.01,
"MaxAttempts": 3})
-
- @staticmethod
- def describe_task_definition(status: str | EcsTaskDefinitionStates,
task_definition: str = "spam-egg"):
- """
- Helper function for generate minimal DescribeTaskDefinition response
for single job.
-
https://docs.aws.amazon.com/AmazonECS/latest/APIReference/API_DescribeTaskDefinition.html
- """
- if isinstance(status, EcsTaskDefinitionStates):
- status = status.value
- else:
- assert status in EcsTaskDefinitionStates.__members__.values()
-
- return {
- "taskDefinition": {
- "taskDefinitionArn": (
-
f"arn:aws:ecs:eu-west-3:123456789012:task-definition/{task_definition}:42"
- ),
- "status": status,
- }
- }
-
-
-class TestCustomDynamoDBServiceWaiters:
- """Test waiters from ``amazon/aws/waiters/dynamodb.json``."""
-
- STATUS_COMPLETED = "COMPLETED"
- STATUS_FAILED = "FAILED"
- STATUS_IN_PROGRESS = "IN_PROGRESS"
-
- @pytest.fixture(autouse=True)
- def setup_test_cases(self, monkeypatch):
- self.resource = boto3.resource("dynamodb", region_name="eu-west-3")
- monkeypatch.setattr(DynamoDBHook, "conn", self.resource)
- self.client = self.resource.meta.client
-
- @pytest.fixture
- def mock_describe_export(self):
- """Mock ``DynamoDBHook.Client.describe_export`` method."""
- with mock.patch.object(self.client, "describe_export") as m:
- yield m
-
- def test_service_waiters(self):
- hook_waiters = DynamoDBHook(aws_conn_id=None).list_waiters()
- assert "export_table" in hook_waiters
-
- @staticmethod
- def describe_export(status: str):
- """
- Helper function for generate minimal DescribeExport response for
single job.
-
https://docs.aws.amazon.com/amazondynamodb/latest/APIReference/API_DescribeExport.html
- """
- return {"ExportDescription": {"ExportStatus": status}}
-
- def test_export_table_to_point_in_time_completed(self,
mock_describe_export):
- """Test state transition from `in progress` to `completed` during
init."""
- waiter = DynamoDBHook(aws_conn_id=None).get_waiter("export_table")
- mock_describe_export.side_effect = [
- self.describe_export(self.STATUS_IN_PROGRESS),
- self.describe_export(self.STATUS_COMPLETED),
- ]
- waiter.wait(
-
ExportArn="LoremIpsumissimplydummytextoftheprintingandtypesettingindustry",
- WaiterConfig={"Delay": 0.01, "MaxAttempts": 3},
- )
-
- def test_export_table_to_point_in_time_failed(self, mock_describe_export):
- """Test state transition from `in progress` to `failed` during init."""
- with mock.patch("boto3.client") as client:
- client.return_value = self.client
- mock_describe_export.side_effect = [
- self.describe_export(self.STATUS_IN_PROGRESS),
- self.describe_export(self.STATUS_FAILED),
- ]
- waiter = DynamoDBHook(aws_conn_id=None).get_waiter("export_table",
client=self.client)
- with pytest.raises(WaiterError, match='we matched expected path:
"FAILED"'):
- waiter.wait(
-
ExportArn="LoremIpsumissimplydummytextoftheprintingandtypesettingindustry",
- WaiterConfig={"Delay": 0.01, "MaxAttempts": 3},
- )
-
-
-class TestCustomBatchServiceWaiters:
- """Test waiters from ``amazon/aws/waiters/batch.json``."""
-
- JOB_ID = "test_job_id"
-
- @pytest.fixture(autouse=True)
- def setup_test_cases(self, monkeypatch):
- self.client = boto3.client("batch", region_name="eu-west-3")
- monkeypatch.setattr(BatchClientHook, "conn", self.client)
-
- @pytest.fixture
- def mock_describe_jobs(self):
- """Mock ``BatchClientHook.Client.describe_jobs`` method."""
- with mock.patch.object(self.client, "describe_jobs") as m:
- yield m
-
- def test_service_waiters(self):
- hook_waiters = BatchClientHook(aws_conn_id=None).list_waiters()
- assert "batch_job_complete" in hook_waiters
-
- @staticmethod
- def describe_jobs(status: str):
- """
- Helper function for generate minimal DescribeJobs response for a
single job.
-
https://docs.aws.amazon.com/batch/latest/APIReference/API_DescribeJobs.html
- """
- return {
- "jobs": [
- {
- "status": status,
- },
- ],
- }
-
- def test_job_succeeded(self, mock_describe_jobs):
- """Test job succeeded"""
- mock_describe_jobs.side_effect = [
- self.describe_jobs(BatchClientHook.RUNNING_STATE),
- self.describe_jobs(BatchClientHook.SUCCESS_STATE),
- ]
- waiter =
BatchClientHook(aws_conn_id=None).get_waiter("batch_job_complete")
- waiter.wait(jobs=[self.JOB_ID], WaiterConfig={"Delay": 0.01,
"MaxAttempts": 2})
-
- def test_job_failed(self, mock_describe_jobs):
- """Test job failed"""
- mock_describe_jobs.side_effect = [
- self.describe_jobs(BatchClientHook.RUNNING_STATE),
- self.describe_jobs(BatchClientHook.FAILURE_STATE),
- ]
- waiter =
BatchClientHook(aws_conn_id=None).get_waiter("batch_job_complete")
-
- with pytest.raises(WaiterError, match="Waiter encountered a terminal
failure state"):
- waiter.wait(jobs=[self.JOB_ID], WaiterConfig={"Delay": 0.01,
"MaxAttempts": 2})
-
-
-class TestCustomEmrServiceWaiters:
- """Test waiters from ``amazon/aws/waiters/emr.json``."""
-
- JOBFLOW_ID = "test_jobflow_id"
- STEP_ID1 = "test_step_id_1"
- STEP_ID2 = "test_step_id_2"
-
- @pytest.fixture(autouse=True)
- def setup_test_cases(self, monkeypatch):
- self.client = boto3.client("emr", region_name="eu-west-3")
- monkeypatch.setattr(EmrHook, "conn", self.client)
-
- @pytest.fixture
- def mock_list_steps(self):
- """Mock ``EmrHook.Client.list_steps`` method."""
- with mock.patch.object(self.client, "list_steps") as m:
- yield m
-
- def test_service_waiters(self):
- hook_waiters = EmrHook(aws_conn_id=None).list_waiters()
- assert "steps_wait_for_terminal" in hook_waiters
-
- @staticmethod
- def list_steps(step_records: Sequence[tuple[str, str]]):
- """
- Helper function to generate minimal ListSteps response.
- https://docs.aws.amazon.com/emr/latest/APIReference/API_ListSteps.html
- """
- return {
- "Steps": [
- {
- "Id": step_record[0],
- "Status": {
- "State": step_record[1],
- },
- }
- for step_record in step_records
- ],
- }
-
- def test_steps_succeeded(self, mock_list_steps):
- """Test steps succeeded"""
- mock_list_steps.side_effect = [
- self.list_steps([(self.STEP_ID1, "PENDING"), (self.STEP_ID2,
"RUNNING")]),
- self.list_steps([(self.STEP_ID1, "RUNNING"), (self.STEP_ID2,
"COMPLETED")]),
- self.list_steps([(self.STEP_ID1, "COMPLETED"), (self.STEP_ID2,
"COMPLETED")]),
- ]
- waiter =
EmrHook(aws_conn_id=None).get_waiter("steps_wait_for_terminal")
- waiter.wait(
- ClusterId=self.JOBFLOW_ID,
- StepIds=[self.STEP_ID1, self.STEP_ID2],
- WaiterConfig={"Delay": 0.01, "MaxAttempts": 3},
- )
-
- def test_steps_failed(self, mock_list_steps):
- """Test steps failed"""
- mock_list_steps.side_effect = [
- self.list_steps([(self.STEP_ID1, "PENDING"), (self.STEP_ID2,
"RUNNING")]),
- self.list_steps([(self.STEP_ID1, "RUNNING"), (self.STEP_ID2,
"COMPLETED")]),
- self.list_steps([(self.STEP_ID1, "FAILED"), (self.STEP_ID2,
"COMPLETED")]),
- ]
- waiter =
EmrHook(aws_conn_id=None).get_waiter("steps_wait_for_terminal")
-
- with pytest.raises(WaiterError, match="Waiter encountered a terminal
failure state"):
- waiter.wait(
- ClusterId=self.JOBFLOW_ID,
- StepIds=[self.STEP_ID1, self.STEP_ID2],
- WaiterConfig={"Delay": 0.01, "MaxAttempts": 3},
- )
diff --git a/tests/providers/amazon/aws/waiters/test_dynamo.py
b/tests/providers/amazon/aws/waiters/test_dynamo.py
new file mode 100644
index 0000000000..be94f68081
--- /dev/null
+++ b/tests/providers/amazon/aws/waiters/test_dynamo.py
@@ -0,0 +1,82 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+
+import boto3
+import pytest
+from botocore.exceptions import WaiterError
+
+from airflow.providers.amazon.aws.hooks.dynamodb import DynamoDBHook
+
+
+class TestCustomDynamoDBServiceWaiters:
+ STATUS_COMPLETED = "COMPLETED"
+ STATUS_FAILED = "FAILED"
+ STATUS_IN_PROGRESS = "IN_PROGRESS"
+
+ @pytest.fixture(autouse=True)
+ def setup_test_cases(self, monkeypatch):
+ self.resource = boto3.resource("dynamodb", region_name="eu-west-3")
+ monkeypatch.setattr(DynamoDBHook, "conn", self.resource)
+ self.client = self.resource.meta.client
+
+ @pytest.fixture
+ def mock_describe_export(self):
+ """Mock ``DynamoDBHook.Client.describe_export`` method."""
+ with mock.patch.object(self.client, "describe_export") as m:
+ yield m
+
+ def test_service_waiters(self):
+ hook_waiters = DynamoDBHook(aws_conn_id=None).list_waiters()
+ assert "export_table" in hook_waiters
+
+ @staticmethod
+ def describe_export(status: str):
+ """
+ Helper function for generate minimal DescribeExport response for
single job.
+
https://docs.aws.amazon.com/amazondynamodb/latest/APIReference/API_DescribeExport.html
+ """
+ return {"ExportDescription": {"ExportStatus": status}}
+
+ def test_export_table_to_point_in_time_completed(self,
mock_describe_export):
+ """Test state transition from `in progress` to `completed` during
init."""
+ waiter = DynamoDBHook(aws_conn_id=None).get_waiter("export_table")
+ mock_describe_export.side_effect = [
+ self.describe_export(self.STATUS_IN_PROGRESS),
+ self.describe_export(self.STATUS_COMPLETED),
+ ]
+ waiter.wait(
+
ExportArn="LoremIpsumissimplydummytextoftheprintingandtypesettingindustry",
+ WaiterConfig={"Delay": 0.01, "MaxAttempts": 3},
+ )
+
+ def test_export_table_to_point_in_time_failed(self, mock_describe_export):
+ """Test state transition from `in progress` to `failed` during init."""
+ with mock.patch("boto3.client") as client:
+ client.return_value = self.client
+ mock_describe_export.side_effect = [
+ self.describe_export(self.STATUS_IN_PROGRESS),
+ self.describe_export(self.STATUS_FAILED),
+ ]
+ waiter = DynamoDBHook(aws_conn_id=None).get_waiter("export_table",
client=self.client)
+ with pytest.raises(WaiterError, match='we matched expected path:
"FAILED"'):
+ waiter.wait(
+
ExportArn="LoremIpsumissimplydummytextoftheprintingandtypesettingindustry",
+ WaiterConfig={"Delay": 0.01, "MaxAttempts": 3},
+ )
diff --git a/tests/providers/amazon/aws/waiters/test_ecs.py
b/tests/providers/amazon/aws/waiters/test_ecs.py
new file mode 100644
index 0000000000..5742cd59b1
--- /dev/null
+++ b/tests/providers/amazon/aws/waiters/test_ecs.py
@@ -0,0 +1,139 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+
+import boto3
+import pytest
+from botocore.exceptions import WaiterError
+
+from airflow.providers.amazon.aws.hooks.ecs import EcsClusterStates, EcsHook,
EcsTaskDefinitionStates
+
+
+class TestCustomECSServiceWaiters:
+ @pytest.fixture(autouse=True)
+ def setup_test_cases(self, monkeypatch):
+ self.client = boto3.client("ecs", region_name="eu-west-3")
+ monkeypatch.setattr(EcsHook, "conn", self.client)
+
+ @pytest.fixture
+ def mock_describe_clusters(self):
+ """Mock ``ECS.Client.describe_clusters`` method."""
+ with mock.patch.object(self.client, "describe_clusters") as m:
+ yield m
+
+ @pytest.fixture
+ def mock_describe_task_definition(self):
+ """Mock ``ECS.Client.describe_task_definition`` method."""
+ with mock.patch.object(self.client, "describe_task_definition") as m:
+ yield m
+
+ def test_service_waiters(self):
+ hook_waiters = EcsHook(aws_conn_id=None).list_waiters()
+ assert "cluster_active" in hook_waiters
+ assert "cluster_inactive" in hook_waiters
+
+ @staticmethod
+ def describe_clusters(
+ status: str | EcsClusterStates, cluster_name: str = "spam-egg",
failures: dict | list | None = None
+ ):
+ """
+ Helper function for generate minimal DescribeClusters response for
single job.
+
https://docs.aws.amazon.com/AmazonECS/latest/APIReference/API_DescribeClusters.html
+ """
+ if isinstance(status, EcsClusterStates):
+ status = status.value
+ else:
+ assert status in EcsClusterStates.__members__.values()
+
+ failures = failures or []
+ if isinstance(failures, dict):
+ failures = [failures]
+
+ return {"clusters": [{"clusterName": cluster_name, "status": status}],
"failures": failures}
+
+ def test_cluster_active(self, mock_describe_clusters):
+ """Test cluster reach Active state during creation."""
+ mock_describe_clusters.side_effect = [
+ self.describe_clusters(EcsClusterStates.DEPROVISIONING),
+ self.describe_clusters(EcsClusterStates.PROVISIONING),
+ self.describe_clusters(EcsClusterStates.ACTIVE),
+ ]
+ waiter = EcsHook(aws_conn_id=None).get_waiter("cluster_active")
+ waiter.wait(clusters=["spam-egg"], WaiterConfig={"Delay": 0.01,
"MaxAttempts": 3})
+
+ @pytest.mark.parametrize("state", ["FAILED", "INACTIVE"])
+ def test_cluster_active_failure_states(self, mock_describe_clusters,
state):
+ """Test cluster reach inactive state during creation."""
+ mock_describe_clusters.side_effect = [
+ self.describe_clusters(EcsClusterStates.PROVISIONING),
+ self.describe_clusters(state),
+ ]
+ waiter = EcsHook(aws_conn_id=None).get_waiter("cluster_active")
+ with pytest.raises(WaiterError, match=f'matched expected path:
"{state}"'):
+ waiter.wait(clusters=["spam-egg"], WaiterConfig={"Delay": 0.01,
"MaxAttempts": 3})
+
+ def test_cluster_active_failure_reasons(self, mock_describe_clusters):
+ """Test cluster reach failure state during creation."""
+ mock_describe_clusters.side_effect = [
+ self.describe_clusters(EcsClusterStates.PROVISIONING),
+ self.describe_clusters(EcsClusterStates.PROVISIONING,
failures={"reason": "MISSING"}),
+ ]
+ waiter = EcsHook(aws_conn_id=None).get_waiter("cluster_active")
+ with pytest.raises(WaiterError, match='matched expected path:
"MISSING"'):
+ waiter.wait(clusters=["spam-egg"], WaiterConfig={"Delay": 0.01,
"MaxAttempts": 3})
+
+ def test_cluster_inactive(self, mock_describe_clusters):
+ """Test cluster reach Inactive state during deletion."""
+ mock_describe_clusters.side_effect = [
+ self.describe_clusters(EcsClusterStates.ACTIVE),
+ self.describe_clusters(EcsClusterStates.ACTIVE),
+ self.describe_clusters(EcsClusterStates.INACTIVE),
+ ]
+ waiter = EcsHook(aws_conn_id=None).get_waiter("cluster_inactive")
+ waiter.wait(clusters=["spam-egg"], WaiterConfig={"Delay": 0.01,
"MaxAttempts": 3})
+
+ def test_cluster_inactive_failure_reasons(self, mock_describe_clusters):
+ """Test cluster reach failure state during deletion."""
+ mock_describe_clusters.side_effect = [
+ self.describe_clusters(EcsClusterStates.ACTIVE),
+ self.describe_clusters(EcsClusterStates.DEPROVISIONING),
+ self.describe_clusters(EcsClusterStates.DEPROVISIONING,
failures={"reason": "MISSING"}),
+ ]
+ waiter = EcsHook(aws_conn_id=None).get_waiter("cluster_inactive")
+ waiter.wait(clusters=["spam-egg"], WaiterConfig={"Delay": 0.01,
"MaxAttempts": 3})
+
+ @staticmethod
+ def describe_task_definition(status: str | EcsTaskDefinitionStates,
task_definition: str = "spam-egg"):
+ """
+ Helper function for generate minimal DescribeTaskDefinition response
for single job.
+
https://docs.aws.amazon.com/AmazonECS/latest/APIReference/API_DescribeTaskDefinition.html
+ """
+ if isinstance(status, EcsTaskDefinitionStates):
+ status = status.value
+ else:
+ assert status in EcsTaskDefinitionStates.__members__.values()
+
+ return {
+ "taskDefinition": {
+ "taskDefinitionArn": (
+
f"arn:aws:ecs:eu-west-3:123456789012:task-definition/{task_definition}:42"
+ ),
+ "status": status,
+ }
+ }
diff --git a/tests/providers/amazon/aws/waiters/test_eks.py
b/tests/providers/amazon/aws/waiters/test_eks.py
new file mode 100644
index 0000000000..9013c8b7c6
--- /dev/null
+++ b/tests/providers/amazon/aws/waiters/test_eks.py
@@ -0,0 +1,56 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import json
+
+import boto3
+from moto import mock_aws
+
+from airflow.providers.amazon.aws.hooks.eks import EksHook
+from tests.providers.amazon.aws.waiters.test_custom_waiters import
assert_all_match
+
+
+class TestCustomEKSServiceWaiters:
+ def test_service_waiters(self):
+ hook = EksHook()
+ with open(hook.waiter_path) as config_file:
+ expected_waiters = json.load(config_file)["waiters"]
+
+ for waiter in list(expected_waiters.keys()):
+ assert waiter in hook.list_waiters()
+ assert waiter in hook._list_custom_waiters()
+
+ @mock_aws
+ def test_existing_waiter_inherited(self):
+ """
+ AwsBaseHook::get_waiter will first check if there is a custom waiter
with the
+ provided name and pass that through is it exists, otherwise it will
check the
+ custom waiters for the given service. This test checks to make sure
that the
+ waiter is the same whichever way you get it and no modifications are
made.
+ """
+ hook_waiter = EksHook().get_waiter("cluster_active")
+ client_waiter = EksHook().conn.get_waiter("cluster_active")
+ boto_waiter = boto3.client("eks").get_waiter("cluster_active")
+
+ assert_all_match(hook_waiter.name, client_waiter.name,
boto_waiter.name)
+ assert_all_match(len(hook_waiter.__dict__),
len(client_waiter.__dict__), len(boto_waiter.__dict__))
+ for attr in hook_waiter.__dict__:
+ # Not all attributes in a Waiter are directly comparable
+ # so the best we can do it make sure the same attrs exist.
+ assert hasattr(boto_waiter, attr)
+ assert hasattr(client_waiter, attr)
diff --git a/tests/providers/amazon/aws/waiters/test_emr.py
b/tests/providers/amazon/aws/waiters/test_emr.py
new file mode 100644
index 0000000000..a02327f004
--- /dev/null
+++ b/tests/providers/amazon/aws/waiters/test_emr.py
@@ -0,0 +1,95 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import Sequence
+from unittest import mock
+
+import boto3
+import pytest
+from botocore.exceptions import WaiterError
+
+from airflow.providers.amazon.aws.hooks.emr import EmrHook
+
+
+class TestCustomEmrServiceWaiters:
+ JOBFLOW_ID = "test_jobflow_id"
+ STEP_ID1 = "test_step_id_1"
+ STEP_ID2 = "test_step_id_2"
+
+ @pytest.fixture(autouse=True)
+ def setup_test_cases(self, monkeypatch):
+ self.client = boto3.client("emr", region_name="eu-west-3")
+ monkeypatch.setattr(EmrHook, "conn", self.client)
+
+ @pytest.fixture
+ def mock_list_steps(self):
+ """Mock ``EmrHook.Client.list_steps`` method."""
+ with mock.patch.object(self.client, "list_steps") as m:
+ yield m
+
+ def test_service_waiters(self):
+ hook_waiters = EmrHook(aws_conn_id=None).list_waiters()
+ assert "steps_wait_for_terminal" in hook_waiters
+
+ @staticmethod
+ def list_steps(step_records: Sequence[tuple[str, str]]):
+ """
+ Helper function to generate minimal ListSteps response.
+ https://docs.aws.amazon.com/emr/latest/APIReference/API_ListSteps.html
+ """
+ return {
+ "Steps": [
+ {
+ "Id": step_record[0],
+ "Status": {
+ "State": step_record[1],
+ },
+ }
+ for step_record in step_records
+ ],
+ }
+
+ def test_steps_succeeded(self, mock_list_steps):
+ """Test steps succeeded"""
+ mock_list_steps.side_effect = [
+ self.list_steps([(self.STEP_ID1, "PENDING"), (self.STEP_ID2,
"RUNNING")]),
+ self.list_steps([(self.STEP_ID1, "RUNNING"), (self.STEP_ID2,
"COMPLETED")]),
+ self.list_steps([(self.STEP_ID1, "COMPLETED"), (self.STEP_ID2,
"COMPLETED")]),
+ ]
+ waiter =
EmrHook(aws_conn_id=None).get_waiter("steps_wait_for_terminal")
+ waiter.wait(
+ ClusterId=self.JOBFLOW_ID,
+ StepIds=[self.STEP_ID1, self.STEP_ID2],
+ WaiterConfig={"Delay": 0.01, "MaxAttempts": 3},
+ )
+
+ def test_steps_failed(self, mock_list_steps):
+ """Test steps failed"""
+ mock_list_steps.side_effect = [
+ self.list_steps([(self.STEP_ID1, "PENDING"), (self.STEP_ID2,
"RUNNING")]),
+ self.list_steps([(self.STEP_ID1, "RUNNING"), (self.STEP_ID2,
"COMPLETED")]),
+ self.list_steps([(self.STEP_ID1, "FAILED"), (self.STEP_ID2,
"COMPLETED")]),
+ ]
+ waiter =
EmrHook(aws_conn_id=None).get_waiter("steps_wait_for_terminal")
+
+ with pytest.raises(WaiterError, match="Waiter encountered a terminal
failure state"):
+ waiter.wait(
+ ClusterId=self.JOBFLOW_ID,
+ StepIds=[self.STEP_ID1, self.STEP_ID2],
+ WaiterConfig={"Delay": 0.01, "MaxAttempts": 3},
+ )