This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new a1576d39d2 Fix outdated test name and description in BatchSensor 
(#33407)
a1576d39d2 is described below

commit a1576d39d2293b51f64987ba928fea9b5b180e51
Author: Wei Lee <[email protected]>
AuthorDate: Tue Aug 15 23:06:08 2023 +0800

    Fix outdated test name and description in BatchSensor (#33407)
---
 tests/providers/amazon/aws/sensors/test_batch.py | 171 +++++++++++++----------
 1 file changed, 95 insertions(+), 76 deletions(-)

diff --git a/tests/providers/amazon/aws/sensors/test_batch.py 
b/tests/providers/amazon/aws/sensors/test_batch.py
index 74b348381e..353db3b812 100644
--- a/tests/providers/amazon/aws/sensors/test_batch.py
+++ b/tests/providers/amazon/aws/sensors/test_batch.py
@@ -32,78 +32,112 @@ from airflow.providers.amazon.aws.triggers.batch import 
BatchJobTrigger
 TASK_ID = "batch_job_sensor"
 JOB_ID = "8222a1c2-b246-4e19-b1b8-0039bb4407c0"
 AWS_REGION = "eu-west-1"
+ENVIRONMENT_NAME = "environment_name"
+JOB_QUEUE = "job_queue"
 
 
-class TestBatchSensor:
-    def setup_method(self):
-        self.batch_sensor = BatchSensor(
-            task_id="batch_job_sensor",
-            job_id=JOB_ID,
-        )
[email protected](scope="module")
+def batch_sensor() -> BatchSensor:
+    return BatchSensor(
+        task_id="batch_job_sensor",
+        job_id=JOB_ID,
+    )
+
+
[email protected](scope="module")
+def deferrable_batch_sensor() -> BatchSensor:
+    return BatchSensor(task_id="task", job_id=JOB_ID, region_name=AWS_REGION, 
deferrable=True)
+
 
+class TestBatchSensor:
     @mock.patch.object(BatchClientHook, "get_job_description")
-    def test_poke_on_success_state(self, mock_get_job_description):
+    def test_poke_on_success_state(self, mock_get_job_description, 
batch_sensor: BatchSensor):
         mock_get_job_description.return_value = {"status": "SUCCEEDED"}
-        assert self.batch_sensor.poke({}) is True
+        assert batch_sensor.poke({}) is True
         mock_get_job_description.assert_called_once_with(JOB_ID)
 
     @mock.patch.object(BatchClientHook, "get_job_description")
-    def test_poke_on_failure_state(self, mock_get_job_description):
+    def test_poke_on_failure_state(self, mock_get_job_description, 
batch_sensor: BatchSensor):
         mock_get_job_description.return_value = {"status": "FAILED"}
         with pytest.raises(AirflowException, match="Batch sensor failed. AWS 
Batch job status: FAILED"):
-            self.batch_sensor.poke({})
+            batch_sensor.poke({})
 
         mock_get_job_description.assert_called_once_with(JOB_ID)
 
     @mock.patch.object(BatchClientHook, "get_job_description")
-    def test_poke_on_invalid_state(self, mock_get_job_description):
+    def test_poke_on_invalid_state(self, mock_get_job_description, 
batch_sensor: BatchSensor):
         mock_get_job_description.return_value = {"status": "INVALID"}
         with pytest.raises(
             AirflowException, match="Batch sensor failed. Unknown AWS Batch 
job status: INVALID"
         ):
-            self.batch_sensor.poke({})
+            batch_sensor.poke({})
 
         mock_get_job_description.assert_called_once_with(JOB_ID)
 
     @pytest.mark.parametrize("job_status", ["SUBMITTED", "PENDING", 
"RUNNABLE", "STARTING", "RUNNING"])
     @mock.patch.object(BatchClientHook, "get_job_description")
-    def test_poke_on_intermediate_state(self, mock_get_job_description, 
job_status):
+    def test_poke_on_intermediate_state(
+        self, mock_get_job_description, job_status, batch_sensor: BatchSensor
+    ):
         print(job_status)
         mock_get_job_description.return_value = {"status": job_status}
-        assert self.batch_sensor.poke({}) is False
+        assert batch_sensor.poke({}) is False
         mock_get_job_description.assert_called_once_with(JOB_ID)
 
+    def test_execute_in_deferrable_mode(self, deferrable_batch_sensor: 
BatchSensor):
+        """
+        Asserts that a task is deferred and a BatchSensorTrigger will be fired
+        when the BatchSensor is executed in deferrable mode.
+        """
 
-class TestBatchComputeEnvironmentSensor:
-    def setup_method(self):
-        self.environment_name = "environment_name"
-        self.sensor = BatchComputeEnvironmentSensor(
-            task_id="test_batch_compute_environment_sensor",
-            compute_environment=self.environment_name,
-        )
+        with pytest.raises(TaskDeferred) as exc:
+            deferrable_batch_sensor.execute({})
+        assert isinstance(exc.value.trigger, BatchJobTrigger), "Trigger is not 
a BatchJobTrigger"
+
+    def test_execute_failure_in_deferrable_mode(self, deferrable_batch_sensor: 
BatchSensor):
+        """Tests that an AirflowException is raised in case of error event"""
+
+        with pytest.raises(AirflowException):
+            deferrable_batch_sensor.execute_complete(context={}, 
event={"status": "failure"})
+
+
[email protected](scope="module")
+def batch_compute_environment_sensor() -> BatchComputeEnvironmentSensor:
+    return BatchComputeEnvironmentSensor(
+        task_id="test_batch_compute_environment_sensor",
+        compute_environment=ENVIRONMENT_NAME,
+    )
 
+
+class TestBatchComputeEnvironmentSensor:
     @mock.patch.object(BatchClientHook, "client")
-    def test_poke_no_environment(self, mock_batch_client):
+    def test_poke_no_environment(
+        self, mock_batch_client, batch_compute_environment_sensor: 
BatchComputeEnvironmentSensor
+    ):
         mock_batch_client.describe_compute_environments.return_value = 
{"computeEnvironments": []}
         with pytest.raises(AirflowException) as ctx:
-            self.sensor.poke({})
+            batch_compute_environment_sensor.poke({})
         
mock_batch_client.describe_compute_environments.assert_called_once_with(
-            computeEnvironments=[self.environment_name],
+            computeEnvironments=[ENVIRONMENT_NAME],
         )
         assert "not found" in str(ctx.value)
 
     @mock.patch.object(BatchClientHook, "client")
-    def test_poke_valid(self, mock_batch_client):
+    def test_poke_valid(
+        self, mock_batch_client, batch_compute_environment_sensor: 
BatchComputeEnvironmentSensor
+    ):
         mock_batch_client.describe_compute_environments.return_value = {
             "computeEnvironments": [{"status": "VALID"}]
         }
-        assert self.sensor.poke({}) is True
+        assert batch_compute_environment_sensor.poke({}) is True
         
mock_batch_client.describe_compute_environments.assert_called_once_with(
-            computeEnvironments=[self.environment_name],
+            computeEnvironments=[ENVIRONMENT_NAME],
         )
 
     @mock.patch.object(BatchClientHook, "client")
-    def test_poke_running(self, mock_batch_client):
+    def test_poke_running(
+        self, mock_batch_client, batch_compute_environment_sensor: 
BatchComputeEnvironmentSensor
+    ):
         mock_batch_client.describe_compute_environments.return_value = {
             "computeEnvironments": [
                 {
@@ -111,13 +145,15 @@ class TestBatchComputeEnvironmentSensor:
                 }
             ]
         }
-        assert self.sensor.poke({}) is False
+        assert batch_compute_environment_sensor.poke({}) is False
         
mock_batch_client.describe_compute_environments.assert_called_once_with(
-            computeEnvironments=[self.environment_name],
+            computeEnvironments=[ENVIRONMENT_NAME],
         )
 
     @mock.patch.object(BatchClientHook, "client")
-    def test_poke_invalid(self, mock_batch_client):
+    def test_poke_invalid(
+        self, mock_batch_client, batch_compute_environment_sensor: 
BatchComputeEnvironmentSensor
+    ):
         mock_batch_client.describe_compute_environments.return_value = {
             "computeEnvironments": [
                 {
@@ -126,50 +162,53 @@ class TestBatchComputeEnvironmentSensor:
             ]
         }
         with pytest.raises(AirflowException) as ctx:
-            self.sensor.poke({})
+            batch_compute_environment_sensor.poke({})
         
mock_batch_client.describe_compute_environments.assert_called_once_with(
-            computeEnvironments=[self.environment_name],
+            computeEnvironments=[ENVIRONMENT_NAME],
         )
         assert "AWS Batch compute environment failed" in str(ctx.value)
 
 
-class TestBatchJobQueueSensor:
-    def setup_method(self):
-        self.job_queue = "job_queue"
-        self.sensor = BatchJobQueueSensor(
-            task_id="test_batch_job_queue_sensor",
-            job_queue=self.job_queue,
-        )
[email protected](scope="module")
+def batch_job_queue_sensor() -> BatchJobQueueSensor:
+    return BatchJobQueueSensor(
+        task_id="test_batch_job_queue_sensor",
+        job_queue=JOB_QUEUE,
+    )
+
 
+class TestBatchJobQueueSensor:
     @mock.patch.object(BatchClientHook, "client")
-    def test_poke_no_queue(self, mock_batch_client):
+    def test_poke_no_queue(self, mock_batch_client, batch_job_queue_sensor: 
BatchJobQueueSensor):
         mock_batch_client.describe_job_queues.return_value = {"jobQueues": []}
         with pytest.raises(AirflowException) as ctx:
-            self.sensor.poke({})
+            batch_job_queue_sensor.poke({})
         mock_batch_client.describe_job_queues.assert_called_once_with(
-            jobQueues=[self.job_queue],
+            jobQueues=[JOB_QUEUE],
         )
         assert "not found" in str(ctx.value)
 
     @mock.patch.object(BatchClientHook, "client")
-    def test_poke_no_queue_with_treat_non_existing_as_deleted(self, 
mock_batch_client):
-        self.sensor.treat_non_existing_as_deleted = True
+    def test_poke_no_queue_with_treat_non_existing_as_deleted(
+        self, mock_batch_client, batch_job_queue_sensor: BatchJobQueueSensor
+    ):
+        batch_job_queue_sensor.treat_non_existing_as_deleted = True
         mock_batch_client.describe_job_queues.return_value = {"jobQueues": []}
-        assert self.sensor.poke({}) is True
+        assert batch_job_queue_sensor.poke({}) is True
         mock_batch_client.describe_job_queues.assert_called_once_with(
-            jobQueues=[self.job_queue],
+            jobQueues=[JOB_QUEUE],
         )
 
     @mock.patch.object(BatchClientHook, "client")
-    def test_poke_valid(self, mock_batch_client):
+    def test_poke_valid(self, mock_batch_client, batch_job_queue_sensor: 
BatchJobQueueSensor):
         mock_batch_client.describe_job_queues.return_value = {"jobQueues": 
[{"status": "VALID"}]}
-        assert self.sensor.poke({}) is True
+        assert batch_job_queue_sensor.poke({}) is True
         mock_batch_client.describe_job_queues.assert_called_once_with(
-            jobQueues=[self.job_queue],
+            jobQueues=[JOB_QUEUE],
         )
 
     @mock.patch.object(BatchClientHook, "client")
-    def test_poke_running(self, mock_batch_client):
+    def test_poke_running(self, mock_batch_client, batch_job_queue_sensor: 
BatchJobQueueSensor):
         mock_batch_client.describe_job_queues.return_value = {
             "jobQueues": [
                 {
@@ -177,13 +216,13 @@ class TestBatchJobQueueSensor:
                 }
             ]
         }
-        assert self.sensor.poke({}) is False
+        assert batch_job_queue_sensor.poke({}) is False
         mock_batch_client.describe_job_queues.assert_called_once_with(
-            jobQueues=[self.job_queue],
+            jobQueues=[JOB_QUEUE],
         )
 
     @mock.patch.object(BatchClientHook, "client")
-    def test_poke_invalid(self, mock_batch_client):
+    def test_poke_invalid(self, mock_batch_client, batch_job_queue_sensor: 
BatchJobQueueSensor):
         mock_batch_client.describe_job_queues.return_value = {
             "jobQueues": [
                 {
@@ -192,28 +231,8 @@ class TestBatchJobQueueSensor:
             ]
         }
         with pytest.raises(AirflowException) as ctx:
-            self.sensor.poke({})
+            batch_job_queue_sensor.poke({})
         mock_batch_client.describe_job_queues.assert_called_once_with(
-            jobQueues=[self.job_queue],
+            jobQueues=[JOB_QUEUE],
         )
         assert "AWS Batch job queue failed" in str(ctx.value)
-
-
-class TestBatchAsyncSensor:
-    TASK = BatchSensor(task_id="task", job_id=JOB_ID, region_name=AWS_REGION, 
deferrable=True)
-
-    def test_batch_sensor_async(self):
-        """
-        Asserts that a task is deferred and a BatchSensorTrigger will be fired
-        when the BatchSensorAsync is executed.
-        """
-
-        with pytest.raises(TaskDeferred) as exc:
-            self.TASK.execute({})
-        assert isinstance(exc.value.trigger, BatchJobTrigger), "Trigger is not 
a BatchJobTrigger"
-
-    def test_batch_sensor_async_execute_failure(self):
-        """Tests that an AirflowException is raised in case of error event"""
-
-        with pytest.raises(AirflowException):
-            self.TASK.execute_complete(context={}, event={"status": "failure"})

Reply via email to