dabla commented on code in PR #59798:
URL: https://github.com/apache/airflow/pull/59798#discussion_r2664744862


##########
providers/microsoft/azure/tests/unit/microsoft/azure/triggers/test_batch.py:
##########
@@ -0,0 +1,267 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+import json
+from unittest import mock
+
+import pytest
+from azure.batch import models as batch_models
+
+from airflow.providers.microsoft.azure.triggers.batch import 
AzureBatchJobTrigger
+from airflow.triggers.base import TriggerEvent
+
+
+def test_azure_batch_job_trigger_serialize():

Review Comment:
   Maybe put all tests under a class named TestAzureBatchJobTrigger.



##########
providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/batch.py:
##########
@@ -296,24 +313,67 @@ def execute(self, context: Context) -> None:
         )
         # Add task to job
         self.hook.add_single_task_to_job(job_id=self.batch_job_id, task=task)
-        # Wait for tasks to complete
-        fail_tasks = 
self.hook.wait_for_job_tasks_to_complete(job_id=self.batch_job_id, 
timeout=self.timeout)
-        # Clean up
-        if self.should_delete_job:
-            # delete job first
-            self.clean_up(job_id=self.batch_job_id)
-        if self.should_delete_pool:
-            self.clean_up(self.batch_pool_id)
-        # raise exception if any task fail
+        if self.deferrable:
+            # Verify pool and nodes are in terminal state before deferral
+            pool = self.hook.connection.pool.get(self.batch_pool_id)
+            nodes = 
list(self.hook.connection.compute_node.list(self.batch_pool_id))
+            if pool.resize_errors:
+                raise AirflowException(f"Pool resize errors: 
{pool.resize_errors}")
+            self.log.debug("Deferral pre-check: %d nodes present in pool %s", 
len(nodes), self.batch_pool_id)
+
+            self.defer(
+                timeout=self.execution_timeout,
+                trigger=AzureBatchJobTrigger(
+                    job_id=self.batch_job_id,
+                    azure_batch_conn_id=self.azure_batch_conn_id,
+                    timeout=self.timeout,
+                    poll_interval=self.poll_interval,
+                ),
+                method_name="execute_complete",
+            )
+            return
+
+        # Wait for tasks to complete (synchronous path)
+        fail_tasks = self.hook.wait_for_job_tasks_to_complete(
+            job_id=self.batch_job_id, timeout=self.timeout
+        )
         if fail_tasks:
             raise AirflowException(f"Job fail. The failed task are: 
{fail_tasks}")
 
+    def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> str:
+        if not event:
+            raise AirflowException("No event received in trigger callback")
+
+        status = event.get("status")
+        fail_task_ids = event.get("fail_task_ids", [])
+
+        if status == "timeout":
+            raise AirflowException(event.get("message", "Timed out waiting for 
tasks to complete"))
+        if status == "error":
+            raise AirflowException(event.get("message", "Unknown error while 
waiting for tasks"))
+        if status == "failure" or fail_task_ids:
+            raise AirflowException(f"Job failed. Failed tasks: 
{fail_task_ids}")
+        if status != "success":
+            raise AirflowException(f"Unexpected event status: {event}")
+
+        return self.batch_job_id
+
     def on_kill(self) -> None:
         response = self.hook.connection.job.terminate(
             job_id=self.batch_job_id, terminate_reason="Job killed by user"
         )
         self.log.info("Azure Batch job (%s) terminated: %s", 
self.batch_job_id, response)
 
+    def post_execute(self, context: Context, result: Any | None = None) -> 
None:  # type: ignore[override]
+        """Perform cleanup after task completion in both deferrable and 
non-deferrable modes."""
+        if getattr(self, "_cleanup_done", False):

Review Comment:
   How come we have to use getattr here for _cleanup_done?



##########
providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/batch.py:
##########
@@ -296,24 +313,67 @@ def execute(self, context: Context) -> None:
         )
         # Add task to job
         self.hook.add_single_task_to_job(job_id=self.batch_job_id, task=task)
-        # Wait for tasks to complete
-        fail_tasks = 
self.hook.wait_for_job_tasks_to_complete(job_id=self.batch_job_id, 
timeout=self.timeout)
-        # Clean up
-        if self.should_delete_job:
-            # delete job first
-            self.clean_up(job_id=self.batch_job_id)
-        if self.should_delete_pool:
-            self.clean_up(self.batch_pool_id)
-        # raise exception if any task fail
+        if self.deferrable:
+            # Verify pool and nodes are in terminal state before deferral
+            pool = self.hook.connection.pool.get(self.batch_pool_id)
+            nodes = 
list(self.hook.connection.compute_node.list(self.batch_pool_id))
+            if pool.resize_errors:

Review Comment:
   This check could be done before computing the node list no?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to