This is an automated email from the ASF dual-hosted git repository. ephraimanierobi pushed a commit to branch v2-5-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 154ad9a02b126c2e09185fb2b1fcd2071699e9c1 Author: Adrian Castro <[email protected]> AuthorDate: Tue Dec 6 18:28:41 2022 +0100 Migrate amazon provider sensor tests from `unittests` to `pytest` (#28139) (cherry picked from commit b726d8eeb84871fafea3395764815b4ddc0c3216) --- tests/providers/amazon/aws/sensors/test_athena.py | 11 +++-- tests/providers/amazon/aws/sensors/test_batch.py | 49 +++++++++------------ .../amazon/aws/sensors/test_cloud_formation.py | 8 +--- .../providers/amazon/aws/sensors/test_dms_task.py | 5 +-- tests/providers/amazon/aws/sensors/test_eks.py | 12 ++--- .../providers/amazon/aws/sensors/test_emr_base.py | 4 +- .../amazon/aws/sensors/test_emr_containers.py | 5 +-- .../amazon/aws/sensors/test_emr_job_flow.py | 10 ++--- .../providers/amazon/aws/sensors/test_emr_step.py | 10 ++--- tests/providers/amazon/aws/sensors/test_glacier.py | 7 ++- tests/providers/amazon/aws/sensors/test_glue.py | 9 +--- .../amazon/aws/sensors/test_glue_crawler.py | 15 +++---- .../amazon/aws/sensors/test_quicksight.py | 51 ++++++++++------------ tests/providers/amazon/aws/sensors/test_s3_key.py | 16 +++---- .../amazon/aws/sensors/test_s3_keys_unchanged.py | 38 ++++++++++------ .../amazon/aws/sensors/test_sagemaker_base.py | 4 +- .../amazon/aws/sensors/test_sagemaker_endpoint.py | 3 +- .../amazon/aws/sensors/test_sagemaker_training.py | 3 +- .../amazon/aws/sensors/test_sagemaker_transform.py | 3 +- .../amazon/aws/sensors/test_sagemaker_tuning.py | 3 +- tests/providers/amazon/aws/sensors/test_sqs.py | 5 +-- .../amazon/aws/sensors/test_step_function.py | 10 ++--- 22 files changed, 123 insertions(+), 158 deletions(-) diff --git a/tests/providers/amazon/aws/sensors/test_athena.py b/tests/providers/amazon/aws/sensors/test_athena.py index c6019296dd..a9809be1d0 100644 --- a/tests/providers/amazon/aws/sensors/test_athena.py +++ b/tests/providers/amazon/aws/sensors/test_athena.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock import pytest @@ -27,8 +26,8 @@ from airflow.providers.amazon.aws.hooks.athena import AthenaHook from airflow.providers.amazon.aws.sensors.athena import AthenaSensor -class TestAthenaSensor(unittest.TestCase): - def setUp(self): +class TestAthenaSensor: + def setup_method(self): self.sensor = AthenaSensor( task_id="test_athena_sensor", query_execution_id="abc", @@ -39,15 +38,15 @@ class TestAthenaSensor(unittest.TestCase): @mock.patch.object(AthenaHook, "poll_query_status", side_effect=("SUCCEEDED",)) def test_poke_success(self, mock_poll_query_status): - assert self.sensor.poke({}) + assert self.sensor.poke({}) is True @mock.patch.object(AthenaHook, "poll_query_status", side_effect=("RUNNING",)) def test_poke_running(self, mock_poll_query_status): - assert not self.sensor.poke({}) + assert self.sensor.poke({}) is False @mock.patch.object(AthenaHook, "poll_query_status", side_effect=("QUEUED",)) def test_poke_queued(self, mock_poll_query_status): - assert not self.sensor.poke({}) + assert self.sensor.poke({}) is False @mock.patch.object(AthenaHook, "poll_query_status", side_effect=("FAILED",)) def test_poke_failed(self, mock_poll_query_status): diff --git a/tests/providers/amazon/aws/sensors/test_batch.py b/tests/providers/amazon/aws/sensors/test_batch.py index e7e20d65a2..d7905d563f 100644 --- a/tests/providers/amazon/aws/sensors/test_batch.py +++ b/tests/providers/amazon/aws/sensors/test_batch.py @@ -16,11 +16,9 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock import pytest -from parameterized import parameterized from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook @@ -34,8 +32,8 @@ TASK_ID = "batch_job_sensor" JOB_ID = "8222a1c2-b246-4e19-b1b8-0039bb4407c0" -class TestBatchSensor(unittest.TestCase): - def setUp(self): +class TestBatchSensor: + def setup_method(self): self.batch_sensor = BatchSensor( task_id="batch_job_sensor", job_id=JOB_ID, @@ -44,45 +42,38 @@ class TestBatchSensor(unittest.TestCase): @mock.patch.object(BatchClientHook, "get_job_description") def test_poke_on_success_state(self, mock_get_job_description): mock_get_job_description.return_value = {"status": "SUCCEEDED"} - self.assertTrue(self.batch_sensor.poke({})) + assert self.batch_sensor.poke({}) mock_get_job_description.assert_called_once_with(JOB_ID) @mock.patch.object(BatchClientHook, "get_job_description") def test_poke_on_failure_state(self, mock_get_job_description): mock_get_job_description.return_value = {"status": "FAILED"} - with self.assertRaises(AirflowException) as e: + with pytest.raises(AirflowException, match="Batch sensor failed. AWS Batch job status: FAILED"): self.batch_sensor.poke({}) - self.assertEqual("Batch sensor failed. AWS Batch job status: FAILED", str(e.exception)) mock_get_job_description.assert_called_once_with(JOB_ID) @mock.patch.object(BatchClientHook, "get_job_description") def test_poke_on_invalid_state(self, mock_get_job_description): mock_get_job_description.return_value = {"status": "INVALID"} - with self.assertRaises(AirflowException) as e: + with pytest.raises( + AirflowException, match="Batch sensor failed. Unknown AWS Batch job status: INVALID" + ): self.batch_sensor.poke({}) - self.assertEqual("Batch sensor failed. Unknown AWS Batch job status: INVALID", str(e.exception)) mock_get_job_description.assert_called_once_with(JOB_ID) - @parameterized.expand( - [ - ("SUBMITTED",), - ("PENDING",), - ("RUNNABLE",), - ("STARTING",), - ("RUNNING",), - ] - ) + @pytest.mark.parametrize("job_status", ["SUBMITTED", "PENDING", "RUNNABLE", "STARTING", "RUNNING"]) @mock.patch.object(BatchClientHook, "get_job_description") - def test_poke_on_intermediate_state(self, job_status, mock_get_job_description): + def test_poke_on_intermediate_state(self, mock_get_job_description, job_status): + print(job_status) mock_get_job_description.return_value = {"status": job_status} - self.assertFalse(self.batch_sensor.poke({})) + assert self.batch_sensor.poke({}) is False mock_get_job_description.assert_called_once_with(JOB_ID) -class TestBatchComputeEnvironmentSensor(unittest.TestCase): - def setUp(self): +class TestBatchComputeEnvironmentSensor: + def setup_method(self): self.environment_name = "environment_name" self.sensor = BatchComputeEnvironmentSensor( task_id="test_batch_compute_environment_sensor", @@ -104,7 +95,7 @@ class TestBatchComputeEnvironmentSensor(unittest.TestCase): mock_batch_client.describe_compute_environments.return_value = { "computeEnvironments": [{"status": "VALID"}] } - assert self.sensor.poke({}) + assert self.sensor.poke({}) is True mock_batch_client.describe_compute_environments.assert_called_once_with( computeEnvironments=[self.environment_name], ) @@ -118,7 +109,7 @@ class TestBatchComputeEnvironmentSensor(unittest.TestCase): } ] } - assert not self.sensor.poke({}) + assert self.sensor.poke({}) is False mock_batch_client.describe_compute_environments.assert_called_once_with( computeEnvironments=[self.environment_name], ) @@ -140,8 +131,8 @@ class TestBatchComputeEnvironmentSensor(unittest.TestCase): assert "AWS Batch compute environment failed" in str(ctx.value) -class TestBatchJobQueueSensor(unittest.TestCase): - def setUp(self): +class TestBatchJobQueueSensor: + def setup_method(self): self.job_queue = "job_queue" self.sensor = BatchJobQueueSensor( task_id="test_batch_job_queue_sensor", @@ -162,7 +153,7 @@ class TestBatchJobQueueSensor(unittest.TestCase): def test_poke_no_queue_with_treat_non_existing_as_deleted(self, mock_batch_client): self.sensor.treat_non_existing_as_deleted = True mock_batch_client.describe_job_queues.return_value = {"jobQueues": []} - assert self.sensor.poke({}) + assert self.sensor.poke({}) is True mock_batch_client.describe_job_queues.assert_called_once_with( jobQueues=[self.job_queue], ) @@ -170,7 +161,7 @@ class TestBatchJobQueueSensor(unittest.TestCase): @mock.patch.object(BatchClientHook, "client") def test_poke_valid(self, mock_batch_client): mock_batch_client.describe_job_queues.return_value = {"jobQueues": [{"status": "VALID"}]} - assert self.sensor.poke({}) + assert self.sensor.poke({}) is True mock_batch_client.describe_job_queues.assert_called_once_with( jobQueues=[self.job_queue], ) @@ -184,7 +175,7 @@ class TestBatchJobQueueSensor(unittest.TestCase): } ] } - assert not self.sensor.poke({}) + assert self.sensor.poke({}) is False mock_batch_client.describe_job_queues.assert_called_once_with( jobQueues=[self.job_queue], ) diff --git a/tests/providers/amazon/aws/sensors/test_cloud_formation.py b/tests/providers/amazon/aws/sensors/test_cloud_formation.py index 63a54c0bb2..14610df267 100644 --- a/tests/providers/amazon/aws/sensors/test_cloud_formation.py +++ b/tests/providers/amazon/aws/sensors/test_cloud_formation.py @@ -63,12 +63,10 @@ class TestCloudFormationCreateStackSensor: self.cloudformation_client_mock.describe_stacks.return_value = { "Stacks": [{"StackStatus": "bar"}] } - with pytest.raises(ValueError) as ctx: + with pytest.raises(ValueError, match="Stack foo in bad state: bar"): op = CloudFormationCreateStackSensor(task_id="task", stack_name="foo") op.poke({}) - assert "Stack foo in bad state: bar" == str(ctx.value) - class TestCloudFormationDeleteStackSensor: task_id = "test_cloudformation_cluster_delete_sensor" @@ -105,12 +103,10 @@ class TestCloudFormationDeleteStackSensor: self.cloudformation_client_mock.describe_stacks.return_value = { "Stacks": [{"StackStatus": "bar"}] } - with pytest.raises(ValueError) as ctx: + with pytest.raises(ValueError, match="Stack foo in bad state: bar"): op = CloudFormationDeleteStackSensor(task_id="task", stack_name="foo") op.poke({}) - assert "Stack foo in bad state: bar" == str(ctx.value) - @mock_cloudformation def test_poke_stack_does_not_exist(self): op = CloudFormationDeleteStackSensor(task_id="task", stack_name="foo") diff --git a/tests/providers/amazon/aws/sensors/test_dms_task.py b/tests/providers/amazon/aws/sensors/test_dms_task.py index e5770593fd..810510c80b 100644 --- a/tests/providers/amazon/aws/sensors/test_dms_task.py +++ b/tests/providers/amazon/aws/sensors/test_dms_task.py @@ -16,7 +16,6 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock import pytest @@ -26,8 +25,8 @@ from airflow.providers.amazon.aws.hooks.dms import DmsHook from airflow.providers.amazon.aws.sensors.dms import DmsTaskCompletedSensor -class TestDmsTaskCompletedSensor(unittest.TestCase): - def setUp(self): +class TestDmsTaskCompletedSensor: + def setup_method(self): self.sensor = DmsTaskCompletedSensor( task_id="test_dms_sensor", aws_conn_id="aws_default", diff --git a/tests/providers/amazon/aws/sensors/test_eks.py b/tests/providers/amazon/aws/sensors/test_eks.py index 1e3e93791e..fa5457f889 100644 --- a/tests/providers/amazon/aws/sensors/test_eks.py +++ b/tests/providers/amazon/aws/sensors/test_eks.py @@ -63,7 +63,7 @@ class TestEksClusterStateSensor: @mock.patch.object(EksHook, "get_cluster_state", return_value=ClusterStates.ACTIVE) def test_poke_reached_target_state(self, mock_get_cluster_state, setUp): - assert self.sensor.poke({}) + assert self.sensor.poke({}) is True mock_get_cluster_state.assert_called_once_with(clusterName=CLUSTER_NAME) @mock.patch("airflow.providers.amazon.aws.hooks.eks.EksHook.get_cluster_state") @@ -71,7 +71,7 @@ class TestEksClusterStateSensor: def test_poke_reached_pending_state(self, mock_get_cluster_state, setUp, pending_state): mock_get_cluster_state.return_value = pending_state - assert not self.sensor.poke({}) + assert self.sensor.poke({}) is False mock_get_cluster_state.assert_called_once_with(clusterName=CLUSTER_NAME) @mock.patch("airflow.providers.amazon.aws.hooks.eks.EksHook.get_cluster_state") @@ -104,7 +104,7 @@ class TestEksFargateProfileStateSensor: @mock.patch.object(EksHook, "get_fargate_profile_state", return_value=FargateProfileStates.ACTIVE) def test_poke_reached_target_state(self, mock_get_fargate_profile_state, setUp): - assert self.sensor.poke({}) + assert self.sensor.poke({}) is True mock_get_fargate_profile_state.assert_called_once_with( clusterName=CLUSTER_NAME, fargateProfileName=FARGATE_PROFILE_NAME ) @@ -114,7 +114,7 @@ class TestEksFargateProfileStateSensor: def test_poke_reached_pending_state(self, mock_get_fargate_profile_state, setUp, pending_state): mock_get_fargate_profile_state.return_value = pending_state - assert not self.sensor.poke({}) + assert self.sensor.poke({}) is False mock_get_fargate_profile_state.assert_called_once_with( clusterName=CLUSTER_NAME, fargateProfileName=FARGATE_PROFILE_NAME ) @@ -153,7 +153,7 @@ class TestEksNodegroupStateSensor: @mock.patch.object(EksHook, "get_nodegroup_state", return_value=NodegroupStates.ACTIVE) def test_poke_reached_target_state(self, mock_get_nodegroup_state, setUp): - assert self.sensor.poke({}) + assert self.sensor.poke({}) is True mock_get_nodegroup_state.assert_called_once_with( clusterName=CLUSTER_NAME, nodegroupName=NODEGROUP_NAME ) @@ -163,7 +163,7 @@ class TestEksNodegroupStateSensor: def test_poke_reached_pending_state(self, mock_get_nodegroup_state, setUp, pending_state): mock_get_nodegroup_state.return_value = pending_state - assert not self.sensor.poke({}) + assert self.sensor.poke({}) is False mock_get_nodegroup_state.assert_called_once_with( clusterName=CLUSTER_NAME, nodegroupName=NODEGROUP_NAME ) diff --git a/tests/providers/amazon/aws/sensors/test_emr_base.py b/tests/providers/amazon/aws/sensors/test_emr_base.py index 87650674b7..b0dfd66233 100644 --- a/tests/providers/amazon/aws/sensors/test_emr_base.py +++ b/tests/providers/amazon/aws/sensors/test_emr_base.py @@ -17,8 +17,6 @@ # under the License. from __future__ import annotations -import unittest - import pytest from airflow.exceptions import AirflowException @@ -60,7 +58,7 @@ class EmrBaseSensorSubclass(EmrBaseSensor): return None -class TestEmrBaseSensor(unittest.TestCase): +class TestEmrBaseSensor: def test_poke_returns_true_when_state_is_in_target_states(self): operator = EmrBaseSensorSubclass( task_id="test_task", diff --git a/tests/providers/amazon/aws/sensors/test_emr_containers.py b/tests/providers/amazon/aws/sensors/test_emr_containers.py index b12a69c789..38d7688f66 100644 --- a/tests/providers/amazon/aws/sensors/test_emr_containers.py +++ b/tests/providers/amazon/aws/sensors/test_emr_containers.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock import pytest @@ -27,8 +26,8 @@ from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook from airflow.providers.amazon.aws.sensors.emr import EmrContainerSensor -class TestEmrContainerSensor(unittest.TestCase): - def setUp(self): +class TestEmrContainerSensor: + def setup_method(self): self.sensor = EmrContainerSensor( task_id="test_emrcontainer_sensor", virtual_cluster_id="vzwemreks", diff --git a/tests/providers/amazon/aws/sensors/test_emr_job_flow.py b/tests/providers/amazon/aws/sensors/test_emr_job_flow.py index 79baee716c..87a80d6a01 100644 --- a/tests/providers/amazon/aws/sensors/test_emr_job_flow.py +++ b/tests/providers/amazon/aws/sensors/test_emr_job_flow.py @@ -18,7 +18,7 @@ from __future__ import annotations import datetime -import unittest +from unittest import mock from unittest.mock import MagicMock, patch import pytest @@ -188,8 +188,8 @@ DESCRIBE_CLUSTER_TERMINATED_WITH_ERRORS_RETURN = { } -class TestEmrJobFlowSensor(unittest.TestCase): - def setUp(self): +class TestEmrJobFlowSensor: + def setup_method(self): # Mock out the emr_client (moto has incorrect response) self.mock_emr_client = MagicMock() @@ -216,7 +216,7 @@ class TestEmrJobFlowSensor(unittest.TestCase): assert self.mock_emr_client.describe_cluster.call_count == 3 # make sure it was called with the job_flow_id - calls = [unittest.mock.call(ClusterId="j-8989898989")] + calls = [mock.call(ClusterId="j-8989898989")] self.mock_emr_client.describe_cluster.assert_has_calls(calls) def test_execute_calls_with_the_job_flow_id_until_it_reaches_failed_state_with_exception(self): @@ -262,5 +262,5 @@ class TestEmrJobFlowSensor(unittest.TestCase): assert self.mock_emr_client.describe_cluster.call_count == 3 # make sure it was called with the job_flow_id - calls = [unittest.mock.call(ClusterId="j-8989898989")] + calls = [mock.call(ClusterId="j-8989898989")] self.mock_emr_client.describe_cluster.assert_has_calls(calls) diff --git a/tests/providers/amazon/aws/sensors/test_emr_step.py b/tests/providers/amazon/aws/sensors/test_emr_step.py index 1fb5aab378..d053bda97c 100644 --- a/tests/providers/amazon/aws/sensors/test_emr_step.py +++ b/tests/providers/amazon/aws/sensors/test_emr_step.py @@ -17,8 +17,8 @@ # under the License. from __future__ import annotations -import unittest from datetime import datetime +from unittest import mock from unittest.mock import MagicMock, patch import pytest @@ -142,8 +142,8 @@ DESCRIBE_JOB_STEP_COMPLETED_RETURN = { } -class TestEmrStepSensor(unittest.TestCase): - def setUp(self): +class TestEmrStepSensor: + def setup_method(self): self.emr_client_mock = MagicMock() self.sensor = EmrStepSensor( task_id="test_task", @@ -170,8 +170,8 @@ class TestEmrStepSensor(unittest.TestCase): assert self.emr_client_mock.describe_step.call_count == 2 calls = [ - unittest.mock.call(ClusterId="j-8989898989", StepId="s-VK57YR1Z9Z5N"), - unittest.mock.call(ClusterId="j-8989898989", StepId="s-VK57YR1Z9Z5N"), + mock.call(ClusterId="j-8989898989", StepId="s-VK57YR1Z9Z5N"), + mock.call(ClusterId="j-8989898989", StepId="s-VK57YR1Z9Z5N"), ] self.emr_client_mock.describe_step.assert_has_calls(calls) diff --git a/tests/providers/amazon/aws/sensors/test_glacier.py b/tests/providers/amazon/aws/sensors/test_glacier.py index adac1b358a..20c4156e1b 100644 --- a/tests/providers/amazon/aws/sensors/test_glacier.py +++ b/tests/providers/amazon/aws/sensors/test_glacier.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock import pytest @@ -29,8 +28,8 @@ SUCCEEDED = "Succeeded" IN_PROGRESS = "InProgress" -class TestAmazonGlacierSensor(unittest.TestCase): - def setUp(self): +class TestAmazonGlacierSensor: + def setup_method(self): self.op = GlacierJobOperationSensor( task_id="test_athena_sensor", aws_conn_id="aws_default", @@ -63,7 +62,7 @@ class TestAmazonGlacierSensor(unittest.TestCase): assert "Sensor failed" in str(ctx.value) -class TestSensorJobDescription(unittest.TestCase): +class TestSensorJobDescription: def test_job_status_success(self): assert JobStatus.SUCCEEDED.value == SUCCEEDED diff --git a/tests/providers/amazon/aws/sensors/test_glue.py b/tests/providers/amazon/aws/sensors/test_glue.py index 1b1239c2cf..c8d593eed4 100644 --- a/tests/providers/amazon/aws/sensors/test_glue.py +++ b/tests/providers/amazon/aws/sensors/test_glue.py @@ -16,7 +16,6 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock from unittest.mock import ANY @@ -28,8 +27,8 @@ from airflow.providers.amazon.aws.hooks.glue import GlueJobHook from airflow.providers.amazon.aws.sensors.glue import GlueJobSensor -class TestGlueJobSensor(unittest.TestCase): - def setUp(self): +class TestGlueJobSensor: + def setup_method(self): conf.load_test_config() @mock.patch.object(GlueJobHook, "print_job_logs") @@ -142,7 +141,3 @@ class TestGlueJobSensor(unittest.TestCase): filter_pattern="?ERROR ?Exception", next_token=ANY, ) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/providers/amazon/aws/sensors/test_glue_crawler.py b/tests/providers/amazon/aws/sensors/test_glue_crawler.py index 17a2953f54..6a6ee5ae89 100644 --- a/tests/providers/amazon/aws/sensors/test_glue_crawler.py +++ b/tests/providers/amazon/aws/sensors/test_glue_crawler.py @@ -16,15 +16,14 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock from airflow.providers.amazon.aws.hooks.glue_crawler import GlueCrawlerHook from airflow.providers.amazon.aws.sensors.glue_crawler import GlueCrawlerSensor -class TestGlueCrawlerSensor(unittest.TestCase): - def setUp(self): +class TestGlueCrawlerSensor: + def setup_method(self): self.sensor = GlueCrawlerSensor( task_id="test_glue_crawler_sensor", crawler_name="aws_test_glue_crawler", @@ -36,21 +35,17 @@ class TestGlueCrawlerSensor(unittest.TestCase): @mock.patch.object(GlueCrawlerHook, "get_crawler") def test_poke_success(self, mock_get_crawler): mock_get_crawler.return_value["LastCrawl"]["Status"] = "SUCCEEDED" - self.assertFalse(self.sensor.poke({})) + assert self.sensor.poke({}) is False mock_get_crawler.assert_called_once_with("aws_test_glue_crawler") @mock.patch.object(GlueCrawlerHook, "get_crawler") def test_poke_failed(self, mock_get_crawler): mock_get_crawler.return_value["LastCrawl"]["Status"] = "FAILED" - self.assertFalse(self.sensor.poke({})) + assert self.sensor.poke({}) is False mock_get_crawler.assert_called_once_with("aws_test_glue_crawler") @mock.patch.object(GlueCrawlerHook, "get_crawler") def test_poke_cancelled(self, mock_get_crawler): mock_get_crawler.return_value["LastCrawl"]["Status"] = "CANCELLED" - self.assertFalse(self.sensor.poke({})) + assert self.sensor.poke({}) is False mock_get_crawler.assert_called_once_with("aws_test_glue_crawler") - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/providers/amazon/aws/sensors/test_quicksight.py b/tests/providers/amazon/aws/sensors/test_quicksight.py index 3dbf5f7778..562a986c88 100644 --- a/tests/providers/amazon/aws/sensors/test_quicksight.py +++ b/tests/providers/amazon/aws/sensors/test_quicksight.py @@ -17,21 +17,22 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock +import pytest +from moto import mock_sts +from moto.core import DEFAULT_ACCOUNT_ID + from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.quicksight import QuickSightHook -from airflow.providers.amazon.aws.hooks.sts import StsHook from airflow.providers.amazon.aws.sensors.quicksight import QuickSightSensor -AWS_ACCOUNT_ID = "123456789012" DATA_SET_ID = "DemoDataSet" INGESTION_ID = "DemoDataSet_Ingestion" -class TestQuickSightSensor(unittest.TestCase): - def setUp(self): +class TestQuickSightSensor: + def setup_method(self): self.sensor = QuickSightSensor( task_id="test_quicksight_sensor", aws_conn_id="aws_default", @@ -39,40 +40,32 @@ class TestQuickSightSensor(unittest.TestCase): ingestion_id="DemoDataSet_Ingestion", ) + @mock_sts @mock.patch.object(QuickSightHook, "get_status") - @mock.patch.object(StsHook, "get_conn") - @mock.patch.object(StsHook, "get_account_number") - def test_poke_success(self, mock_get_account_number, sts_conn, mock_get_status): - mock_get_account_number.return_value = AWS_ACCOUNT_ID + def test_poke_success(self, mock_get_status): mock_get_status.return_value = "COMPLETED" - self.assertTrue(self.sensor.poke({})) - mock_get_status.assert_called_once_with(AWS_ACCOUNT_ID, DATA_SET_ID, INGESTION_ID) + assert self.sensor.poke({}) is True + mock_get_status.assert_called_once_with(DEFAULT_ACCOUNT_ID, DATA_SET_ID, INGESTION_ID) + @mock_sts @mock.patch.object(QuickSightHook, "get_status") - @mock.patch.object(StsHook, "get_conn") - @mock.patch.object(StsHook, "get_account_number") - def test_poke_cancelled(self, mock_get_account_number, sts_conn, mock_get_status): - mock_get_account_number.return_value = AWS_ACCOUNT_ID + def test_poke_cancelled(self, mock_get_status): mock_get_status.return_value = "CANCELLED" - with self.assertRaises(AirflowException): + with pytest.raises(AirflowException): self.sensor.poke({}) - mock_get_status.assert_called_once_with(AWS_ACCOUNT_ID, DATA_SET_ID, INGESTION_ID) + mock_get_status.assert_called_once_with(DEFAULT_ACCOUNT_ID, DATA_SET_ID, INGESTION_ID) + @mock_sts @mock.patch.object(QuickSightHook, "get_status") - @mock.patch.object(StsHook, "get_conn") - @mock.patch.object(StsHook, "get_account_number") - def test_poke_failed(self, mock_get_account_number, sts_conn, mock_get_status): - mock_get_account_number.return_value = AWS_ACCOUNT_ID + def test_poke_failed(self, mock_get_status): mock_get_status.return_value = "FAILED" - with self.assertRaises(AirflowException): + with pytest.raises(AirflowException): self.sensor.poke({}) - mock_get_status.assert_called_once_with(AWS_ACCOUNT_ID, DATA_SET_ID, INGESTION_ID) + mock_get_status.assert_called_once_with(DEFAULT_ACCOUNT_ID, DATA_SET_ID, INGESTION_ID) + @mock_sts @mock.patch.object(QuickSightHook, "get_status") - @mock.patch.object(StsHook, "get_conn") - @mock.patch.object(StsHook, "get_account_number") - def test_poke_initialized(self, mock_get_account_number, sts_conn, mock_get_status): - mock_get_account_number.return_value = AWS_ACCOUNT_ID + def test_poke_initialized(self, mock_get_status): mock_get_status.return_value = "INITIALIZED" - self.assertFalse(self.sensor.poke({})) - mock_get_status.assert_called_once_with(AWS_ACCOUNT_ID, DATA_SET_ID, INGESTION_ID) + assert self.sensor.poke({}) is False + mock_get_status.assert_called_once_with(DEFAULT_ACCOUNT_ID, DATA_SET_ID, INGESTION_ID) diff --git a/tests/providers/amazon/aws/sensors/test_s3_key.py b/tests/providers/amazon/aws/sensors/test_s3_key.py index cd3c64da46..8d560e2c82 100644 --- a/tests/providers/amazon/aws/sensors/test_s3_key.py +++ b/tests/providers/amazon/aws/sensors/test_s3_key.py @@ -17,11 +17,9 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock import pytest -from parameterized import parameterized from airflow.exceptions import AirflowException from airflow.models import DAG, DagRun, TaskInstance @@ -30,7 +28,7 @@ from airflow.providers.amazon.aws.sensors.s3 import S3KeySensor from airflow.utils import timezone -class TestS3KeySensor(unittest.TestCase): +class TestS3KeySensor: def test_bucket_name_none_and_bucket_key_as_relative_path(self): """ Test if exception is raised when bucket_name is None @@ -81,14 +79,16 @@ class TestS3KeySensor(unittest.TestCase): with pytest.raises(TypeError): op.poke(None) - @parameterized.expand( + @pytest.mark.parametrize( + "key, bucket, parsed_key, parsed_bucket", [ - ["s3://bucket/key", None, "key", "bucket"], - ["key", "bucket", "key", "bucket"], - ] + pytest.param("s3://bucket/key", None, "key", "bucket", id="key as s3url"), + pytest.param("key", "bucket", "key", "bucket", id="separate bucket and key"), + ], ) @mock.patch("airflow.providers.amazon.aws.sensors.s3.S3Hook.head_object") - def test_parse_bucket_key(self, key, bucket, parsed_key, parsed_bucket, mock_head_object): + def test_parse_bucket_key(self, mock_head_object, key, bucket, parsed_key, parsed_bucket): + print(key, bucket, parsed_key, parsed_bucket) mock_head_object.return_value = None op = S3KeySensor( diff --git a/tests/providers/amazon/aws/sensors/test_s3_keys_unchanged.py b/tests/providers/amazon/aws/sensors/test_s3_keys_unchanged.py index 0fec724621..251f8d6258 100644 --- a/tests/providers/amazon/aws/sensors/test_s3_keys_unchanged.py +++ b/tests/providers/amazon/aws/sensors/test_s3_keys_unchanged.py @@ -18,11 +18,10 @@ from __future__ import annotations from datetime import datetime -from unittest import TestCase, mock +from unittest import mock import pytest from freezegun import freeze_time -from parameterized import parameterized from airflow.models.dag import DAG, AirflowException from airflow.providers.amazon.aws.sensors.s3 import S3KeysUnchangedSensor @@ -31,8 +30,8 @@ TEST_DAG_ID = "unit_tests_aws_sensor" DEFAULT_DATE = datetime(2015, 1, 1) -class TestS3KeysUnchangedSensor(TestCase): - def setUp(self): +class TestS3KeysUnchangedSensor: + def setup_method(self): self.dag = DAG(f"{TEST_DAG_ID}test_schedule_dag_once", start_date=DEFAULT_DATE, schedule="@once") self.sensor = S3KeysUnchangedSensor( @@ -76,17 +75,28 @@ class TestS3KeysUnchangedSensor(TestCase): with pytest.raises(AirflowException): self.sensor.is_keys_unchanged({"a"}) - @parameterized.expand( + @pytest.mark.parametrize( + "current_objects, expected_returns, inactivity_periods", [ - # Test: resetting inactivity period after key change - (({"a"}, {"a", "b"}, {"a", "b", "c"}), (False, False, False), (0, 0, 0)), - # ..and in case an item was deleted with option `allow_delete=True` - (({"a", "b"}, {"a"}, {"a", "c"}), (False, False, False), (0, 0, 0)), - # Test: passes after inactivity period was exceeded - (({"a"}, {"a"}, {"a"}), (False, False, True), (0, 10, 20)), - # ..and do not pass if empty key is given - ((set(), set(), set()), (False, False, False), (0, 10, 20)), - ] + pytest.param( + ({"a"}, {"a", "b"}, {"a", "b", "c"}), + (False, False, False), + (0, 0, 0), + id="resetting inactivity period after key change", + ), + pytest.param( + ({"a", "b"}, {"a"}, {"a", "c"}), + (False, False, False), + (0, 0, 0), + id="item was deleted with option `allow_delete=True`", + ), + pytest.param( + ({"a"}, {"a"}, {"a"}), (False, False, True), (0, 10, 20), id="inactivity period was exceeded" + ), + pytest.param( + (set(), set(), set()), (False, False, False), (0, 10, 20), id="not pass if empty key is given" + ), + ], ) @freeze_time(DEFAULT_DATE, auto_tick_seconds=10) def test_key_changes(self, current_objects, expected_returns, inactivity_periods): diff --git a/tests/providers/amazon/aws/sensors/test_sagemaker_base.py b/tests/providers/amazon/aws/sensors/test_sagemaker_base.py index 6eaa9c18d9..7cfc5c29c4 100644 --- a/tests/providers/amazon/aws/sensors/test_sagemaker_base.py +++ b/tests/providers/amazon/aws/sensors/test_sagemaker_base.py @@ -17,15 +17,13 @@ # under the License. from __future__ import annotations -import unittest - import pytest from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.sensors.sagemaker import SageMakerBaseSensor -class TestSagemakerBaseSensor(unittest.TestCase): +class TestSagemakerBaseSensor: def test_execute(self): class SageMakerBaseSensorSubclass(SageMakerBaseSensor): def non_terminal_states(self): diff --git a/tests/providers/amazon/aws/sensors/test_sagemaker_endpoint.py b/tests/providers/amazon/aws/sensors/test_sagemaker_endpoint.py index f71183be3a..6f5158c042 100644 --- a/tests/providers/amazon/aws/sensors/test_sagemaker_endpoint.py +++ b/tests/providers/amazon/aws/sensors/test_sagemaker_endpoint.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock import pytest @@ -55,7 +54,7 @@ DESCRIBE_ENDPOINT_UPDATING_RESPONSE = { } -class TestSageMakerEndpointSensor(unittest.TestCase): +class TestSageMakerEndpointSensor: @mock.patch.object(SageMakerHook, "get_conn") @mock.patch.object(SageMakerHook, "describe_endpoint") def test_sensor_with_failure(self, mock_describe, mock_get_conn): diff --git a/tests/providers/amazon/aws/sensors/test_sagemaker_training.py b/tests/providers/amazon/aws/sensors/test_sagemaker_training.py index 3a13384f25..0fc8bb5f52 100644 --- a/tests/providers/amazon/aws/sensors/test_sagemaker_training.py +++ b/tests/providers/amazon/aws/sensors/test_sagemaker_training.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from datetime import datetime from unittest import mock @@ -48,7 +47,7 @@ DESCRIBE_TRAINING_STOPPING_RESPONSE = dict(DESCRIBE_TRAINING_COMPLETED_RESPONSE) DESCRIBE_TRAINING_STOPPING_RESPONSE.update({"TrainingJobStatus": "Stopping"}) -class TestSageMakerTrainingSensor(unittest.TestCase): +class TestSageMakerTrainingSensor: @mock.patch.object(SageMakerHook, "get_conn") @mock.patch.object(SageMakerHook, "__init__") @mock.patch.object(SageMakerHook, "describe_training_job") diff --git a/tests/providers/amazon/aws/sensors/test_sagemaker_transform.py b/tests/providers/amazon/aws/sensors/test_sagemaker_transform.py index c6777165b2..3b4d939e8f 100644 --- a/tests/providers/amazon/aws/sensors/test_sagemaker_transform.py +++ b/tests/providers/amazon/aws/sensors/test_sagemaker_transform.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock import pytest @@ -53,7 +52,7 @@ DESCRIBE_TRANSFORM_STOPPING_RESPONSE = { } -class TestSageMakerTransformSensor(unittest.TestCase): +class TestSageMakerTransformSensor: @mock.patch.object(SageMakerHook, "get_conn") @mock.patch.object(SageMakerHook, "describe_transform_job") def test_sensor_with_failure(self, mock_describe_job, mock_client): diff --git a/tests/providers/amazon/aws/sensors/test_sagemaker_tuning.py b/tests/providers/amazon/aws/sensors/test_sagemaker_tuning.py index d7ff9153e4..b89f1a85d5 100644 --- a/tests/providers/amazon/aws/sensors/test_sagemaker_tuning.py +++ b/tests/providers/amazon/aws/sensors/test_sagemaker_tuning.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock import pytest @@ -56,7 +55,7 @@ DESCRIBE_TUNING_STOPPING_RESPONSE = { } -class TestSageMakerTuningSensor(unittest.TestCase): +class TestSageMakerTuningSensor: @mock.patch.object(SageMakerHook, "get_conn") @mock.patch.object(SageMakerHook, "describe_tuning_job") def test_sensor_with_failure(self, mock_describe_job, mock_client): diff --git a/tests/providers/amazon/aws/sensors/test_sqs.py b/tests/providers/amazon/aws/sensors/test_sqs.py index 7a46cbdbb5..73a3ce8278 100644 --- a/tests/providers/amazon/aws/sensors/test_sqs.py +++ b/tests/providers/amazon/aws/sensors/test_sqs.py @@ -18,7 +18,6 @@ from __future__ import annotations import json -import unittest from unittest import mock import pytest @@ -36,8 +35,8 @@ QUEUE_NAME = "test-queue" QUEUE_URL = f"https://{QUEUE_NAME}" -class TestSqsSensor(unittest.TestCase): - def setUp(self): +class TestSqsSensor: + def setup_method(self): args = {"owner": "airflow", "start_date": DEFAULT_DATE} self.dag = DAG("test_dag_id", default_args=args) diff --git a/tests/providers/amazon/aws/sensors/test_step_function.py b/tests/providers/amazon/aws/sensors/test_step_function.py index c3cf62d6a3..d0452dd0c2 100644 --- a/tests/providers/amazon/aws/sensors/test_step_function.py +++ b/tests/providers/amazon/aws/sensors/test_step_function.py @@ -17,12 +17,10 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock from unittest.mock import MagicMock import pytest -from parameterized import parameterized from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.sensors.step_function import StepFunctionExecutionSensor @@ -36,8 +34,8 @@ AWS_CONN_ID = "aws_non_default" REGION_NAME = "us-west-2" -class TestStepFunctionExecutionSensor(unittest.TestCase): - def setUp(self): +class TestStepFunctionExecutionSensor: + def setup_method(self): self.mock_context = MagicMock() def test_init(self): @@ -50,9 +48,9 @@ class TestStepFunctionExecutionSensor(unittest.TestCase): assert AWS_CONN_ID == sensor.aws_conn_id assert REGION_NAME == sensor.region_name - @parameterized.expand([("FAILED",), ("TIMED_OUT",), ("ABORTED",)]) + @pytest.mark.parametrize("mock_status", ["FAILED", "TIMED_OUT", "ABORTED"]) @mock.patch("airflow.providers.amazon.aws.sensors.step_function.StepFunctionHook") - def test_exceptions(self, mock_status, mock_hook): + def test_exceptions(self, mock_hook, mock_status): hook_response = {"status": mock_status} hook_instance = mock_hook.return_value
