This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new a1576d39d2 Fix outdated test name and description in BatchSensor
(#33407)
a1576d39d2 is described below
commit a1576d39d2293b51f64987ba928fea9b5b180e51
Author: Wei Lee <[email protected]>
AuthorDate: Tue Aug 15 23:06:08 2023 +0800
Fix outdated test name and description in BatchSensor (#33407)
---
tests/providers/amazon/aws/sensors/test_batch.py | 171 +++++++++++++----------
1 file changed, 95 insertions(+), 76 deletions(-)
diff --git a/tests/providers/amazon/aws/sensors/test_batch.py
b/tests/providers/amazon/aws/sensors/test_batch.py
index 74b348381e..353db3b812 100644
--- a/tests/providers/amazon/aws/sensors/test_batch.py
+++ b/tests/providers/amazon/aws/sensors/test_batch.py
@@ -32,78 +32,112 @@ from airflow.providers.amazon.aws.triggers.batch import
BatchJobTrigger
TASK_ID = "batch_job_sensor"
JOB_ID = "8222a1c2-b246-4e19-b1b8-0039bb4407c0"
AWS_REGION = "eu-west-1"
+ENVIRONMENT_NAME = "environment_name"
+JOB_QUEUE = "job_queue"
-class TestBatchSensor:
- def setup_method(self):
- self.batch_sensor = BatchSensor(
- task_id="batch_job_sensor",
- job_id=JOB_ID,
- )
[email protected](scope="module")
+def batch_sensor() -> BatchSensor:
+ return BatchSensor(
+ task_id="batch_job_sensor",
+ job_id=JOB_ID,
+ )
+
+
[email protected](scope="module")
+def deferrable_batch_sensor() -> BatchSensor:
+ return BatchSensor(task_id="task", job_id=JOB_ID, region_name=AWS_REGION,
deferrable=True)
+
+class TestBatchSensor:
@mock.patch.object(BatchClientHook, "get_job_description")
- def test_poke_on_success_state(self, mock_get_job_description):
+ def test_poke_on_success_state(self, mock_get_job_description,
batch_sensor: BatchSensor):
mock_get_job_description.return_value = {"status": "SUCCEEDED"}
- assert self.batch_sensor.poke({}) is True
+ assert batch_sensor.poke({}) is True
mock_get_job_description.assert_called_once_with(JOB_ID)
@mock.patch.object(BatchClientHook, "get_job_description")
- def test_poke_on_failure_state(self, mock_get_job_description):
+ def test_poke_on_failure_state(self, mock_get_job_description,
batch_sensor: BatchSensor):
mock_get_job_description.return_value = {"status": "FAILED"}
with pytest.raises(AirflowException, match="Batch sensor failed. AWS
Batch job status: FAILED"):
- self.batch_sensor.poke({})
+ batch_sensor.poke({})
mock_get_job_description.assert_called_once_with(JOB_ID)
@mock.patch.object(BatchClientHook, "get_job_description")
- def test_poke_on_invalid_state(self, mock_get_job_description):
+ def test_poke_on_invalid_state(self, mock_get_job_description,
batch_sensor: BatchSensor):
mock_get_job_description.return_value = {"status": "INVALID"}
with pytest.raises(
AirflowException, match="Batch sensor failed. Unknown AWS Batch
job status: INVALID"
):
- self.batch_sensor.poke({})
+ batch_sensor.poke({})
mock_get_job_description.assert_called_once_with(JOB_ID)
@pytest.mark.parametrize("job_status", ["SUBMITTED", "PENDING",
"RUNNABLE", "STARTING", "RUNNING"])
@mock.patch.object(BatchClientHook, "get_job_description")
- def test_poke_on_intermediate_state(self, mock_get_job_description,
job_status):
+ def test_poke_on_intermediate_state(
+ self, mock_get_job_description, job_status, batch_sensor: BatchSensor
+ ):
print(job_status)
mock_get_job_description.return_value = {"status": job_status}
- assert self.batch_sensor.poke({}) is False
+ assert batch_sensor.poke({}) is False
mock_get_job_description.assert_called_once_with(JOB_ID)
+ def test_execute_in_deferrable_mode(self, deferrable_batch_sensor:
BatchSensor):
+ """
+ Asserts that a task is deferred and a BatchSensorTrigger will be fired
+ when the BatchSensor is executed in deferrable mode.
+ """
-class TestBatchComputeEnvironmentSensor:
- def setup_method(self):
- self.environment_name = "environment_name"
- self.sensor = BatchComputeEnvironmentSensor(
- task_id="test_batch_compute_environment_sensor",
- compute_environment=self.environment_name,
- )
+ with pytest.raises(TaskDeferred) as exc:
+ deferrable_batch_sensor.execute({})
+ assert isinstance(exc.value.trigger, BatchJobTrigger), "Trigger is not
a BatchJobTrigger"
+
+ def test_execute_failure_in_deferrable_mode(self, deferrable_batch_sensor:
BatchSensor):
+ """Tests that an AirflowException is raised in case of error event"""
+
+ with pytest.raises(AirflowException):
+ deferrable_batch_sensor.execute_complete(context={},
event={"status": "failure"})
+
+
[email protected](scope="module")
+def batch_compute_environment_sensor() -> BatchComputeEnvironmentSensor:
+ return BatchComputeEnvironmentSensor(
+ task_id="test_batch_compute_environment_sensor",
+ compute_environment=ENVIRONMENT_NAME,
+ )
+
+class TestBatchComputeEnvironmentSensor:
@mock.patch.object(BatchClientHook, "client")
- def test_poke_no_environment(self, mock_batch_client):
+ def test_poke_no_environment(
+ self, mock_batch_client, batch_compute_environment_sensor:
BatchComputeEnvironmentSensor
+ ):
mock_batch_client.describe_compute_environments.return_value =
{"computeEnvironments": []}
with pytest.raises(AirflowException) as ctx:
- self.sensor.poke({})
+ batch_compute_environment_sensor.poke({})
mock_batch_client.describe_compute_environments.assert_called_once_with(
- computeEnvironments=[self.environment_name],
+ computeEnvironments=[ENVIRONMENT_NAME],
)
assert "not found" in str(ctx.value)
@mock.patch.object(BatchClientHook, "client")
- def test_poke_valid(self, mock_batch_client):
+ def test_poke_valid(
+ self, mock_batch_client, batch_compute_environment_sensor:
BatchComputeEnvironmentSensor
+ ):
mock_batch_client.describe_compute_environments.return_value = {
"computeEnvironments": [{"status": "VALID"}]
}
- assert self.sensor.poke({}) is True
+ assert batch_compute_environment_sensor.poke({}) is True
mock_batch_client.describe_compute_environments.assert_called_once_with(
- computeEnvironments=[self.environment_name],
+ computeEnvironments=[ENVIRONMENT_NAME],
)
@mock.patch.object(BatchClientHook, "client")
- def test_poke_running(self, mock_batch_client):
+ def test_poke_running(
+ self, mock_batch_client, batch_compute_environment_sensor:
BatchComputeEnvironmentSensor
+ ):
mock_batch_client.describe_compute_environments.return_value = {
"computeEnvironments": [
{
@@ -111,13 +145,15 @@ class TestBatchComputeEnvironmentSensor:
}
]
}
- assert self.sensor.poke({}) is False
+ assert batch_compute_environment_sensor.poke({}) is False
mock_batch_client.describe_compute_environments.assert_called_once_with(
- computeEnvironments=[self.environment_name],
+ computeEnvironments=[ENVIRONMENT_NAME],
)
@mock.patch.object(BatchClientHook, "client")
- def test_poke_invalid(self, mock_batch_client):
+ def test_poke_invalid(
+ self, mock_batch_client, batch_compute_environment_sensor:
BatchComputeEnvironmentSensor
+ ):
mock_batch_client.describe_compute_environments.return_value = {
"computeEnvironments": [
{
@@ -126,50 +162,53 @@ class TestBatchComputeEnvironmentSensor:
]
}
with pytest.raises(AirflowException) as ctx:
- self.sensor.poke({})
+ batch_compute_environment_sensor.poke({})
mock_batch_client.describe_compute_environments.assert_called_once_with(
- computeEnvironments=[self.environment_name],
+ computeEnvironments=[ENVIRONMENT_NAME],
)
assert "AWS Batch compute environment failed" in str(ctx.value)
-class TestBatchJobQueueSensor:
- def setup_method(self):
- self.job_queue = "job_queue"
- self.sensor = BatchJobQueueSensor(
- task_id="test_batch_job_queue_sensor",
- job_queue=self.job_queue,
- )
[email protected](scope="module")
+def batch_job_queue_sensor() -> BatchJobQueueSensor:
+ return BatchJobQueueSensor(
+ task_id="test_batch_job_queue_sensor",
+ job_queue=JOB_QUEUE,
+ )
+
+class TestBatchJobQueueSensor:
@mock.patch.object(BatchClientHook, "client")
- def test_poke_no_queue(self, mock_batch_client):
+ def test_poke_no_queue(self, mock_batch_client, batch_job_queue_sensor:
BatchJobQueueSensor):
mock_batch_client.describe_job_queues.return_value = {"jobQueues": []}
with pytest.raises(AirflowException) as ctx:
- self.sensor.poke({})
+ batch_job_queue_sensor.poke({})
mock_batch_client.describe_job_queues.assert_called_once_with(
- jobQueues=[self.job_queue],
+ jobQueues=[JOB_QUEUE],
)
assert "not found" in str(ctx.value)
@mock.patch.object(BatchClientHook, "client")
- def test_poke_no_queue_with_treat_non_existing_as_deleted(self,
mock_batch_client):
- self.sensor.treat_non_existing_as_deleted = True
+ def test_poke_no_queue_with_treat_non_existing_as_deleted(
+ self, mock_batch_client, batch_job_queue_sensor: BatchJobQueueSensor
+ ):
+ batch_job_queue_sensor.treat_non_existing_as_deleted = True
mock_batch_client.describe_job_queues.return_value = {"jobQueues": []}
- assert self.sensor.poke({}) is True
+ assert batch_job_queue_sensor.poke({}) is True
mock_batch_client.describe_job_queues.assert_called_once_with(
- jobQueues=[self.job_queue],
+ jobQueues=[JOB_QUEUE],
)
@mock.patch.object(BatchClientHook, "client")
- def test_poke_valid(self, mock_batch_client):
+ def test_poke_valid(self, mock_batch_client, batch_job_queue_sensor:
BatchJobQueueSensor):
mock_batch_client.describe_job_queues.return_value = {"jobQueues":
[{"status": "VALID"}]}
- assert self.sensor.poke({}) is True
+ assert batch_job_queue_sensor.poke({}) is True
mock_batch_client.describe_job_queues.assert_called_once_with(
- jobQueues=[self.job_queue],
+ jobQueues=[JOB_QUEUE],
)
@mock.patch.object(BatchClientHook, "client")
- def test_poke_running(self, mock_batch_client):
+ def test_poke_running(self, mock_batch_client, batch_job_queue_sensor:
BatchJobQueueSensor):
mock_batch_client.describe_job_queues.return_value = {
"jobQueues": [
{
@@ -177,13 +216,13 @@ class TestBatchJobQueueSensor:
}
]
}
- assert self.sensor.poke({}) is False
+ assert batch_job_queue_sensor.poke({}) is False
mock_batch_client.describe_job_queues.assert_called_once_with(
- jobQueues=[self.job_queue],
+ jobQueues=[JOB_QUEUE],
)
@mock.patch.object(BatchClientHook, "client")
- def test_poke_invalid(self, mock_batch_client):
+ def test_poke_invalid(self, mock_batch_client, batch_job_queue_sensor:
BatchJobQueueSensor):
mock_batch_client.describe_job_queues.return_value = {
"jobQueues": [
{
@@ -192,28 +231,8 @@ class TestBatchJobQueueSensor:
]
}
with pytest.raises(AirflowException) as ctx:
- self.sensor.poke({})
+ batch_job_queue_sensor.poke({})
mock_batch_client.describe_job_queues.assert_called_once_with(
- jobQueues=[self.job_queue],
+ jobQueues=[JOB_QUEUE],
)
assert "AWS Batch job queue failed" in str(ctx.value)
-
-
-class TestBatchAsyncSensor:
- TASK = BatchSensor(task_id="task", job_id=JOB_ID, region_name=AWS_REGION,
deferrable=True)
-
- def test_batch_sensor_async(self):
- """
- Asserts that a task is deferred and a BatchSensorTrigger will be fired
- when the BatchSensorAsync is executed.
- """
-
- with pytest.raises(TaskDeferred) as exc:
- self.TASK.execute({})
- assert isinstance(exc.value.trigger, BatchJobTrigger), "Trigger is not
a BatchJobTrigger"
-
- def test_batch_sensor_async_execute_failure(self):
- """Tests that an AirflowException is raised in case of error event"""
-
- with pytest.raises(AirflowException):
- self.TASK.execute_complete(context={}, event={"status": "failure"})