This is an automated email from the ASF dual-hosted git repository.

jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 26c6dc2d907 Add deferrable mode support to AzureBatchOperator (#66815)
26c6dc2d907 is described below

commit 26c6dc2d907b56cbbca45d9a871d30fd67bd0e35
Author: SameerMesiah97 <[email protected]>
AuthorDate: Sun May 31 20:46:58 2026 +0100

    Add deferrable mode support to AzureBatchOperator (#66815)
    
    * Add deferrable mode support to AzureBatchOperator by introducing
    AzureBatchTrigger and delegating Azure Batch job monitoring to the
    trigger when deferrable=True.
    
    Implement execute_complete handling for success, timeout, failure,
    and empty-task trigger events, with cleanup support for terminal
    states.
    
    Add unit tests covering trigger serialization, task state handling,
    timeouts, mixed and empty task lists, exception handling, and
    execute_complete event processing.
    
    * Adjusted docs and end_time computation to reflect operator intent.
    
    ---------
    
    Co-authored-by: Sameer Mesiah <[email protected]>
---
 providers/microsoft/azure/docs/operators/batch.rst |   9 +
 providers/microsoft/azure/provider.yaml            |   3 +
 .../providers/microsoft/azure/get_provider_info.py |   4 +
 .../providers/microsoft/azure/operators/batch.py   |  75 +++++-
 .../providers/microsoft/azure/triggers/batch.py    | 176 +++++++++++++
 .../azure/example_azure_batch_operator.py          |  19 ++
 .../unit/microsoft/azure/operators/test_batch.py   | 183 ++++++++++++-
 .../unit/microsoft/azure/triggers/test_batch.py    | 287 +++++++++++++++++++++
 scripts/ci/prek/known_airflow_exceptions.txt       |   2 +-
 9 files changed, 754 insertions(+), 4 deletions(-)

diff --git a/providers/microsoft/azure/docs/operators/batch.rst 
b/providers/microsoft/azure/docs/operators/batch.rst
index 8cc5cc63006..9b7722d4686 100644
--- a/providers/microsoft/azure/docs/operators/batch.rst
+++ b/providers/microsoft/azure/docs/operators/batch.rst
@@ -32,6 +32,15 @@ Below is an example of using this operator to trigger a task 
on Azure Batch
     :start-after: [START howto_azure_batch_operator]
     :end-before: [END howto_azure_batch_operator]
 
+Below is an example of using this operator to trigger a task on Azure Batch 
with the deferrable flag,
+so that polling for job/task completion occurs on the Airflow Triggerer.
+
+  .. exampleinclude:: 
/../tests/system/microsoft/azure/example_azure_batch_operator.py
+      :language: python
+      :dedent: 4
+      :start-after: [START howto_azure_batch_operator_with_deferrable_flag]
+      :end-before: [END howto_azure_batch_operator_with_deferrable_flag]
+
 
 Reference
 ---------
diff --git a/providers/microsoft/azure/provider.yaml 
b/providers/microsoft/azure/provider.yaml
index a5cbd0bc1e8..801f283bb99 100644
--- a/providers/microsoft/azure/provider.yaml
+++ b/providers/microsoft/azure/provider.yaml
@@ -310,6 +310,9 @@ hooks:
       - airflow.providers.microsoft.azure.hooks.powerbi
 
 triggers:
+  - integration-name: Microsoft Azure Batch
+    python-modules:
+      - airflow.providers.microsoft.azure.triggers.batch
   - integration-name: Microsoft Azure Compute
     python-modules:
       - airflow.providers.microsoft.azure.triggers.compute
diff --git 
a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/get_provider_info.py
 
b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/get_provider_info.py
index 0b312ec09a9..40364bf1845 100644
--- 
a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/get_provider_info.py
+++ 
b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/get_provider_info.py
@@ -284,6 +284,10 @@ def get_provider_info():
             },
         ],
         "triggers": [
+            {
+                "integration-name": "Microsoft Azure Batch",
+                "python-modules": 
["airflow.providers.microsoft.azure.triggers.batch"],
+            },
             {
                 "integration-name": "Microsoft Azure Compute",
                 "python-modules": 
["airflow.providers.microsoft.azure.triggers.compute"],
diff --git 
a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/batch.py
 
b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/batch.py
index e5b36841c66..4e9216ec181 100644
--- 
a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/batch.py
+++ 
b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/batch.py
@@ -17,14 +17,16 @@
 # under the License.
 from __future__ import annotations
 
+import time
 from collections.abc import Sequence
 from functools import cached_property
 from typing import TYPE_CHECKING, Any
 
 from azure.batch import models as batch_models
 
-from airflow.providers.common.compat.sdk import AirflowException, BaseOperator
+from airflow.providers.common.compat.sdk import AirflowException, 
BaseOperator, conf
 from airflow.providers.microsoft.azure.hooks.batch import AzureBatchHook
+from airflow.providers.microsoft.azure.triggers.batch import AzureBatchTrigger
 
 if TYPE_CHECKING:
     from airflow.sdk import Context
@@ -91,6 +93,10 @@ class AzureBatchOperator(BaseOperator):
     :param timeout: The amount of time to wait for the job to complete in 
minutes. Default is 25
     :param should_delete_job: Whether to delete job after execution. Default 
is False
     :param should_delete_pool: Whether to delete pool after execution of jobs. 
Default is False
+    :param poll_interval: Polling interval in seconds for deferrable mode. 
Default is 30.
+        Determines how frequently the trigger checks task completion status 
when deferrable=True.
+    :param deferrable: Run operator in deferrable mode.
+
     """
 
     template_fields: Sequence[str] = (
@@ -139,6 +145,8 @@ class AzureBatchOperator(BaseOperator):
         timeout: int = 25,
         should_delete_job: bool = False,
         should_delete_pool: bool = False,
+        poll_interval: int = 30,
+        deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
@@ -176,6 +184,8 @@ class AzureBatchOperator(BaseOperator):
         self.timeout = timeout
         self.should_delete_job = should_delete_job
         self.should_delete_pool = should_delete_pool
+        self.poll_interval = poll_interval
+        self.deferrable = deferrable
 
     @cached_property
     def hook(self) -> AzureBatchHook:
@@ -265,6 +275,7 @@ class AzureBatchOperator(BaseOperator):
             start_task=self.batch_start_task,
         )
         self.hook.create_pool(pool)
+
         # Wait for nodes to reach complete state
         self.hook.wait_for_all_node_state(
             self.batch_pool_id,
@@ -296,6 +307,29 @@ class AzureBatchOperator(BaseOperator):
         )
         # Add task to job
         self.hook.add_single_task_to_job(job_id=self.batch_job_id, task=task)
+
+        if self.deferrable:
+            # Pre-deferral check (node readiness is already enforced by 
wait_for_all_node_state above)
+            pool = self.hook.connection.pool.get(self.batch_pool_id)
+            if pool.resize_errors:
+                raise RuntimeError(f"Pool resize errors: {pool.resize_errors}")
+
+            nodes = 
list(self.hook.connection.compute_node.list(self.batch_pool_id))
+            self.log.debug("Deferral pre-check: %d nodes present in pool %s", 
len(nodes), self.batch_pool_id)
+            end_time = time.time() + (self.timeout * 60)
+
+            self.defer(
+                timeout=self.execution_timeout,
+                trigger=AzureBatchTrigger(
+                    job_id=self.batch_job_id,
+                    azure_batch_conn_id=self.azure_batch_conn_id,
+                    end_time=end_time,
+                    poll_interval=self.poll_interval,
+                ),
+                method_name="execute_complete",
+            )
+            return
+
         # Wait for tasks to complete
         fail_tasks = 
self.hook.wait_for_job_tasks_to_complete(job_id=self.batch_job_id, 
timeout=self.timeout)
         # Clean up
@@ -306,7 +340,44 @@ class AzureBatchOperator(BaseOperator):
             self.clean_up(self.batch_pool_id)
         # raise exception if any task fail
         if fail_tasks:
-            raise AirflowException(f"Job fail. The failed task are: 
{fail_tasks}")
+            raise RuntimeError(f"Job fail. The failed task are: {fail_tasks}")
+
+    def execute_complete(self, context: Context, event: dict[str, Any] | None) 
-> None:
+        """
+        Return immediately - callback for when the trigger fires.
+
+        The trigger communicates the terminal Azure Batch job state
+        through the event payload.
+        """
+        if event is None:
+            raise RuntimeError("Trigger returned no event.")
+
+        status = event.get("status")
+        message = event.get("message", "No message returned from trigger.")
+        failed_tasks = event.get("failed_tasks")
+
+        try:
+            if status == "success":
+                self.log.info(message)
+                return
+
+            if status == "timeout":
+                raise RuntimeError(message)
+
+            if status == "error":
+                if failed_tasks:
+                    raise RuntimeError(f"{message} Failed tasks: 
{failed_tasks}")
+
+                raise RuntimeError(message)
+
+            raise RuntimeError(f"Unexpected trigger event received: {event}")
+
+        finally:
+            if self.should_delete_job:
+                self.clean_up(job_id=self.batch_job_id)
+
+            if self.should_delete_pool:
+                self.clean_up(pool_id=self.batch_pool_id)
 
     def on_kill(self) -> None:
         response = self.hook.connection.job.terminate(
diff --git 
a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/triggers/batch.py
 
b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/triggers/batch.py
new file mode 100644
index 00000000000..65bf5eb0217
--- /dev/null
+++ 
b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/triggers/batch.py
@@ -0,0 +1,176 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import asyncio
+import time
+from collections.abc import AsyncIterator
+from typing import Any
+
+from azure.batch import models as batch_models
+
+from airflow.providers.microsoft.azure.hooks.batch import AzureBatchHook
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+
+
+class AzureBatchTrigger(BaseTrigger):
+    """
+    Trigger when Azure Batch job tasks reach a terminal state.
+
+    :param job_id: Azure Batch job identifier.
+    :param azure_batch_conn_id: Azure Batch connection id.
+    :param end_time: Absolute timeout deadline as determined using 
``time.time()``.
+    :param poll_interval: Poll interval in seconds.
+    """
+
+    def __init__(
+        self,
+        job_id: str,
+        azure_batch_conn_id: str,
+        end_time: float,
+        poll_interval: int = 30,
+    ):
+        super().__init__()
+
+        self.job_id = job_id
+        self.azure_batch_conn_id = azure_batch_conn_id
+        self.end_time = end_time
+        self.poll_interval = poll_interval
+
+    def serialize(self) -> tuple[str, dict[str, Any]]:
+        """Serialize trigger arguments and classpath."""
+        return (
+            f"{self.__class__.__module__}.{self.__class__.__name__}",
+            {
+                "job_id": self.job_id,
+                "azure_batch_conn_id": self.azure_batch_conn_id,
+                "end_time": self.end_time,
+                "poll_interval": self.poll_interval,
+            },
+        )
+
+    def _get_incomplete_tasks(
+        self,
+        tasks: list[batch_models.CloudTask],
+    ) -> list[batch_models.CloudTask]:
+        """Return tasks that have not yet completed."""
+        return [task for task in tasks if task.state != 
batch_models.TaskState.completed]
+
+    def _build_trigger_event(
+        self,
+        tasks: list[batch_models.CloudTask],
+    ) -> TriggerEvent | None:
+        """
+        Convert Batch task states to TriggerEvent.
+
+        Returns None if tasks are still running.
+        """
+        if not tasks:
+            return TriggerEvent(
+                {
+                    "status": "error",
+                    "message": f"Azure Batch job {self.job_id} contains no 
tasks.",
+                    "job_id": self.job_id,
+                }
+            )
+
+        if self._get_incomplete_tasks(tasks):
+            return None
+
+        failed_tasks = [
+            task.id
+            for task in tasks
+            if task.execution_info and task.execution_info.result == 
batch_models.TaskExecutionResult.failure
+        ]
+
+        if failed_tasks:
+            return TriggerEvent(
+                {
+                    "status": "error",
+                    "message": f"Azure Batch job {self.job_id} failed.",
+                    "job_id": self.job_id,
+                    "failed_tasks": failed_tasks,
+                }
+            )
+
+        return TriggerEvent(
+            {
+                "status": "success",
+                "message": f"Azure Batch job {self.job_id} completed 
successfully.",
+                "job_id": self.job_id,
+            }
+        )
+
+    async def run(self) -> AsyncIterator[TriggerEvent]:
+        """Poll Azure Batch job tasks until completion or timeout."""
+        hook = AzureBatchHook(
+            azure_batch_conn_id=self.azure_batch_conn_id,
+        )
+
+        try:
+            while time.time() <= self.end_time:
+                tasks = await asyncio.to_thread(lambda: 
list(hook.connection.task.list(self.job_id)))
+
+                event = self._build_trigger_event(tasks)
+
+                if event:
+                    yield event
+                    return
+
+                incomplete_tasks = self._get_incomplete_tasks(tasks)
+
+                self.log.info(
+                    "Azure Batch job %s still running. Incomplete tasks: %s. 
Sleeping %s seconds.",
+                    self.job_id,
+                    incomplete_tasks,
+                    self.poll_interval,
+                )
+
+                await asyncio.sleep(self.poll_interval)
+
+            # Final check before timeout event in case job completed
+            # during the last sleep interval.
+            tasks = await asyncio.to_thread(lambda: 
list(hook.connection.task.list(self.job_id)))
+
+            event = self._build_trigger_event(tasks)
+
+            if event:
+                yield event
+                return
+
+            yield TriggerEvent(
+                {
+                    "status": "timeout",
+                    "message": f"Timeout waiting for Azure Batch job 
{self.job_id}.",
+                    "job_id": self.job_id,
+                }
+            )
+
+        except Exception as e:
+            self.log.exception(
+                "Azure Batch trigger failed for job %s.",
+                self.job_id,
+            )
+
+            yield TriggerEvent(
+                {
+                    "status": "error",
+                    "message": str(e),
+                    "job_id": self.job_id,
+                }
+            )
diff --git 
a/providers/microsoft/azure/tests/system/microsoft/azure/example_azure_batch_operator.py
 
b/providers/microsoft/azure/tests/system/microsoft/azure/example_azure_batch_operator.py
index f5c563aab77..b6f6dded181 100644
--- 
a/providers/microsoft/azure/tests/system/microsoft/azure/example_azure_batch_operator.py
+++ 
b/providers/microsoft/azure/tests/system/microsoft/azure/example_azure_batch_operator.py
@@ -57,6 +57,25 @@ with DAG(
     )
     # [END howto_azure_batch_operator]
 
+    # [START howto_azure_batch_operator_with_deferrable_flag]
+    azure_batch_operator_deferrable = AzureBatchOperator(
+        task_id="azure_batch_deferrable",
+        batch_pool_id=POOL_ID,
+        batch_pool_vm_size="standard_d2s_v3",
+        batch_job_id="example-job",
+        batch_task_command_line="/bin/bash -c 'set -e; set -o pipefail; echo 
hello world!; wait'",
+        batch_task_id="example-task",
+        vm_node_agent_sku_id="batch.node.ubuntu 22.04",
+        vm_publisher="Canonical",
+        vm_offer="0001-com-ubuntu-server-jammy",
+        vm_sku="22_04-lts-gen2",
+        target_dedicated_nodes=1,
+        deferrable=True,
+    )
+    # [END howto_azure_batch_operator_with_deferrable_flag]
+
+    azure_batch_operator >> azure_batch_operator_deferrable
+
 from tests_common.test_utils.system_tests import get_test_run  # noqa: E402
 
 # Needed to run the example DAG with pytest (see: 
contributing-docs/testing/system_tests.rst)
diff --git 
a/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_batch.py 
b/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_batch.py
index 160aea3df7e..7c9fba490cc 100644
--- 
a/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_batch.py
+++ 
b/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_batch.py
@@ -23,9 +23,10 @@ from unittest import mock
 import pytest
 
 from airflow.models import Connection
-from airflow.providers.common.compat.sdk import AirflowException
+from airflow.providers.common.compat.sdk import AirflowException, TaskDeferred
 from airflow.providers.microsoft.azure.hooks.batch import AzureBatchHook
 from airflow.providers.microsoft.azure.operators.batch import 
AzureBatchOperator
+from airflow.providers.microsoft.azure.triggers.batch import AzureBatchTrigger
 
 TASK_ID = "MyDag"
 BATCH_POOL_ID = "MyPool"
@@ -247,3 +248,183 @@ class TestAzureBatchOperator:
         self.operator.clean_up("mypool", "myjob")
         self.batch_client.job.delete.assert_called_with("myjob")
         self.batch_client.pool.delete.assert_called_with("mypool")
+
+
+class TestAzureBatchOperatorDeferrable:
+    @pytest.fixture(autouse=True)
+    def setup_test_cases(self, mocked_batch_service_client, 
create_mock_connections):
+        self.batch_client = mock.MagicMock(name="FakeBatchServiceClient")
+        mocked_batch_service_client.return_value = self.batch_client
+
+        self.test_conn_id = "test_azure_batch"
+        self.test_account_url = "http://test-endpoint:29000";
+
+        create_mock_connections(
+            Connection(
+                conn_id=self.test_conn_id,
+                conn_type="azure_batch",
+                extra=json.dumps({"account_url": self.test_account_url}),
+            ),
+        )
+
+        self.operator = AzureBatchOperator(
+            task_id=TASK_ID,
+            batch_pool_id=BATCH_POOL_ID,
+            batch_pool_vm_size=BATCH_VM_SIZE,
+            batch_job_id=BATCH_JOB_ID,
+            batch_task_id=BATCH_TASK_ID,
+            batch_task_command_line="echo hello",
+            vm_node_agent_sku_id="node-agent",
+            os_family="4",
+            target_dedicated_nodes=1,
+            azure_batch_conn_id=self.test_conn_id,
+            deferrable=True,
+        )
+
+    @mock.patch.object(AzureBatchHook, "wait_for_all_node_state")
+    def test_execute_defers(self, wait_mock):
+
+        wait_mock.return_value = True
+        self.batch_client.pool.get.return_value.resize_errors = None
+
+        with pytest.raises(TaskDeferred) as ctx:
+            self.operator.execute(None)
+
+        trigger = ctx.value.trigger
+
+        assert isinstance(trigger, AzureBatchTrigger)
+
+        assert trigger.job_id == BATCH_JOB_ID
+        assert trigger.azure_batch_conn_id == self.test_conn_id
+
+        self.batch_client.pool.add.assert_called()
+        self.batch_client.job.add.assert_called()
+        self.batch_client.task.add.assert_called()
+
+    def test_execute_complete_success(self):
+        with mock.patch.object(self.operator.log, "info") as mock_log:
+            self.operator.execute_complete(
+                context={},
+                event={
+                    "status": "success",
+                    "message": "success",
+                    "job_id": BATCH_JOB_ID,
+                },
+            )
+
+        mock_log.assert_called_once_with("success")
+
+    def test_execute_complete_error(self):
+        with pytest.raises(RuntimeError, match="error"):
+            self.operator.execute_complete(
+                context={},
+                event={
+                    "status": "error",
+                    "message": "error",
+                    "job_id": BATCH_JOB_ID,
+                },
+            )
+
+    def test_execute_complete_timeout(self):
+        with pytest.raises(RuntimeError, match="timeout"):
+            self.operator.execute_complete(
+                context={},
+                event={
+                    "status": "timeout",
+                    "message": "timeout",
+                    "job_id": BATCH_JOB_ID,
+                },
+            )
+
+    def test_execute_complete_no_event(self):
+        with pytest.raises(RuntimeError, match="no event"):
+            self.operator.execute_complete(
+                context={},
+                event=None,
+            )
+
+    def test_execute_complete_unexpected_event(self):
+        with pytest.raises(RuntimeError, match="Unexpected"):
+            self.operator.execute_complete(
+                context={},
+                event={
+                    "status": "unknown",
+                    "message": "???",
+                },
+            )
+
+    def test_execute_complete_failed_tasks(self):
+        with pytest.raises(RuntimeError, match="task1"):
+            self.operator.execute_complete(
+                context={},
+                event={
+                    "status": "error",
+                    "message": "job failed",
+                    "failed_tasks": ["task1"],
+                },
+            )
+
+    @pytest.mark.parametrize(
+        ("event", "expected_exception"),
+        [
+            (
+                {
+                    "status": "success",
+                    "message": "success",
+                    "job_id": BATCH_JOB_ID,
+                },
+                None,
+            ),
+            (
+                {
+                    "status": "error",
+                    "message": "error",
+                    "job_id": BATCH_JOB_ID,
+                },
+                RuntimeError,
+            ),
+            (
+                {
+                    "status": "timeout",
+                    "message": "timeout",
+                    "job_id": BATCH_JOB_ID,
+                },
+                RuntimeError,
+            ),
+            (
+                {
+                    "status": "unknown",
+                    "message": "???",
+                },
+                RuntimeError,
+            ),
+        ],
+    )
+    @mock.patch.object(AzureBatchOperator, "clean_up")
+    def test_execute_complete_cleanup(
+        self,
+        clean_up_mock,
+        event,
+        expected_exception,
+    ):
+        self.operator.should_delete_job = True
+        self.operator.should_delete_pool = True
+
+        if expected_exception:
+            with pytest.raises(expected_exception):
+                self.operator.execute_complete(
+                    context={},
+                    event=event,
+                )
+        else:
+            self.operator.execute_complete(
+                context={},
+                event=event,
+            )
+
+        clean_up_mock.assert_has_calls(
+            [
+                mock.call(job_id=BATCH_JOB_ID),
+                mock.call(pool_id=BATCH_POOL_ID),
+            ]
+        )
diff --git 
a/providers/microsoft/azure/tests/unit/microsoft/azure/triggers/test_batch.py 
b/providers/microsoft/azure/tests/unit/microsoft/azure/triggers/test_batch.py
new file mode 100644
index 00000000000..4c24fb68216
--- /dev/null
+++ 
b/providers/microsoft/azure/tests/unit/microsoft/azure/triggers/test_batch.py
@@ -0,0 +1,287 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import time
+from unittest import mock
+
+import pytest
+from azure.batch import models as batch_models
+
+from airflow.providers.microsoft.azure.triggers.batch import AzureBatchTrigger
+from airflow.triggers.base import TriggerEvent
+
+AZURE_BATCH_CONN_ID = "azure_batch_default"
+JOB_ID = "test-job"
+POKE_INTERVAL = 5
+BATCH_END_TIME = time.time() + 60 * 60 * 24 * 7
+MODULE = "airflow.providers.microsoft.azure"
+
+
+class TestAzureBatchTrigger:
+    TRIGGER = AzureBatchTrigger(
+        job_id=JOB_ID,
+        azure_batch_conn_id=AZURE_BATCH_CONN_ID,
+        poll_interval=POKE_INTERVAL,
+        end_time=BATCH_END_TIME,
+    )
+
+    def test_batch_trigger_serialization(self):
+        classpath, kwargs = self.TRIGGER.serialize()
+
+        assert classpath == f"{MODULE}.triggers.batch.AzureBatchTrigger"
+
+        assert kwargs == {
+            "job_id": JOB_ID,
+            "azure_batch_conn_id": AZURE_BATCH_CONN_ID,
+            "poll_interval": POKE_INTERVAL,
+            "end_time": BATCH_END_TIME,
+        }
+
+    def test_build_trigger_event_success(self):
+        completed_task = mock.MagicMock()
+        completed_task.id = "task1"
+        completed_task.state = batch_models.TaskState.completed
+        completed_task.execution_info.result = 
batch_models.TaskExecutionResult.success
+
+        event = self.TRIGGER._build_trigger_event([completed_task])
+
+        assert event is not None
+
+        assert event.payload == {
+            "status": "success",
+            "message": f"Azure Batch job {JOB_ID} completed successfully.",
+            "job_id": JOB_ID,
+        }
+
+    def test_build_trigger_event_failure(self):
+        failed_task = mock.MagicMock()
+        failed_task.id = "task1"
+        failed_task.state = batch_models.TaskState.completed
+        failed_task.execution_info.result = 
batch_models.TaskExecutionResult.failure
+
+        event = self.TRIGGER._build_trigger_event([failed_task])
+
+        assert event is not None
+
+        assert event.payload == {
+            "status": "error",
+            "message": f"Azure Batch job {JOB_ID} failed.",
+            "job_id": JOB_ID,
+            "failed_tasks": ["task1"],
+        }
+
+    def test_build_trigger_event_mixed_states(self):
+        completed_task = mock.MagicMock()
+        completed_task.id = "task1"
+        completed_task.state = batch_models.TaskState.completed
+        completed_task.execution_info.result = 
batch_models.TaskExecutionResult.success
+
+        running_task = mock.MagicMock()
+        running_task.id = "task2"
+        running_task.state = batch_models.TaskState.running
+
+        event = self.TRIGGER._build_trigger_event([completed_task, 
running_task])
+
+        assert event is None
+
+    def test_build_trigger_event_empty_tasks(self):
+        event = self.TRIGGER._build_trigger_event([])
+
+        assert event is not None
+
+        assert event.payload == {
+            "status": "error",
+            "message": f"Azure Batch job {JOB_ID} contains no tasks.",
+            "job_id": JOB_ID,
+        }
+
+    @pytest.mark.asyncio
+    @mock.patch(f"{MODULE}.triggers.batch.asyncio.sleep")
+    @mock.patch(f"{MODULE}.triggers.batch.asyncio.to_thread")
+    async def test_trigger_run_non_terminal_sleeps(
+        self,
+        mock_to_thread,
+        mock_sleep,
+    ):
+        running_task = mock.MagicMock()
+        running_task.id = "task1"
+        running_task.state = batch_models.TaskState.running
+
+        completed_task = mock.MagicMock()
+        completed_task.id = "task1"
+        completed_task.state = batch_models.TaskState.completed
+        completed_task.execution_info.result = 
batch_models.TaskExecutionResult.success
+
+        mock_to_thread.side_effect = [
+            [running_task],
+            [completed_task],
+        ]
+
+        events = [event async for event in self.TRIGGER.run()]
+
+        assert events == [
+            TriggerEvent(
+                {
+                    "status": "success",
+                    "message": f"Azure Batch job {JOB_ID} completed 
successfully.",
+                    "job_id": JOB_ID,
+                }
+            )
+        ]
+
+        mock_sleep.assert_awaited_once_with(POKE_INTERVAL)
+
+    def test_build_trigger_event_non_terminal(self):
+        running_task = mock.MagicMock()
+        running_task.id = "task1"
+        running_task.state = batch_models.TaskState.running
+
+        event = self.TRIGGER._build_trigger_event([running_task])
+
+        assert event is None
+
+    @pytest.mark.asyncio
+    @mock.patch(f"{MODULE}.triggers.batch.asyncio.to_thread")
+    async def test_trigger_run_success(self, mock_to_thread):
+        completed_task = mock.MagicMock()
+        completed_task.id = "task1"
+        completed_task.state = batch_models.TaskState.completed
+        completed_task.execution_info.result = 
batch_models.TaskExecutionResult.success
+
+        mock_to_thread.return_value = [completed_task]
+
+        generator = self.TRIGGER.run()
+        actual = await generator.asend(None)
+
+        assert actual == TriggerEvent(
+            {
+                "status": "success",
+                "message": f"Azure Batch job {JOB_ID} completed successfully.",
+                "job_id": JOB_ID,
+            }
+        )
+
+    @pytest.mark.asyncio
+    @mock.patch(f"{MODULE}.triggers.batch.asyncio.to_thread")
+    async def test_trigger_run_failure(self, mock_to_thread):
+        failed_task = mock.MagicMock()
+        failed_task.id = "task1"
+        failed_task.state = batch_models.TaskState.completed
+        failed_task.execution_info.result = 
batch_models.TaskExecutionResult.failure
+
+        mock_to_thread.return_value = [failed_task]
+
+        generator = self.TRIGGER.run()
+        actual = await generator.asend(None)
+
+        assert actual == TriggerEvent(
+            {
+                "status": "error",
+                "message": f"Azure Batch job {JOB_ID} failed.",
+                "job_id": JOB_ID,
+                "failed_tasks": ["task1"],
+            }
+        )
+
+    @pytest.mark.asyncio
+    @mock.patch(f"{MODULE}.triggers.batch.asyncio.to_thread")
+    async def test_trigger_exception(self, mock_to_thread):
+        mock_to_thread.side_effect = Exception("API failure")
+
+        events = [event async for event in self.TRIGGER.run()]
+
+        assert events == [
+            TriggerEvent(
+                {
+                    "status": "error",
+                    "message": "API failure",
+                    "job_id": JOB_ID,
+                }
+            )
+        ]
+
+    @pytest.mark.asyncio
+    @mock.patch(f"{MODULE}.triggers.batch.asyncio.to_thread")
+    async def test_trigger_run_empty_tasks(self, mock_to_thread):
+        mock_to_thread.return_value = []
+
+        events = [event async for event in self.TRIGGER.run()]
+
+        assert events == [
+            TriggerEvent(
+                {
+                    "status": "error",
+                    "message": f"Azure Batch job {JOB_ID} contains no tasks.",
+                    "job_id": JOB_ID,
+                }
+            )
+        ]
+
+    @pytest.mark.asyncio
+    @mock.patch(f"{MODULE}.triggers.batch.time")
+    @mock.patch(f"{MODULE}.triggers.batch.asyncio.to_thread")
+    async def test_trigger_timeout_job_already_succeeded(
+        self,
+        mock_to_thread,
+        mock_time,
+    ):
+        completed_task = mock.MagicMock()
+        completed_task.id = "task1"
+        completed_task.state = batch_models.TaskState.completed
+        completed_task.execution_info.result = 
batch_models.TaskExecutionResult.success
+
+        mock_to_thread.return_value = [completed_task]
+
+        mock_time.time.return_value = BATCH_END_TIME + 60
+
+        events = [event async for event in self.TRIGGER.run()]
+
+        assert events == [
+            TriggerEvent(
+                {
+                    "status": "success",
+                    "message": f"Azure Batch job {JOB_ID} completed 
successfully.",
+                    "job_id": JOB_ID,
+                }
+            )
+        ]
+
+    @pytest.mark.asyncio
+    @mock.patch(f"{MODULE}.triggers.batch.time")
+    @mock.patch(f"{MODULE}.triggers.batch.asyncio.to_thread")
+    async def test_trigger_timeout(self, mock_to_thread, mock_time):
+        running_task = mock.MagicMock()
+        running_task.id = "task1"
+        running_task.state = batch_models.TaskState.running
+
+        mock_to_thread.return_value = [running_task]
+
+        mock_time.time.return_value = BATCH_END_TIME + 60
+
+        events = [event async for event in self.TRIGGER.run()]
+
+        assert events == [
+            TriggerEvent(
+                {
+                    "status": "timeout",
+                    "message": f"Timeout waiting for Azure Batch job 
{JOB_ID}.",
+                    "job_id": JOB_ID,
+                }
+            )
+        ]
diff --git a/scripts/ci/prek/known_airflow_exceptions.txt 
b/scripts/ci/prek/known_airflow_exceptions.txt
index bd4570fc55f..b4cf8603870 100644
--- a/scripts/ci/prek/known_airflow_exceptions.txt
+++ b/scripts/ci/prek/known_airflow_exceptions.txt
@@ -352,7 +352,7 @@ 
providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/data_facto
 
providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/data_lake.py::1
 
providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py::3
 
providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/wasb.py::2
-providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/batch.py::10
+providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/batch.py::9
 
providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/container_instances.py::10
 
providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/data_factory.py::1
 
providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py::1


Reply via email to