Copilot commented on code in PR #62772:
URL: https://github.com/apache/airflow/pull/62772#discussion_r3066478341


##########
providers/microsoft/azure/src/airflow/providers/microsoft/azure/triggers/container_instance.py:
##########
@@ -0,0 +1,123 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+from collections.abc import AsyncIterator
+from typing import Any
+
+from airflow.providers.microsoft.azure.hooks.container_instance import 
AzureContainerInstanceAsyncHook
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+
+TERMINAL_STATES = frozenset({"Terminated", "Succeeded", "Failed", "Unhealthy"})
+SUCCESS_STATES = frozenset({"Terminated", "Succeeded"})
+
+
+class AzureContainerInstanceTrigger(BaseTrigger):
+    """
+    Poll an Azure Container Instance until it reaches a terminal state.
+
+    :param resource_group: the name of the resource group
+    :param name: the name of the container group
+    :param ci_conn_id: connection id of the Azure service principal
+    :param polling_interval: time in seconds between state polls
+    """
+
+    def __init__(
+        self,
+        resource_group: str,
+        name: str,
+        ci_conn_id: str,
+        polling_interval: float = 30.0,
+    ) -> None:
+        super().__init__()
+        self.resource_group = resource_group
+        self.name = name
+        self.ci_conn_id = ci_conn_id
+        self.polling_interval = polling_interval
+
+    def serialize(self) -> tuple[str, dict[str, Any]]:
+        """Serialize trigger arguments and classpath."""
+        return (
+            
"airflow.providers.microsoft.azure.triggers.container_instance.AzureContainerInstanceTrigger",
+            {
+                "resource_group": self.resource_group,
+                "name": self.name,
+                "ci_conn_id": self.ci_conn_id,
+                "polling_interval": self.polling_interval,
+            },
+        )
+
+    async def run(self) -> AsyncIterator[TriggerEvent]:
+        """Poll ACI until a terminal state is reached, then yield a 
TriggerEvent."""
+        hook = AzureContainerInstanceAsyncHook(azure_conn_id=self.ci_conn_id)
+        try:
+            async with hook.get_async_conn() as client:
+                while True:
+                    cg_state = await 
client.container_groups.get(self.resource_group, self.name)
+                    instance_view = cg_state.containers[0].instance_view
+
+                    if instance_view is not None:
+                        c_state = instance_view.current_state
+                        state = c_state.state
+                        exit_code = c_state.exit_code
+                        detail_status = c_state.detail_status
+                    else:
+                        state = cg_state.provisioning_state
+                        exit_code = 0
+                        detail_status = "Provisioning"
+
+                    if state in TERMINAL_STATES:
+                        if state in SUCCESS_STATES and exit_code == 0:

Review Comment:
   When `instance_view` is `None`, using `cg_state.provisioning_state` to 
decide terminal success can complete the trigger early (e.g., provisioning 
state may become `Succeeded` before the container has actually finished). 
Consider treating `instance_view is None` as non-terminal except for explicit 
provisioning failure states (e.g., `Failed`/`Unhealthy`), and only using 
container `current_state` for success/exit-code decisions.



##########
providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/container_instance.py:
##########
@@ -172,3 +185,84 @@ def test_connection(self):
             return False, str(e)
 
         return True, "Successfully connected to Azure Container Instance."
+
+
+class AzureContainerInstanceAsyncHook(AzureContainerInstanceHook):
+    """
+    An async hook for communicating with Azure Container Instances.
+
+    :param azure_conn_id: :ref:`Azure connection id<howto/connection:azure>` of
+        a service principal which will be used to start the container instance.
+    """
+
+    def __init__(self, azure_conn_id: str = 
AzureContainerInstanceHook.default_conn_name) -> None:
+        super().__init__(azure_conn_id=azure_conn_id)
+
+    @asynccontextmanager
+    async def get_async_conn(self) -> 
AsyncGenerator[AsyncContainerInstanceManagementClient, None]:
+        """Create an async management client bound to a single credential."""
+        conn = await get_async_connection(self.conn_id)
+        tenant = conn.extra_dejson.get("tenantId")
+        subscription_id = cast("str", conn.extra_dejson.get("subscriptionId"))
+
+        credential: AsyncClientSecretCredential | AsyncDefaultAzureCredential
+        if all([conn.login, conn.password, tenant]):
+            credential = AsyncClientSecretCredential(
+                client_id=cast("str", conn.login),
+                client_secret=cast("str", conn.password),
+                tenant_id=cast("str", tenant),
+            )
+        else:
+            managed_identity_client_id = 
conn.extra_dejson.get("managed_identity_client_id")

