Ankurdeewan commented on code in PR #59798:
URL: https://github.com/apache/airflow/pull/59798#discussion_r2700714240
##########
providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/batch.py:
##########
@@ -296,24 +313,67 @@ def execute(self, context: Context) -> None:
)
# Add task to job
self.hook.add_single_task_to_job(job_id=self.batch_job_id, task=task)
- # Wait for tasks to complete
- fail_tasks =
self.hook.wait_for_job_tasks_to_complete(job_id=self.batch_job_id,
timeout=self.timeout)
- # Clean up
- if self.should_delete_job:
- # delete job first
- self.clean_up(job_id=self.batch_job_id)
- if self.should_delete_pool:
- self.clean_up(self.batch_pool_id)
- # raise exception if any task fail
+ if self.deferrable:
+ # Verify pool and nodes are in terminal state before deferral
+ pool = self.hook.connection.pool.get(self.batch_pool_id)
+ nodes =
list(self.hook.connection.compute_node.list(self.batch_pool_id))
+ if pool.resize_errors:
+ raise AirflowException(f"Pool resize errors:
{pool.resize_errors}")
+ self.log.debug("Deferral pre-check: %d nodes present in pool %s",
len(nodes), self.batch_pool_id)
+
+ self.defer(
+ timeout=self.execution_timeout,
+ trigger=AzureBatchJobTrigger(
+ job_id=self.batch_job_id,
+ azure_batch_conn_id=self.azure_batch_conn_id,
+ timeout=self.timeout,
+ poll_interval=self.poll_interval,
+ ),
+ method_name="execute_complete",
+ )
+ return
+
+ # Wait for tasks to complete (synchronous path)
+ fail_tasks = self.hook.wait_for_job_tasks_to_complete(
+ job_id=self.batch_job_id, timeout=self.timeout
+ )
if fail_tasks:
raise AirflowException(f"Job fail. The failed task are:
{fail_tasks}")
+ def execute_complete(self, context: Context, event: dict[str, Any] | None
= None) -> str:
+ if not event:
+ raise AirflowException("No event received in trigger callback")
+
+ status = event.get("status")
+ fail_task_ids = event.get("fail_task_ids", [])
+
+ if status == "timeout":
+ raise AirflowException(event.get("message", "Timed out waiting for
tasks to complete"))
+ if status == "error":
+ raise AirflowException(event.get("message", "Unknown error while
waiting for tasks"))
+ if status == "failure" or fail_task_ids:
+ raise AirflowException(f"Job failed. Failed tasks:
{fail_task_ids}")
+ if status != "success":
+ raise AirflowException(f"Unexpected event status: {event}")
+
+ return self.batch_job_id
+
def on_kill(self) -> None:
response = self.hook.connection.job.terminate(
job_id=self.batch_job_id, terminate_reason="Job killed by user"
)
self.log.info("Azure Batch job (%s) terminated: %s",
self.batch_job_id, response)
+ def post_execute(self, context: Context, result: Any | None = None) ->
None: # type: ignore[override]
+ """Perform cleanup after task completion in both deferrable and
non-deferrable modes."""
+ if getattr(self, "_cleanup_done", False):
Review Comment:
Yeah, this is mainly because _cleanup_done might not exist in some paths;
especially after deferral or when the task gets deserialized on resume. Using
getattr(..., False) just keeps it safe and still guarantees cleanup only runs
once.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]