This is an automated email from the ASF dual-hosted git repository.
jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 26c6dc2d907 Add deferrable mode support to AzureBatchOperator (#66815)
26c6dc2d907 is described below
commit 26c6dc2d907b56cbbca45d9a871d30fd67bd0e35
Author: SameerMesiah97 <[email protected]>
AuthorDate: Sun May 31 20:46:58 2026 +0100
Add deferrable mode support to AzureBatchOperator (#66815)
* Add deferrable mode support to AzureBatchOperator by introducing
AzureBatchTrigger and delegating Azure Batch job monitoring to the
trigger when deferrable=True.
Implement execute_complete handling for success, timeout, failure,
and empty-task trigger events, with cleanup support for terminal
states.
Add unit tests covering trigger serialization, task state handling,
timeouts, mixed and empty task lists, exception handling, and
execute_complete event processing.
* Adjusted docs and end_time computation to reflect operator intent.
---------
Co-authored-by: Sameer Mesiah <[email protected]>
---
providers/microsoft/azure/docs/operators/batch.rst | 9 +
providers/microsoft/azure/provider.yaml | 3 +
.../providers/microsoft/azure/get_provider_info.py | 4 +
.../providers/microsoft/azure/operators/batch.py | 75 +++++-
.../providers/microsoft/azure/triggers/batch.py | 176 +++++++++++++
.../azure/example_azure_batch_operator.py | 19 ++
.../unit/microsoft/azure/operators/test_batch.py | 183 ++++++++++++-
.../unit/microsoft/azure/triggers/test_batch.py | 287 +++++++++++++++++++++
scripts/ci/prek/known_airflow_exceptions.txt | 2 +-
9 files changed, 754 insertions(+), 4 deletions(-)
diff --git a/providers/microsoft/azure/docs/operators/batch.rst
b/providers/microsoft/azure/docs/operators/batch.rst
index 8cc5cc63006..9b7722d4686 100644
--- a/providers/microsoft/azure/docs/operators/batch.rst
+++ b/providers/microsoft/azure/docs/operators/batch.rst
@@ -32,6 +32,15 @@ Below is an example of using this operator to trigger a task
on Azure Batch
:start-after: [START howto_azure_batch_operator]
:end-before: [END howto_azure_batch_operator]
+Below is an example of using this operator to trigger a task on Azure Batch
with the deferrable flag,
+so that polling for job/task completion occurs on the Airflow Triggerer.
+
+ .. exampleinclude::
/../tests/system/microsoft/azure/example_azure_batch_operator.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_azure_batch_operator_with_deferrable_flag]
+ :end-before: [END howto_azure_batch_operator_with_deferrable_flag]
+
Reference
---------
diff --git a/providers/microsoft/azure/provider.yaml
b/providers/microsoft/azure/provider.yaml
index a5cbd0bc1e8..801f283bb99 100644
--- a/providers/microsoft/azure/provider.yaml
+++ b/providers/microsoft/azure/provider.yaml
@@ -310,6 +310,9 @@ hooks:
- airflow.providers.microsoft.azure.hooks.powerbi
triggers:
+ - integration-name: Microsoft Azure Batch
+ python-modules:
+ - airflow.providers.microsoft.azure.triggers.batch
- integration-name: Microsoft Azure Compute
python-modules:
- airflow.providers.microsoft.azure.triggers.compute
diff --git
a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/get_provider_info.py
b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/get_provider_info.py
index 0b312ec09a9..40364bf1845 100644
---
a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/get_provider_info.py
+++
b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/get_provider_info.py
@@ -284,6 +284,10 @@ def get_provider_info():
},
],
"triggers": [
+ {
+ "integration-name": "Microsoft Azure Batch",
+ "python-modules":
["airflow.providers.microsoft.azure.triggers.batch"],
+ },
{
"integration-name": "Microsoft Azure Compute",
"python-modules":
["airflow.providers.microsoft.azure.triggers.compute"],
diff --git
a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/batch.py
b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/batch.py
index e5b36841c66..4e9216ec181 100644
---
a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/batch.py
+++
b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/batch.py
@@ -17,14 +17,16 @@
# under the License.
from __future__ import annotations
+import time
from collections.abc import Sequence
from functools import cached_property
from typing import TYPE_CHECKING, Any
from azure.batch import models as batch_models
-from airflow.providers.common.compat.sdk import AirflowException, BaseOperator
+from airflow.providers.common.compat.sdk import AirflowException,
BaseOperator, conf
from airflow.providers.microsoft.azure.hooks.batch import AzureBatchHook
+from airflow.providers.microsoft.azure.triggers.batch import AzureBatchTrigger
if TYPE_CHECKING:
from airflow.sdk import Context
@@ -91,6 +93,10 @@ class AzureBatchOperator(BaseOperator):
:param timeout: The amount of time to wait for the job to complete in
minutes. Default is 25
:param should_delete_job: Whether to delete job after execution. Default
is False
:param should_delete_pool: Whether to delete pool after execution of jobs.
Default is False
+ :param poll_interval: Polling interval in seconds for deferrable mode.
Default is 30.
+ Determines how frequently the trigger checks task completion status
when deferrable=True.
+ :param deferrable: Run operator in deferrable mode.
+
"""
template_fields: Sequence[str] = (
@@ -139,6 +145,8 @@ class AzureBatchOperator(BaseOperator):
timeout: int = 25,
should_delete_job: bool = False,
should_delete_pool: bool = False,
+ poll_interval: int = 30,
+ deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -176,6 +184,8 @@ class AzureBatchOperator(BaseOperator):
self.timeout = timeout
self.should_delete_job = should_delete_job
self.should_delete_pool = should_delete_pool
+ self.poll_interval = poll_interval
+ self.deferrable = deferrable
@cached_property
def hook(self) -> AzureBatchHook:
@@ -265,6 +275,7 @@ class AzureBatchOperator(BaseOperator):
start_task=self.batch_start_task,
)
self.hook.create_pool(pool)
+
# Wait for nodes to reach complete state
self.hook.wait_for_all_node_state(
self.batch_pool_id,
@@ -296,6 +307,29 @@ class AzureBatchOperator(BaseOperator):
)
# Add task to job
self.hook.add_single_task_to_job(job_id=self.batch_job_id, task=task)
+
+ if self.deferrable:
+ # Pre-deferral check (node readiness is already enforced by
wait_for_all_node_state above)
+ pool = self.hook.connection.pool.get(self.batch_pool_id)
+ if pool.resize_errors:
+ raise RuntimeError(f"Pool resize errors: {pool.resize_errors}")
+
+ nodes =
list(self.hook.connection.compute_node.list(self.batch_pool_id))
+ self.log.debug("Deferral pre-check: %d nodes present in pool %s",
len(nodes), self.batch_pool_id)
+ end_time = time.time() + (self.timeout * 60)
+
+ self.defer(
+ timeout=self.execution_timeout,
+ trigger=AzureBatchTrigger(
+ job_id=self.batch_job_id,
+ azure_batch_conn_id=self.azure_batch_conn_id,
+ end_time=end_time,
+ poll_interval=self.poll_interval,
+ ),
+ method_name="execute_complete",
+ )
+ return
+
# Wait for tasks to complete
fail_tasks =
self.hook.wait_for_job_tasks_to_complete(job_id=self.batch_job_id,
timeout=self.timeout)
# Clean up
@@ -306,7 +340,44 @@ class AzureBatchOperator(BaseOperator):
self.clean_up(self.batch_pool_id)
# raise exception if any task fail
if fail_tasks:
- raise AirflowException(f"Job fail. The failed task are:
{fail_tasks}")
+ raise RuntimeError(f"Job fail. The failed task are: {fail_tasks}")
+
+ def execute_complete(self, context: Context, event: dict[str, Any] | None)
-> None:
+ """
+ Return immediately - callback for when the trigger fires.
+
+ The trigger communicates the terminal Azure Batch job state
+ through the event payload.
+ """
+ if event is None:
+ raise RuntimeError("Trigger returned no event.")
+
+ status = event.get("status")
+ message = event.get("message", "No message returned from trigger.")
+ failed_tasks = event.get("failed_tasks")
+
+ try:
+ if status == "success":
+ self.log.info(message)
+ return
+
+ if status == "timeout":
+ raise RuntimeError(message)
+
+ if status == "error":
+ if failed_tasks:
+ raise RuntimeError(f"{message} Failed tasks:
{failed_tasks}")
+
+ raise RuntimeError(message)
+
+ raise RuntimeError(f"Unexpected trigger event received: {event}")
+
+ finally:
+ if self.should_delete_job:
+ self.clean_up(job_id=self.batch_job_id)
+
+ if self.should_delete_pool:
+ self.clean_up(pool_id=self.batch_pool_id)
def on_kill(self) -> None:
response = self.hook.connection.job.terminate(
diff --git
a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/triggers/batch.py
b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/triggers/batch.py
new file mode 100644
index 00000000000..65bf5eb0217
--- /dev/null
+++
b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/triggers/batch.py
@@ -0,0 +1,176 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import asyncio
+import time
+from collections.abc import AsyncIterator
+from typing import Any
+
+from azure.batch import models as batch_models
+
+from airflow.providers.microsoft.azure.hooks.batch import AzureBatchHook
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+
+
+class AzureBatchTrigger(BaseTrigger):
+ """
+ Trigger when Azure Batch job tasks reach a terminal state.
+
+ :param job_id: Azure Batch job identifier.
+ :param azure_batch_conn_id: Azure Batch connection id.
+ :param end_time: Absolute timeout deadline as determined using
``time.time()``.
+ :param poll_interval: Poll interval in seconds.
+ """
+
+ def __init__(
+ self,
+ job_id: str,
+ azure_batch_conn_id: str,
+ end_time: float,
+ poll_interval: int = 30,
+ ):
+ super().__init__()
+
+ self.job_id = job_id
+ self.azure_batch_conn_id = azure_batch_conn_id
+ self.end_time = end_time
+ self.poll_interval = poll_interval
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ """Serialize trigger arguments and classpath."""
+ return (
+ f"{self.__class__.__module__}.{self.__class__.__name__}",
+ {
+ "job_id": self.job_id,
+ "azure_batch_conn_id": self.azure_batch_conn_id,
+ "end_time": self.end_time,
+ "poll_interval": self.poll_interval,
+ },
+ )
+
+ def _get_incomplete_tasks(
+ self,
+ tasks: list[batch_models.CloudTask],
+ ) -> list[batch_models.CloudTask]:
+ """Return tasks that have not yet completed."""
+ return [task for task in tasks if task.state !=
batch_models.TaskState.completed]
+
+ def _build_trigger_event(
+ self,
+ tasks: list[batch_models.CloudTask],
+ ) -> TriggerEvent | None:
+ """
+ Convert Batch task states to TriggerEvent.
+
+ Returns None if tasks are still running.
+ """
+ if not tasks:
+ return TriggerEvent(
+ {
+ "status": "error",
+ "message": f"Azure Batch job {self.job_id} contains no
tasks.",
+ "job_id": self.job_id,
+ }
+ )
+
+ if self._get_incomplete_tasks(tasks):
+ return None
+
+ failed_tasks = [
+ task.id
+ for task in tasks
+ if task.execution_info and task.execution_info.result ==
batch_models.TaskExecutionResult.failure
+ ]
+
+ if failed_tasks:
+ return TriggerEvent(
+ {
+ "status": "error",
+ "message": f"Azure Batch job {self.job_id} failed.",
+ "job_id": self.job_id,
+ "failed_tasks": failed_tasks,
+ }
+ )
+
+ return TriggerEvent(
+ {
+ "status": "success",
+ "message": f"Azure Batch job {self.job_id} completed
successfully.",
+ "job_id": self.job_id,
+ }
+ )
+
+ async def run(self) -> AsyncIterator[TriggerEvent]:
+ """Poll Azure Batch job tasks until completion or timeout."""
+ hook = AzureBatchHook(
+ azure_batch_conn_id=self.azure_batch_conn_id,
+ )
+
+ try:
+ while time.time() <= self.end_time:
+ tasks = await asyncio.to_thread(lambda:
list(hook.connection.task.list(self.job_id)))
+
+ event = self._build_trigger_event(tasks)
+
+ if event:
+ yield event
+ return
+
+ incomplete_tasks = self._get_incomplete_tasks(tasks)
+
+ self.log.info(
+ "Azure Batch job %s still running. Incomplete tasks: %s.
Sleeping %s seconds.",
+ self.job_id,
+ incomplete_tasks,
+ self.poll_interval,
+ )
+
+ await asyncio.sleep(self.poll_interval)
+
+ # Final check before timeout event in case job completed
+ # during the last sleep interval.
+ tasks = await asyncio.to_thread(lambda:
list(hook.connection.task.list(self.job_id)))
+
+ event = self._build_trigger_event(tasks)
+
+ if event:
+ yield event
+ return
+
+ yield TriggerEvent(
+ {
+ "status": "timeout",
+ "message": f"Timeout waiting for Azure Batch job
{self.job_id}.",
+ "job_id": self.job_id,
+ }
+ )
+
+ except Exception as e:
+ self.log.exception(
+ "Azure Batch trigger failed for job %s.",
+ self.job_id,
+ )
+
+ yield TriggerEvent(
+ {
+ "status": "error",
+ "message": str(e),
+ "job_id": self.job_id,
+ }
+ )
diff --git
a/providers/microsoft/azure/tests/system/microsoft/azure/example_azure_batch_operator.py
b/providers/microsoft/azure/tests/system/microsoft/azure/example_azure_batch_operator.py
index f5c563aab77..b6f6dded181 100644
---
a/providers/microsoft/azure/tests/system/microsoft/azure/example_azure_batch_operator.py
+++
b/providers/microsoft/azure/tests/system/microsoft/azure/example_azure_batch_operator.py
@@ -57,6 +57,25 @@ with DAG(
)
# [END howto_azure_batch_operator]
+ # [START howto_azure_batch_operator_with_deferrable_flag]
+ azure_batch_operator_deferrable = AzureBatchOperator(
+ task_id="azure_batch_deferrable",
+ batch_pool_id=POOL_ID,
+ batch_pool_vm_size="standard_d2s_v3",
+ batch_job_id="example-job",
+ batch_task_command_line="/bin/bash -c 'set -e; set -o pipefail; echo
hello world!; wait'",
+ batch_task_id="example-task",
+ vm_node_agent_sku_id="batch.node.ubuntu 22.04",
+ vm_publisher="Canonical",
+ vm_offer="0001-com-ubuntu-server-jammy",
+ vm_sku="22_04-lts-gen2",
+ target_dedicated_nodes=1,
+ deferrable=True,
+ )
+ # [END howto_azure_batch_operator_with_deferrable_flag]
+
+ azure_batch_operator >> azure_batch_operator_deferrable
+
from tests_common.test_utils.system_tests import get_test_run # noqa: E402
# Needed to run the example DAG with pytest (see:
contributing-docs/testing/system_tests.rst)
diff --git
a/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_batch.py
b/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_batch.py
index 160aea3df7e..7c9fba490cc 100644
---
a/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_batch.py
+++
b/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_batch.py
@@ -23,9 +23,10 @@ from unittest import mock
import pytest
from airflow.models import Connection
-from airflow.providers.common.compat.sdk import AirflowException
+from airflow.providers.common.compat.sdk import AirflowException, TaskDeferred
from airflow.providers.microsoft.azure.hooks.batch import AzureBatchHook
from airflow.providers.microsoft.azure.operators.batch import
AzureBatchOperator
+from airflow.providers.microsoft.azure.triggers.batch import AzureBatchTrigger
TASK_ID = "MyDag"
BATCH_POOL_ID = "MyPool"
@@ -247,3 +248,183 @@ class TestAzureBatchOperator:
self.operator.clean_up("mypool", "myjob")
self.batch_client.job.delete.assert_called_with("myjob")
self.batch_client.pool.delete.assert_called_with("mypool")
+
+
+class TestAzureBatchOperatorDeferrable:
+ @pytest.fixture(autouse=True)
+ def setup_test_cases(self, mocked_batch_service_client,
create_mock_connections):
+ self.batch_client = mock.MagicMock(name="FakeBatchServiceClient")
+ mocked_batch_service_client.return_value = self.batch_client
+
+ self.test_conn_id = "test_azure_batch"
+ self.test_account_url = "http://test-endpoint:29000"
+
+ create_mock_connections(
+ Connection(
+ conn_id=self.test_conn_id,
+ conn_type="azure_batch",
+ extra=json.dumps({"account_url": self.test_account_url}),
+ ),
+ )
+
+ self.operator = AzureBatchOperator(
+ task_id=TASK_ID,
+ batch_pool_id=BATCH_POOL_ID,
+ batch_pool_vm_size=BATCH_VM_SIZE,
+ batch_job_id=BATCH_JOB_ID,
+ batch_task_id=BATCH_TASK_ID,
+ batch_task_command_line="echo hello",
+ vm_node_agent_sku_id="node-agent",
+ os_family="4",
+ target_dedicated_nodes=1,
+ azure_batch_conn_id=self.test_conn_id,
+ deferrable=True,
+ )
+
+ @mock.patch.object(AzureBatchHook, "wait_for_all_node_state")
+ def test_execute_defers(self, wait_mock):
+
+ wait_mock.return_value = True
+ self.batch_client.pool.get.return_value.resize_errors = None
+
+ with pytest.raises(TaskDeferred) as ctx:
+ self.operator.execute(None)
+
+ trigger = ctx.value.trigger
+
+ assert isinstance(trigger, AzureBatchTrigger)
+
+ assert trigger.job_id == BATCH_JOB_ID
+ assert trigger.azure_batch_conn_id == self.test_conn_id
+
+ self.batch_client.pool.add.assert_called()
+ self.batch_client.job.add.assert_called()
+ self.batch_client.task.add.assert_called()
+
+ def test_execute_complete_success(self):
+ with mock.patch.object(self.operator.log, "info") as mock_log:
+ self.operator.execute_complete(
+ context={},
+ event={
+ "status": "success",
+ "message": "success",
+ "job_id": BATCH_JOB_ID,
+ },
+ )
+
+ mock_log.assert_called_once_with("success")
+
+ def test_execute_complete_error(self):
+ with pytest.raises(RuntimeError, match="error"):
+ self.operator.execute_complete(
+ context={},
+ event={
+ "status": "error",
+ "message": "error",
+ "job_id": BATCH_JOB_ID,
+ },
+ )
+
+ def test_execute_complete_timeout(self):
+ with pytest.raises(RuntimeError, match="timeout"):
+ self.operator.execute_complete(
+ context={},
+ event={
+ "status": "timeout",
+ "message": "timeout",
+ "job_id": BATCH_JOB_ID,
+ },
+ )
+
+ def test_execute_complete_no_event(self):
+ with pytest.raises(RuntimeError, match="no event"):
+ self.operator.execute_complete(
+ context={},
+ event=None,
+ )
+
+ def test_execute_complete_unexpected_event(self):
+ with pytest.raises(RuntimeError, match="Unexpected"):
+ self.operator.execute_complete(
+ context={},
+ event={
+ "status": "unknown",
+ "message": "???",
+ },
+ )
+
+ def test_execute_complete_failed_tasks(self):
+ with pytest.raises(RuntimeError, match="task1"):
+ self.operator.execute_complete(
+ context={},
+ event={
+ "status": "error",
+ "message": "job failed",
+ "failed_tasks": ["task1"],
+ },
+ )
+
+ @pytest.mark.parametrize(
+ ("event", "expected_exception"),
+ [
+ (
+ {
+ "status": "success",
+ "message": "success",
+ "job_id": BATCH_JOB_ID,
+ },
+ None,
+ ),
+ (
+ {
+ "status": "error",
+ "message": "error",
+ "job_id": BATCH_JOB_ID,
+ },
+ RuntimeError,
+ ),
+ (
+ {
+ "status": "timeout",
+ "message": "timeout",
+ "job_id": BATCH_JOB_ID,
+ },
+ RuntimeError,
+ ),
+ (
+ {
+ "status": "unknown",
+ "message": "???",
+ },
+ RuntimeError,
+ ),
+ ],
+ )
+ @mock.patch.object(AzureBatchOperator, "clean_up")
+ def test_execute_complete_cleanup(
+ self,
+ clean_up_mock,
+ event,
+ expected_exception,
+ ):
+ self.operator.should_delete_job = True
+ self.operator.should_delete_pool = True
+
+ if expected_exception:
+ with pytest.raises(expected_exception):
+ self.operator.execute_complete(
+ context={},
+ event=event,
+ )
+ else:
+ self.operator.execute_complete(
+ context={},
+ event=event,
+ )
+
+ clean_up_mock.assert_has_calls(
+ [
+ mock.call(job_id=BATCH_JOB_ID),
+ mock.call(pool_id=BATCH_POOL_ID),
+ ]
+ )
diff --git
a/providers/microsoft/azure/tests/unit/microsoft/azure/triggers/test_batch.py
b/providers/microsoft/azure/tests/unit/microsoft/azure/triggers/test_batch.py
new file mode 100644
index 00000000000..4c24fb68216
--- /dev/null
+++
b/providers/microsoft/azure/tests/unit/microsoft/azure/triggers/test_batch.py
@@ -0,0 +1,287 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import time
+from unittest import mock
+
+import pytest
+from azure.batch import models as batch_models
+
+from airflow.providers.microsoft.azure.triggers.batch import AzureBatchTrigger
+from airflow.triggers.base import TriggerEvent
+
+AZURE_BATCH_CONN_ID = "azure_batch_default"
+JOB_ID = "test-job"
+POKE_INTERVAL = 5
+BATCH_END_TIME = time.time() + 60 * 60 * 24 * 7
+MODULE = "airflow.providers.microsoft.azure"
+
+
+class TestAzureBatchTrigger:
+ TRIGGER = AzureBatchTrigger(
+ job_id=JOB_ID,
+ azure_batch_conn_id=AZURE_BATCH_CONN_ID,
+ poll_interval=POKE_INTERVAL,
+ end_time=BATCH_END_TIME,
+ )
+
+ def test_batch_trigger_serialization(self):
+ classpath, kwargs = self.TRIGGER.serialize()
+
+ assert classpath == f"{MODULE}.triggers.batch.AzureBatchTrigger"
+
+ assert kwargs == {
+ "job_id": JOB_ID,
+ "azure_batch_conn_id": AZURE_BATCH_CONN_ID,
+ "poll_interval": POKE_INTERVAL,
+ "end_time": BATCH_END_TIME,
+ }
+
+ def test_build_trigger_event_success(self):
+ completed_task = mock.MagicMock()
+ completed_task.id = "task1"
+ completed_task.state = batch_models.TaskState.completed
+ completed_task.execution_info.result =
batch_models.TaskExecutionResult.success
+
+ event = self.TRIGGER._build_trigger_event([completed_task])
+
+ assert event is not None
+
+ assert event.payload == {
+ "status": "success",
+ "message": f"Azure Batch job {JOB_ID} completed successfully.",
+ "job_id": JOB_ID,
+ }
+
+ def test_build_trigger_event_failure(self):
+ failed_task = mock.MagicMock()
+ failed_task.id = "task1"
+ failed_task.state = batch_models.TaskState.completed
+ failed_task.execution_info.result =
batch_models.TaskExecutionResult.failure
+
+ event = self.TRIGGER._build_trigger_event([failed_task])
+
+ assert event is not None
+
+ assert event.payload == {
+ "status": "error",
+ "message": f"Azure Batch job {JOB_ID} failed.",
+ "job_id": JOB_ID,
+ "failed_tasks": ["task1"],
+ }
+
+ def test_build_trigger_event_mixed_states(self):
+ completed_task = mock.MagicMock()
+ completed_task.id = "task1"
+ completed_task.state = batch_models.TaskState.completed
+ completed_task.execution_info.result =
batch_models.TaskExecutionResult.success
+
+ running_task = mock.MagicMock()
+ running_task.id = "task2"
+ running_task.state = batch_models.TaskState.running
+
+ event = self.TRIGGER._build_trigger_event([completed_task,
running_task])
+
+ assert event is None
+
+ def test_build_trigger_event_empty_tasks(self):
+ event = self.TRIGGER._build_trigger_event([])
+
+ assert event is not None
+
+ assert event.payload == {
+ "status": "error",
+ "message": f"Azure Batch job {JOB_ID} contains no tasks.",
+ "job_id": JOB_ID,
+ }
+
+ @pytest.mark.asyncio
+ @mock.patch(f"{MODULE}.triggers.batch.asyncio.sleep")
+ @mock.patch(f"{MODULE}.triggers.batch.asyncio.to_thread")
+ async def test_trigger_run_non_terminal_sleeps(
+ self,
+ mock_to_thread,
+ mock_sleep,
+ ):
+ running_task = mock.MagicMock()
+ running_task.id = "task1"
+ running_task.state = batch_models.TaskState.running
+
+ completed_task = mock.MagicMock()
+ completed_task.id = "task1"
+ completed_task.state = batch_models.TaskState.completed
+ completed_task.execution_info.result =
batch_models.TaskExecutionResult.success
+
+ mock_to_thread.side_effect = [
+ [running_task],
+ [completed_task],
+ ]
+
+ events = [event async for event in self.TRIGGER.run()]
+
+ assert events == [
+ TriggerEvent(
+ {
+ "status": "success",
+ "message": f"Azure Batch job {JOB_ID} completed
successfully.",
+ "job_id": JOB_ID,
+ }
+ )
+ ]
+
+ mock_sleep.assert_awaited_once_with(POKE_INTERVAL)
+
+ def test_build_trigger_event_non_terminal(self):
+ running_task = mock.MagicMock()
+ running_task.id = "task1"
+ running_task.state = batch_models.TaskState.running
+
+ event = self.TRIGGER._build_trigger_event([running_task])
+
+ assert event is None
+
+ @pytest.mark.asyncio
+ @mock.patch(f"{MODULE}.triggers.batch.asyncio.to_thread")
+ async def test_trigger_run_success(self, mock_to_thread):
+ completed_task = mock.MagicMock()
+ completed_task.id = "task1"
+ completed_task.state = batch_models.TaskState.completed
+ completed_task.execution_info.result =
batch_models.TaskExecutionResult.success
+
+ mock_to_thread.return_value = [completed_task]
+
+ generator = self.TRIGGER.run()
+ actual = await generator.asend(None)
+
+ assert actual == TriggerEvent(
+ {
+ "status": "success",
+ "message": f"Azure Batch job {JOB_ID} completed successfully.",
+ "job_id": JOB_ID,
+ }
+ )
+
+ @pytest.mark.asyncio
+ @mock.patch(f"{MODULE}.triggers.batch.asyncio.to_thread")
+ async def test_trigger_run_failure(self, mock_to_thread):
+ failed_task = mock.MagicMock()
+ failed_task.id = "task1"
+ failed_task.state = batch_models.TaskState.completed
+ failed_task.execution_info.result =
batch_models.TaskExecutionResult.failure
+
+ mock_to_thread.return_value = [failed_task]
+
+ generator = self.TRIGGER.run()
+ actual = await generator.asend(None)
+
+ assert actual == TriggerEvent(
+ {
+ "status": "error",
+ "message": f"Azure Batch job {JOB_ID} failed.",
+ "job_id": JOB_ID,
+ "failed_tasks": ["task1"],
+ }
+ )
+
+ @pytest.mark.asyncio
+ @mock.patch(f"{MODULE}.triggers.batch.asyncio.to_thread")
+ async def test_trigger_exception(self, mock_to_thread):
+ mock_to_thread.side_effect = Exception("API failure")
+
+ events = [event async for event in self.TRIGGER.run()]
+
+ assert events == [
+ TriggerEvent(
+ {
+ "status": "error",
+ "message": "API failure",
+ "job_id": JOB_ID,
+ }
+ )
+ ]
+
+ @pytest.mark.asyncio
+ @mock.patch(f"{MODULE}.triggers.batch.asyncio.to_thread")
+ async def test_trigger_run_empty_tasks(self, mock_to_thread):
+ mock_to_thread.return_value = []
+
+ events = [event async for event in self.TRIGGER.run()]
+
+ assert events == [
+ TriggerEvent(
+ {
+ "status": "error",
+ "message": f"Azure Batch job {JOB_ID} contains no tasks.",
+ "job_id": JOB_ID,
+ }
+ )
+ ]
+
+ @pytest.mark.asyncio
+ @mock.patch(f"{MODULE}.triggers.batch.time")
+ @mock.patch(f"{MODULE}.triggers.batch.asyncio.to_thread")
+ async def test_trigger_timeout_job_already_succeeded(
+ self,
+ mock_to_thread,
+ mock_time,
+ ):
+ completed_task = mock.MagicMock()
+ completed_task.id = "task1"
+ completed_task.state = batch_models.TaskState.completed
+ completed_task.execution_info.result =
batch_models.TaskExecutionResult.success
+
+ mock_to_thread.return_value = [completed_task]
+
+ mock_time.time.return_value = BATCH_END_TIME + 60
+
+ events = [event async for event in self.TRIGGER.run()]
+
+ assert events == [
+ TriggerEvent(
+ {
+ "status": "success",
+ "message": f"Azure Batch job {JOB_ID} completed
successfully.",
+ "job_id": JOB_ID,
+ }
+ )
+ ]
+
+ @pytest.mark.asyncio
+ @mock.patch(f"{MODULE}.triggers.batch.time")
+ @mock.patch(f"{MODULE}.triggers.batch.asyncio.to_thread")
+ async def test_trigger_timeout(self, mock_to_thread, mock_time):
+ running_task = mock.MagicMock()
+ running_task.id = "task1"
+ running_task.state = batch_models.TaskState.running
+
+ mock_to_thread.return_value = [running_task]
+
+ mock_time.time.return_value = BATCH_END_TIME + 60
+
+ events = [event async for event in self.TRIGGER.run()]
+
+ assert events == [
+ TriggerEvent(
+ {
+ "status": "timeout",
+ "message": f"Timeout waiting for Azure Batch job
{JOB_ID}.",
+ "job_id": JOB_ID,
+ }
+ )
+ ]
diff --git a/scripts/ci/prek/known_airflow_exceptions.txt
b/scripts/ci/prek/known_airflow_exceptions.txt
index bd4570fc55f..b4cf8603870 100644
--- a/scripts/ci/prek/known_airflow_exceptions.txt
+++ b/scripts/ci/prek/known_airflow_exceptions.txt
@@ -352,7 +352,7 @@
providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/data_facto
providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/data_lake.py::1
providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py::3
providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/wasb.py::2
-providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/batch.py::10
+providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/batch.py::9
providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/container_instances.py::10
providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/data_factory.py::1
providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py::1