Review Comment:
   The sync hook supports `key_path`/`key_json` connection extras, but the 
async hook currently ignores them and always uses either client secret creds or 
default creds. This can make `deferrable=True` behave differently (or fail) for 
existing ACI connections that rely on `key_path`/`key_json`. Either add 
equivalent support in the async hook or raise a clear `AirflowException` when 
those extras are present so users understand the limitation.



##########
providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/container_instances.py:
##########
@@ -381,13 +395,36 @@ def execute(self, context: Context) -> int:
                 identity=self.identity,
             )
 
-            self._ci_hook.create_or_update(self.resource_group, self.name, 
container_group)
+            self.hook.create_or_update(self.resource_group, self.name, 
container_group)
 
             self.log.info("Container group started %s/%s", 
self.resource_group, self.name)
 
+            if self.deferrable:
+                cg_state = self.hook.get_state(self.resource_group, self.name)
+                instance_view = cg_state.containers[0].instance_view
+                current_state = (
+                    instance_view.current_state.state
+                    if instance_view is not None
+                    else cg_state.provisioning_state
+                )
+                terminal_states = {"Terminated", "Succeeded", "Failed", 
"Unhealthy"}
+                if current_state not in terminal_states:

Review Comment:
   In deferrable mode, treating `cg_state.provisioning_state` as a terminal 
