Copilot commented on code in PR #59787:
URL: https://github.com/apache/airflow/pull/59787#discussion_r2645436495
##########
providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/batch.py:
##########
@@ -296,24 +302,77 @@ def execute(self, context: Context) -> None:
)
# Add task to job
self.hook.add_single_task_to_job(job_id=self.batch_job_id, task=task)
- # Wait for tasks to complete
- fail_tasks =
self.hook.wait_for_job_tasks_to_complete(job_id=self.batch_job_id,
timeout=self.timeout)
- # Clean up
- if self.should_delete_job:
- # delete job first
- self.clean_up(job_id=self.batch_job_id)
- if self.should_delete_pool:
- self.clean_up(self.batch_pool_id)
- # raise exception if any task fail
- if fail_tasks:
- raise AirflowException(f"Job fail. The failed task are:
{fail_tasks}")
+ if self.deferrable:
+ # Verify pool and nodes are in terminal state before deferral
+ pool = self.hook.connection.pool.get(self.batch_pool_id)
+ nodes =
list(self.hook.connection.compute_node.list(self.batch_pool_id))
+ if pool.resize_errors:
+ raise AirflowException(f"Pool resize errors:
{pool.resize_errors}")
+ self.log.debug("Deferral pre-check: %d nodes present in pool %s",
len(nodes), self.batch_pool_id)
+
+ self.defer(
+ timeout=self.execution_timeout,
+ trigger=AzureBatchJobTrigger(
+ job_id=self.batch_job_id,
+ azure_batch_conn_id=self.azure_batch_conn_id,
+ timeout=self.timeout,
+ ),
+ method_name="execute_complete",
+ )
+ return
+
+ # Wait for tasks to complete (synchronous path) with guaranteed
cleanup on failure
+ sync_failed = False
+ try:
+ fail_tasks = self.hook.wait_for_job_tasks_to_complete(
+ job_id=self.batch_job_id, timeout=self.timeout
+ )
+ if fail_tasks:
+ sync_failed = True
+ raise AirflowException(f"Job fail. The failed task are:
{fail_tasks}")
+ finally:
+ if sync_failed:
+ # Ensure cleanup runs before exception propagates (historical
behavior)
+ if self.should_delete_job:
+ self.clean_up(job_id=self.batch_job_id)
+ if self.should_delete_pool:
+ self.clean_up(self.batch_pool_id)
+ self._cleanup_done = True
+
+ def execute_complete(self, context: Context, event: dict[str, Any] | None
= None) -> str:
Review Comment:
The return type is `str`, but `execute_complete()` doesn't always return a
value. When status is 'failure', 'timeout', or 'error', the method raises an
exception without returning. Either update the return type to `str | None` or
ensure all code paths return a string value.
##########
providers/microsoft/azure/src/airflow/providers/microsoft/azure/triggers/batch.py:
##########
@@ -0,0 +1,95 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+from collections.abc import AsyncIterator
+from datetime import timedelta
+from typing import Any
+
+from azure.batch import models as batch_models
+
+from airflow.providers.microsoft.azure.hooks.batch import AzureBatchHook
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+from airflow.utils import timezone
+
+
+class AzureBatchJobTrigger(BaseTrigger):
+ """
+ Poll Azure Batch for task completion for a given job.
+
+ :param job_id: The Azure Batch job identifier to poll.
+ :param azure_batch_conn_id: Connection id for Azure Batch.
+ :param timeout: Maximum wait time in minutes.
+ :param poll_interval: Seconds to sleep between polls.
+ """
+
+ def __init__(
+ self,
+ job_id: str,
+ azure_batch_conn_id: str = "azure_batch_default",
+ timeout: int = 25,
+ poll_interval: int = 15,
+ ) -> None:
+ super().__init__()
+ self.job_id = job_id
+ self.azure_batch_conn_id = azure_batch_conn_id
+ self.timeout = timeout
+ self.poll_interval = poll_interval
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ """Serialize trigger configuration."""
+ return (
+
"airflow.providers.microsoft.azure.triggers.batch.AzureBatchJobTrigger",
+ {
+ "job_id": self.job_id,
+ "azure_batch_conn_id": self.azure_batch_conn_id,
+ "timeout": self.timeout,
+ "poll_interval": self.poll_interval,
+ },
+ )
+
+ async def run(self) -> AsyncIterator[TriggerEvent]:
+ hook = AzureBatchHook(self.azure_batch_conn_id)
+ timeout_time = timezone.utcnow() + timedelta(minutes=self.timeout)
+
+ try:
+ while timezone.utcnow() < timeout_time:
+ tasks = await asyncio.to_thread(lambda:
list(hook.connection.task.list(self.job_id)))
Review Comment:
The lambda wrapper is unnecessary and creates an additional function call
overhead. Replace with `await asyncio.to_thread(hook.connection.task.list,
self.job_id)` and convert the result to a list if needed, or wrap the entire
`list()` call: `await asyncio.to_thread(list,
hook.connection.task.list(self.job_id))`.
```suggestion
tasks = await asyncio.to_thread(list,
hook.connection.task.list(self.job_id))
```
##########
providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/batch.py:
##########
@@ -296,24 +302,77 @@ def execute(self, context: Context) -> None:
)
# Add task to job
self.hook.add_single_task_to_job(job_id=self.batch_job_id, task=task)
- # Wait for tasks to complete
- fail_tasks =
self.hook.wait_for_job_tasks_to_complete(job_id=self.batch_job_id,
timeout=self.timeout)
- # Clean up
- if self.should_delete_job:
- # delete job first
- self.clean_up(job_id=self.batch_job_id)
- if self.should_delete_pool:
- self.clean_up(self.batch_pool_id)
- # raise exception if any task fail
- if fail_tasks:
- raise AirflowException(f"Job fail. The failed task are:
{fail_tasks}")
+ if self.deferrable:
+ # Verify pool and nodes are in terminal state before deferral
+ pool = self.hook.connection.pool.get(self.batch_pool_id)
+ nodes =
list(self.hook.connection.compute_node.list(self.batch_pool_id))
+ if pool.resize_errors:
+ raise AirflowException(f"Pool resize errors:
{pool.resize_errors}")
+ self.log.debug("Deferral pre-check: %d nodes present in pool %s",
len(nodes), self.batch_pool_id)
+
+ self.defer(
+ timeout=self.execution_timeout,
+ trigger=AzureBatchJobTrigger(
+ job_id=self.batch_job_id,
+ azure_batch_conn_id=self.azure_batch_conn_id,
+ timeout=self.timeout,
+ ),
+ method_name="execute_complete",
+ )
+ return
+
+ # Wait for tasks to complete (synchronous path) with guaranteed
cleanup on failure
+ sync_failed = False
+ try:
+ fail_tasks = self.hook.wait_for_job_tasks_to_complete(
+ job_id=self.batch_job_id, timeout=self.timeout
+ )
+ if fail_tasks:
+ sync_failed = True
+ raise AirflowException(f"Job fail. The failed task are:
{fail_tasks}")
+ finally:
+ if sync_failed:
+ # Ensure cleanup runs before exception propagates (historical
behavior)
+ if self.should_delete_job:
+ self.clean_up(job_id=self.batch_job_id)
+ if self.should_delete_pool:
+ self.clean_up(self.batch_pool_id)
+ self._cleanup_done = True
+
+ def execute_complete(self, context: Context, event: dict[str, Any] | None
= None) -> str:
+ if not event:
+ raise AirflowException("No event received in trigger callback")
+
+ status = event.get("status")
+ fail_task_ids = event.get("fail_task_ids", [])
+
+ if status == "timeout":
+ raise TimeoutError(event.get("message", "Timed out waiting for
tasks to complete"))
+ if status == "error":
+ raise AirflowException(event.get("message", "Unknown error while
waiting for tasks"))
+ if status == "failure" or fail_task_ids:
+ raise AirflowException(f"Job failed. Failed tasks:
{fail_task_ids}")
+ if status != "success":
+ raise AirflowException(f"Unexpected event status: {event}")
+
+ return self.batch_job_id
def on_kill(self) -> None:
response = self.hook.connection.job.terminate(
job_id=self.batch_job_id, terminate_reason="Job killed by user"
)
self.log.info("Azure Batch job (%s) terminated: %s",
self.batch_job_id, response)
+ def post_execute(self, context: Context, result: Any | None = None) ->
None: # type: ignore[override]
Review Comment:
The `# type: ignore[override]` comment suggests a type signature mismatch
with the parent class. Document why this override is necessary or adjust the
signature to match the parent's expected interface.
```suggestion
def post_execute(self, context: Context, result: Any = None) -> None:
```
##########
providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/batch.py:
##########
@@ -296,24 +302,77 @@ def execute(self, context: Context) -> None:
)
# Add task to job
self.hook.add_single_task_to_job(job_id=self.batch_job_id, task=task)
- # Wait for tasks to complete
- fail_tasks =
self.hook.wait_for_job_tasks_to_complete(job_id=self.batch_job_id,
timeout=self.timeout)
- # Clean up
- if self.should_delete_job:
- # delete job first
- self.clean_up(job_id=self.batch_job_id)
- if self.should_delete_pool:
- self.clean_up(self.batch_pool_id)
- # raise exception if any task fail
- if fail_tasks:
- raise AirflowException(f"Job fail. The failed task are:
{fail_tasks}")
+ if self.deferrable:
+ # Verify pool and nodes are in terminal state before deferral
+ pool = self.hook.connection.pool.get(self.batch_pool_id)
+ nodes =
list(self.hook.connection.compute_node.list(self.batch_pool_id))
+ if pool.resize_errors:
+ raise AirflowException(f"Pool resize errors:
{pool.resize_errors}")
+ self.log.debug("Deferral pre-check: %d nodes present in pool %s",
len(nodes), self.batch_pool_id)
+
+ self.defer(
+ timeout=self.execution_timeout,
+ trigger=AzureBatchJobTrigger(
+ job_id=self.batch_job_id,
+ azure_batch_conn_id=self.azure_batch_conn_id,
+ timeout=self.timeout,
+ ),
+ method_name="execute_complete",
+ )
+ return
+
+ # Wait for tasks to complete (synchronous path) with guaranteed
cleanup on failure
+ sync_failed = False
+ try:
+ fail_tasks = self.hook.wait_for_job_tasks_to_complete(
+ job_id=self.batch_job_id, timeout=self.timeout
+ )
+ if fail_tasks:
+ sync_failed = True
+ raise AirflowException(f"Job fail. The failed task are:
{fail_tasks}")
+ finally:
+ if sync_failed:
+ # Ensure cleanup runs before exception propagates (historical
behavior)
+ if self.should_delete_job:
+ self.clean_up(job_id=self.batch_job_id)
+ if self.should_delete_pool:
+ self.clean_up(self.batch_pool_id)
+ self._cleanup_done = True
Review Comment:
The `sync_failed` flag is redundant. The cleanup in the `finally` block only
runs when `sync_failed=True`, which occurs immediately before raising an
exception. Consider restructuring to handle cleanup in an `except` block or
simplifying the control flow since the flag doesn't add value over exception
handling.
```suggestion
fail_tasks = self.hook.wait_for_job_tasks_to_complete(
job_id=self.batch_job_id, timeout=self.timeout
)
if fail_tasks:
# Ensure cleanup runs before exception propagates (historical
behavior)
if self.should_delete_job:
self.clean_up(job_id=self.batch_job_id)
if self.should_delete_pool:
self.clean_up(self.batch_pool_id)
self._cleanup_done = True
raise AirflowException(f"Job fail. The failed task are:
{fail_tasks}")
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]