This is an automated email from the ASF dual-hosted git repository.
phanikumv pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 88c9596f4a check job_status before BatchOperator execute in deferrable
mode (#36523)
88c9596f4a is described below
commit 88c9596f4aaff492dda8b0b87fa60ee16444e9b6
Author: Wei Lee <[email protected]>
AuthorDate: Wed Jan 10 20:38:28 2024 +0800
check job_status before BatchOperator execute in deferrable mode (#36523)
---
airflow/providers/amazon/aws/operators/batch.py | 48 +++++++++++------
tests/providers/amazon/aws/operators/test_batch.py | 63 +++++++++++++++++++++-
2 files changed, 94 insertions(+), 17 deletions(-)
diff --git a/airflow/providers/amazon/aws/operators/batch.py
b/airflow/providers/amazon/aws/operators/batch.py
index fe6f9dadb6..8a124b4027 100644
--- a/airflow/providers/amazon/aws/operators/batch.py
+++ b/airflow/providers/amazon/aws/operators/batch.py
@@ -230,7 +230,7 @@ class BatchOperator(BaseOperator):
region_name=self.region_name,
)
- def execute(self, context: Context):
+ def execute(self, context: Context) -> str | None:
"""Submit and monitor an AWS Batch job.
:raises: AirflowException
@@ -238,28 +238,46 @@ class BatchOperator(BaseOperator):
self.submit_job(context)
if self.deferrable:
- self.defer(
- timeout=self.execution_timeout,
- trigger=BatchJobTrigger(
- job_id=self.job_id,
- waiter_max_attempts=self.max_retries,
- aws_conn_id=self.aws_conn_id,
- region_name=self.region_name,
- waiter_delay=self.poll_interval,
- ),
- method_name="execute_complete",
- )
+ if not self.job_id:
+ raise AirflowException("AWS Batch job - job_id was not found")
+
+ job = self.hook.get_job_description(self.job_id)
+ job_status = job.get("status")
+ if job_status == self.hook.SUCCESS_STATE:
+ self.log.info("Job completed.")
+ return self.job_id
+ elif job_status == self.hook.FAILURE_STATE:
+ raise AirflowException(f"Error while running job:
{self.job_id} is in {job_status} state")
+ elif job_status in self.hook.INTERMEDIATE_STATES:
+ self.defer(
+ timeout=self.execution_timeout,
+ trigger=BatchJobTrigger(
+ job_id=self.job_id,
+ waiter_max_attempts=self.max_retries,
+ aws_conn_id=self.aws_conn_id,
+ region_name=self.region_name,
+ waiter_delay=self.poll_interval,
+ ),
+ method_name="execute_complete",
+ )
+
+ raise AirflowException(f"Unexpected status: {job_status}")
if self.wait_for_completion:
self.monitor_job(context)
return self.job_id
- def execute_complete(self, context, event=None):
+ def execute_complete(self, context: Context, event: dict[str, Any] | None
= None) -> str:
+ if event is None:
+ err_msg = "Trigger error: event is None"
+ self.log.info(err_msg)
+ raise AirflowException(err_msg)
+
if event["status"] != "success":
raise AirflowException(f"Error while running job: {event}")
- else:
- self.log.info("Job completed.")
+
+ self.log.info("Job completed.")
return event["job_id"]
def on_kill(self):
diff --git a/tests/providers/amazon/aws/operators/test_batch.py
b/tests/providers/amazon/aws/operators/test_batch.py
index 020f071786..313d721b3a 100644
--- a/tests/providers/amazon/aws/operators/test_batch.py
+++ b/tests/providers/amazon/aws/operators/test_batch.py
@@ -268,8 +268,11 @@ class TestBatchOperator:
container_overrides={"a": "b"},
)
+ @mock.patch.object(BatchClientHook, "get_job_description")
@mock.patch("airflow.providers.amazon.aws.hooks.batch_client.AwsBaseHook.get_client_type")
- def test_defer_if_deferrable_param_set(self, mock_client):
+ def test_defer_if_deferrable_param_set(self, mock_client,
mock_get_job_description):
+ mock_get_job_description.return_value = {"status": "SUBMITTED"}
+
batch = BatchOperator(
task_id="task",
job_name=JOB_NAME,
@@ -280,9 +283,65 @@ class TestBatchOperator:
)
with pytest.raises(TaskDeferred) as exc:
- batch.execute(context=None)
+ batch.execute(self.mock_context)
assert isinstance(exc.value.trigger, BatchJobTrigger)
+
@mock.patch("airflow.providers.amazon.aws.hooks.batch_client.AwsBaseHook.get_client_type")
+ def test_defer_but_failed_due_to_job_id_not_found(self, mock_client):
+ """Test that an AirflowException is raised if job_id is not set before
deferral."""
+ mock_client.return_value.submit_job.return_value = {
+ "jobName": JOB_NAME,
+ "jobId": None,
+ }
+
+ batch = BatchOperator(
+ task_id="task",
+ job_name=JOB_NAME,
+ job_queue="queue",
+ job_definition="hello-world",
+ do_xcom_push=False,
+ deferrable=True,
+ )
+ with pytest.raises(AirflowException) as exc:
+ batch.execute(self.mock_context)
+ assert "AWS Batch job - job_id was not found" in str(exc.value)
+
+ @mock.patch.object(BatchClientHook, "get_job_description")
+
@mock.patch("airflow.providers.amazon.aws.hooks.batch_client.AwsBaseHook.get_client_type")
+ def test_defer_but_success_before_deferred(self, mock_client,
mock_get_job_description):
+ """Test that an AirflowException is raised if job_id is not set before
deferral."""
+ mock_client.return_value.submit_job.return_value =
RESPONSE_WITHOUT_FAILURES
+ mock_get_job_description.return_value = {"status": "SUCCEEDED"}
+
+ batch = BatchOperator(
+ task_id="task",
+ job_name=JOB_NAME,
+ job_queue="queue",
+ job_definition="hello-world",
+ do_xcom_push=False,
+ deferrable=True,
+ )
+ assert batch.execute(self.mock_context) == JOB_ID
+
+ @mock.patch.object(BatchClientHook, "get_job_description")
+
@mock.patch("airflow.providers.amazon.aws.hooks.batch_client.AwsBaseHook.get_client_type")
+ def test_defer_but_fail_before_deferred(self, mock_client,
mock_get_job_description):
+ """Test that an AirflowException is raised if job_id is not set before
deferral."""
+ mock_client.return_value.submit_job.return_value =
RESPONSE_WITHOUT_FAILURES
+ mock_get_job_description.return_value = {"status": "FAILED"}
+
+ batch = BatchOperator(
+ task_id="task",
+ job_name=JOB_NAME,
+ job_queue="queue",
+ job_definition="hello-world",
+ do_xcom_push=False,
+ deferrable=True,
+ )
+ with pytest.raises(AirflowException) as exc:
+ batch.execute(self.mock_context)
+ assert f"Error while running job: {JOB_ID} is in FAILED state" in
str(exc.value)
+
@mock.patch.object(BatchClientHook, "get_job_description")
@mock.patch.object(BatchClientHook, "wait_for_job")
@mock.patch.object(BatchClientHook, "check_job_success")