state can prevent deferral (e.g., provisioning may become `Succeeded` while the 
container is still running but `instance_view` isn't populated yet). To keep 
deferrable behavior correct and consistent with monitoring logic, consider not 
classifying provisioning `Succeeded` as terminal for deferral decisions; defer 
until container execution state is terminal.
   ```suggestion
                       if instance_view is not None and 
instance_view.current_state is not None
                       else None
                   )
                   terminal_states = {"Terminated", "Succeeded", "Failed", 
"Unhealthy"}
                   if current_state is None or current_state not in 
terminal_states:
   ```



##########
providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_container_instances.py:
##########
@@ -642,6 +644,317 @@ def test_execute_with_identity_dict(self, aci_mock):
         # user_assigned_identities should contain the resource id as a key
         assert resource_id in (called_cg.identity.user_assigned_identities or 
{})
 
+    @mock.patch(
+        
"airflow.providers.microsoft.azure.operators.container_instances.AzureContainerInstanceHook",
+        autospec=True,
+    )
+    def test_execute_deferrable_defers_when_container_running(self, aci_mock):
+        """When deferrable=True and the container is still running, defer() is 
called."""
+        running_cg = make_mock_container(state="Running", exit_code=0, 
detail_status="test")
+        aci_mock.return_value.get_state.return_value = running_cg
+        aci_mock.return_value.exists.return_value = False
+
+        aci = AzureContainerInstancesOperator(
+            ci_conn_id="azure_default",
+            registry_conn_id=None,
+            resource_group="resource-group",
+            name="container-name",
+            image="container-image",
+            region="region",
+            task_id="task",
+            deferrable=True,
+            polling_interval=10.0,
+        )
+        with pytest.raises(TaskDeferred) as exc_info:
+            aci.execute(None)
+
+        assert isinstance(exc_info.value.trigger, 
AzureContainerInstanceTrigger)
+        assert exc_info.value.trigger.resource_group == "resource-group"
+        assert exc_info.value.trigger.name == "container-name"
+        assert exc_info.value.trigger.polling_interval == 10.0
+        assert exc_info.value.method_name == "execute_complete"
+        # Container must NOT be deleted when deferring — it is still running
+        assert aci_mock.return_value.delete.call_count == 0
+
+    @mock.patch(
+        
"airflow.providers.microsoft.azure.operators.container_instances.AzureContainerInstanceHook",
+        autospec=True,
+    )
+    def 
test_execute_deferrable_completes_synchronously_if_already_terminated(self, 
aci_mock):
+        """When deferrable=True but container is already terminal, no deferral 
— sync completion."""
+        terminated_cg = make_mock_container(state="Terminated", exit_code=0, 
detail_status="test")
+        aci_mock.return_value.get_state.return_value = terminated_cg
+        aci_mock.return_value.exists.return_value = False
+
+        aci = AzureContainerInstancesOperator(
+            ci_conn_id="azure_default",
+            registry_conn_id=None,
+            resource_group="resource-group",
+            name="container-name",
+            image="container-image",
+            region="region",
+            task_id="task",
+            deferrable=True,
+        )
+        result = aci.execute(None)
+        assert result == 0
+        assert aci_mock.return_value.create_or_update.call_count == 1
+
+    @mock.patch(
+        
"airflow.providers.microsoft.azure.operators.container_instances.AzureContainerInstanceHook",
+        autospec=True,
+    )
+    def test_execute_complete_success(self, aci_mock):
+        """execute_complete succeeds, does not raise, and deletes the 
container group."""
+        aci_mock.return_value.get_logs.return_value = None
+
+        aci = AzureContainerInstancesOperator(
+            ci_conn_id="azure_default",
+            registry_conn_id=None,
+            resource_group="resource-group",
+            name="container-name",
+            image="container-image",
+            region="region",
+            task_id="task",
+            deferrable=True,
+            remove_on_success=True,
+        )
+        result = aci.execute_complete(
+            context=None,
+            event={
+                "status": "success",
+                "exit_code": 0,
+                "detail_status": "Completed",
+                "resource_group": "resource-group",
+                "name": "container-name",
+            },
+        )
+        assert result == 0
+        assert aci_mock.return_value.delete.call_count == 1
+
+    @mock.patch(
+        
"airflow.providers.microsoft.azure.operators.container_instances.AzureContainerInstanceHook",
+        autospec=True,
+    )
+    def test_execute_complete_success_with_remove_on_success_false(self, 
aci_mock):
+        """execute_complete with remove_on_success=False should NOT delete the 
container."""
+        aci_mock.return_value.get_logs.return_value = None
+
+        aci = AzureContainerInstancesOperator(
+            ci_conn_id="azure_default",
+            registry_conn_id=None,
+            resource_group="resource-group",
+            name="container-name",
+            image="container-image",
+            region="region",
+            task_id="task",
+            deferrable=True,
+            remove_on_success=False,
+        )
+        aci.execute_complete(
+            context=None,
+            event={
+                "status": "success",
+                "exit_code": 0,
+                "detail_status": "Completed",
+                "resource_group": "resource-group",
+                "name": "container-name",
+            },
+        )
+        assert aci_mock.return_value.delete.call_count == 0
+
+    @mock.patch(
+        
"airflow.providers.microsoft.azure.operators.container_instances.AzureContainerInstanceHook",
+        autospec=True,
+    )
+    def test_execute_complete_error_event_raises(self, aci_mock):
+        """execute_complete raises AirflowException when event has 
status=error."""

Review Comment:
   This test docstring says `execute_complete` raises `AirflowException`, but 
the assertion expects `RuntimeError`. Please update the docstring (or the 
implementation) so the documented behavior matches what the test is actually 
verifying.
   ```suggestion
           """execute_complete raises RuntimeError when event has 
status=error."""
   ```



##########
providers/microsoft/azure/src/airflow/providers/microsoft/azure/triggers/container_instance.py:
##########
@@ -0,0 +1,123 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+from collections.abc import AsyncIterator
+from typing import Any
+
+from airflow.providers.microsoft.azure.hooks.container_instance import 
AzureContainerInstanceAsyncHook
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+
+TERMINAL_STATES = frozenset({"Terminated", "Succeeded", "Failed", "Unhealthy"})
+SUCCESS_STATES = frozenset({"Terminated", "Succeeded"})
+
+
+class AzureContainerInstanceTrigger(BaseTrigger):
+    """
+    Poll an Azure Container Instance until it reaches a terminal state.
+
+    :param resource_group: the name of the resource group
+    :param name: the name of the container group
+    :param ci_conn_id: connection id of the Azure service principal
+    :param polling_interval: time in seconds between state polls
+    """
+
+    def __init__(
+        self,
+        resource_group: str,
+        name: str,
+        ci_conn_id: str,
+        polling_interval: float = 30.0,
+    ) -> None:
+        super().__init__()
+        self.resource_group = resource_group
+        self.name = name
+        self.ci_conn_id = ci_conn_id
+        self.polling_interval = polling_interval
+
+    def serialize(self) -> tuple[str, dict[str, Any]]:
+        """Serialize trigger arguments and classpath."""
+        return (
+            
"airflow.providers.microsoft.azure.triggers.container_instance.AzureContainerInstanceTrigger",
+            {
+                "resource_group": self.resource_group,
+                "name": self.name,
+                "ci_conn_id": self.ci_conn_id,
+                "polling_interval": self.polling_interval,
+            },
+        )
+
+    async def run(self) -> AsyncIterator[TriggerEvent]:
+        """Poll ACI until a terminal state is reached, then yield a 
TriggerEvent."""
+        hook = AzureContainerInstanceAsyncHook(azure_conn_id=self.ci_conn_id)
+        try:
+            async with hook.get_async_conn() as client:
+                while True:
+                    cg_state = await 
client.container_groups.get(self.resource_group, self.name)
+                    instance_view = cg_state.containers[0].instance_view
+
+                    if instance_view is not None:
+                        c_state = instance_view.current_state
+                        state = c_state.state
+                        exit_code = c_state.exit_code
+                        detail_status = c_state.detail_status
+                    else:
+                        state = cg_state.provisioning_state
+                        exit_code = 0
+                        detail_status = "Provisioning"
+
+                    if state in TERMINAL_STATES:
+                        if state in SUCCESS_STATES and exit_code == 0:
+                            yield TriggerEvent(
+                                {
+                                    "status": "success",
+                                    "exit_code": exit_code,
+                                    "detail_status": detail_status,
+                                    "resource_group": self.resource_group,
+                                    "name": self.name,
+                                }
+                            )
+                        else:
+                            yield TriggerEvent(
+                                {
+                                    "status": "error",
+                                    "exit_code": exit_code,
+                                    "detail_status": detail_status,
+                                    "resource_group": self.resource_group,
+                                    "name": self.name,
+                                    "message": (
+                                        f"Container group 
{self.resource_group}/{self.name} "
+                                        f"reached state {state!r} with exit 
code {exit_code} "
+                                        f"({detail_status})"
+                                    ),
+                                }
+                            )
+                        return
+
+                    await asyncio.sleep(self.polling_interval)

Review Comment:
   Catching `Exception` will also catch `asyncio.CancelledError` on Python 
versions where it subclasses `Exception`, which can prevent clean trigger 
cancellation/shutdown and instead emit an error event. Handle 
`asyncio.CancelledError` explicitly (re-raise) before the generic exception 
handler.
   ```suggestion
                       await asyncio.sleep(self.polling_interval)
           except asyncio.CancelledError:
               raise
   ```



##########
providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/container_instances.py:
##########
@@ -193,6 +200,8 @@ def __init__(
         diagnostics: ContainerGroupDiagnostics | None = None,
         priority: str | None = "Regular",
         identity: ContainerGroupIdentity | dict | None = None,
+        deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
+        polling_interval: float = 30.0,

Review Comment:
   The PR description states `polling_interval` defaults to `5.0s`, but the 
operator signature defaults to `30.0`. Please align the code default with the 
PR description (or update the PR description) to avoid confusing users and 
reviewers.



##########
providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/container_instances.py:
##########
@@ -407,24 +444,65 @@ def execute(self, context: Context) -> int:
             raise AirflowException("Could not start container group")
 
         finally:
-            if exit_code == 0 or self.remove_on_error:
-                self.on_kill()
+            if _cleanup:
+                if exit_code == 0 and self.remove_on_success:
+                    self.on_kill()
+                elif exit_code != 0 and self.remove_on_error:
+                    self.on_kill()
 
     def on_kill(self) -> None:
         self.log.info("Deleting container group")
         try:
-            self._ci_hook.delete(self.resource_group, self.name)
+            self.hook.delete(self.resource_group, self.name)
         except Exception:
             self.log.exception("Could not delete container group")
 
+    def execute_complete(self, context: Context, event: dict[str, Any] | None) 
-> int:
+        """
+        Handle the trigger event after deferral.
+
+        Called by the Triggerer when the container reaches a terminal state.
+        Raises on failure; returns the exit code on success.
+        """
+        if event is None:
+            raise ValueError("Trigger error: event is None")
+
+        exit_code: int = event.get("exit_code", 1)
+
+        if event["status"] == "error":
+            if self.remove_on_error:
+                self.on_kill()
+            raise RuntimeError(
+                event.get(
+                    "message",
+                    f"Container group {self.resource_group}/{self.name} failed 
with exit code {exit_code}",
+                )
+            )

Review Comment:
   The non-deferrable path raises `AirflowException` on container failure, but 
the deferrable completion path raises a bare `RuntimeError`. For consistency 
with other operators and easier error handling, prefer raising 
`AirflowException` (preserving the same message) here as well.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to