Copilot commented on code in PR #62772: URL: https://github.com/apache/airflow/pull/62772#discussion_r3066478341
########## providers/microsoft/azure/src/airflow/providers/microsoft/azure/triggers/container_instance.py: ########## @@ -0,0 +1,123 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +from collections.abc import AsyncIterator +from typing import Any + +from airflow.providers.microsoft.azure.hooks.container_instance import AzureContainerInstanceAsyncHook +from airflow.triggers.base import BaseTrigger, TriggerEvent + +TERMINAL_STATES = frozenset({"Terminated", "Succeeded", "Failed", "Unhealthy"}) +SUCCESS_STATES = frozenset({"Terminated", "Succeeded"}) + + +class AzureContainerInstanceTrigger(BaseTrigger): + """ + Poll an Azure Container Instance until it reaches a terminal state. + + :param resource_group: the name of the resource group + :param name: the name of the container group + :param ci_conn_id: connection id of the Azure service principal + :param polling_interval: time in seconds between state polls + """ + + def __init__( + self, + resource_group: str, + name: str, + ci_conn_id: str, + polling_interval: float = 30.0, + ) -> None: + super().__init__() + self.resource_group = resource_group + self.name = name + self.ci_conn_id = ci_conn_id + self.polling_interval = polling_interval + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serialize trigger arguments and classpath.""" + return ( + "airflow.providers.microsoft.azure.triggers.container_instance.AzureContainerInstanceTrigger", + { + "resource_group": self.resource_group, + "name": self.name, + "ci_conn_id": self.ci_conn_id, + "polling_interval": self.polling_interval, + }, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + """Poll ACI until a terminal state is reached, then yield a TriggerEvent.""" + hook = AzureContainerInstanceAsyncHook(azure_conn_id=self.ci_conn_id) + try: + async with hook.get_async_conn() as client: + while True: + cg_state = await client.container_groups.get(self.resource_group, self.name) + instance_view = cg_state.containers[0].instance_view + + if instance_view is not None: + c_state = instance_view.current_state + state = c_state.state + exit_code = c_state.exit_code + detail_status = c_state.detail_status + else: + state = cg_state.provisioning_state + exit_code = 0 + detail_status = "Provisioning" + + if state in TERMINAL_STATES: + if state in SUCCESS_STATES and exit_code == 0: Review Comment: When `instance_view` is `None`, using `cg_state.provisioning_state` to decide terminal success can complete the trigger early (e.g., provisioning state may become `Succeeded` before the container has actually finished). Consider treating `instance_view is None` as non-terminal except for explicit provisioning failure states (e.g., `Failed`/`Unhealthy`), and only using container `current_state` for success/exit-code decisions. ########## providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/container_instance.py: ########## @@ -172,3 +185,84 @@ def test_connection(self): return False, str(e) return True, "Successfully connected to Azure Container Instance." + + +class AzureContainerInstanceAsyncHook(AzureContainerInstanceHook): + """ + An async hook for communicating with Azure Container Instances. + + :param azure_conn_id: :ref:`Azure connection id<howto/connection:azure>` of + a service principal which will be used to start the container instance. + """ + + def __init__(self, azure_conn_id: str = AzureContainerInstanceHook.default_conn_name) -> None: + super().__init__(azure_conn_id=azure_conn_id) + + @asynccontextmanager + async def get_async_conn(self) -> AsyncGenerator[AsyncContainerInstanceManagementClient, None]: + """Create an async management client bound to a single credential.""" + conn = await get_async_connection(self.conn_id) + tenant = conn.extra_dejson.get("tenantId") + subscription_id = cast("str", conn.extra_dejson.get("subscriptionId")) + + credential: AsyncClientSecretCredential | AsyncDefaultAzureCredential + if all([conn.login, conn.password, tenant]): + credential = AsyncClientSecretCredential( + client_id=cast("str", conn.login), + client_secret=cast("str", conn.password), + tenant_id=cast("str", tenant), + ) + else: + managed_identity_client_id = conn.extra_dejson.get("managed_identity_client_id") Review Comment: The sync hook supports `key_path`/`key_json` connection extras, but the async hook currently ignores them and always uses either client secret creds or default creds. This can make `deferrable=True` behave differently (or fail) for existing ACI connections that rely on `key_path`/`key_json`. Either add equivalent support in the async hook or raise a clear `AirflowException` when those extras are present so users understand the limitation. ########## providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/container_instances.py: ########## @@ -381,13 +395,36 @@ def execute(self, context: Context) -> int: identity=self.identity, ) - self._ci_hook.create_or_update(self.resource_group, self.name, container_group) + self.hook.create_or_update(self.resource_group, self.name, container_group) self.log.info("Container group started %s/%s", self.resource_group, self.name) + if self.deferrable: + cg_state = self.hook.get_state(self.resource_group, self.name) + instance_view = cg_state.containers[0].instance_view + current_state = ( + instance_view.current_state.state + if instance_view is not None + else cg_state.provisioning_state + ) + terminal_states = {"Terminated", "Succeeded", "Failed", "Unhealthy"} + if current_state not in terminal_states: Review Comment: In deferrable mode, treating `cg_state.provisioning_state` as a terminal state can prevent deferral (e.g., provisioning may become `Succeeded` while the container is still running but `instance_view` isn't populated yet). To keep deferrable behavior correct and consistent with monitoring logic, consider not classifying provisioning `Succeeded` as terminal for deferral decisions; defer until container execution state is terminal. ```suggestion if instance_view is not None and instance_view.current_state is not None else None ) terminal_states = {"Terminated", "Succeeded", "Failed", "Unhealthy"} if current_state is None or current_state not in terminal_states: ``` ########## providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_container_instances.py: ########## @@ -642,6 +644,317 @@ def test_execute_with_identity_dict(self, aci_mock): # user_assigned_identities should contain the resource id as a key assert resource_id in (called_cg.identity.user_assigned_identities or {}) + @mock.patch( + "airflow.providers.microsoft.azure.operators.container_instances.AzureContainerInstanceHook", + autospec=True, + ) + def test_execute_deferrable_defers_when_container_running(self, aci_mock): + """When deferrable=True and the container is still running, defer() is called.""" + running_cg = make_mock_container(state="Running", exit_code=0, detail_status="test") + aci_mock.return_value.get_state.return_value = running_cg + aci_mock.return_value.exists.return_value = False + + aci = AzureContainerInstancesOperator( + ci_conn_id="azure_default", + registry_conn_id=None, + resource_group="resource-group", + name="container-name", + image="container-image", + region="region", + task_id="task", + deferrable=True, + polling_interval=10.0, + ) + with pytest.raises(TaskDeferred) as exc_info: + aci.execute(None) + + assert isinstance(exc_info.value.trigger, AzureContainerInstanceTrigger) + assert exc_info.value.trigger.resource_group == "resource-group" + assert exc_info.value.trigger.name == "container-name" + assert exc_info.value.trigger.polling_interval == 10.0 + assert exc_info.value.method_name == "execute_complete" + # Container must NOT be deleted when deferring — it is still running + assert aci_mock.return_value.delete.call_count == 0 + + @mock.patch( + "airflow.providers.microsoft.azure.operators.container_instances.AzureContainerInstanceHook", + autospec=True, + ) + def test_execute_deferrable_completes_synchronously_if_already_terminated(self, aci_mock): + """When deferrable=True but container is already terminal, no deferral — sync completion.""" + terminated_cg = make_mock_container(state="Terminated", exit_code=0, detail_status="test") + aci_mock.return_value.get_state.return_value = terminated_cg + aci_mock.return_value.exists.return_value = False + + aci = AzureContainerInstancesOperator( + ci_conn_id="azure_default", + registry_conn_id=None, + resource_group="resource-group", + name="container-name", + image="container-image", + region="region", + task_id="task", + deferrable=True, + ) + result = aci.execute(None) + assert result == 0 + assert aci_mock.return_value.create_or_update.call_count == 1 + + @mock.patch( + "airflow.providers.microsoft.azure.operators.container_instances.AzureContainerInstanceHook", + autospec=True, + ) + def test_execute_complete_success(self, aci_mock): + """execute_complete succeeds, does not raise, and deletes the container group.""" + aci_mock.return_value.get_logs.return_value = None + + aci = AzureContainerInstancesOperator( + ci_conn_id="azure_default", + registry_conn_id=None, + resource_group="resource-group", + name="container-name", + image="container-image", + region="region", + task_id="task", + deferrable=True, + remove_on_success=True, + ) + result = aci.execute_complete( + context=None, + event={ + "status": "success", + "exit_code": 0, + "detail_status": "Completed", + "resource_group": "resource-group", + "name": "container-name", + }, + ) + assert result == 0 + assert aci_mock.return_value.delete.call_count == 1 + + @mock.patch( + "airflow.providers.microsoft.azure.operators.container_instances.AzureContainerInstanceHook", + autospec=True, + ) + def test_execute_complete_success_with_remove_on_success_false(self, aci_mock): + """execute_complete with remove_on_success=False should NOT delete the container.""" + aci_mock.return_value.get_logs.return_value = None + + aci = AzureContainerInstancesOperator( + ci_conn_id="azure_default", + registry_conn_id=None, + resource_group="resource-group", + name="container-name", + image="container-image", + region="region", + task_id="task", + deferrable=True, + remove_on_success=False, + ) + aci.execute_complete( + context=None, + event={ + "status": "success", + "exit_code": 0, + "detail_status": "Completed", + "resource_group": "resource-group", + "name": "container-name", + }, + ) + assert aci_mock.return_value.delete.call_count == 0 + + @mock.patch( + "airflow.providers.microsoft.azure.operators.container_instances.AzureContainerInstanceHook", + autospec=True, + ) + def test_execute_complete_error_event_raises(self, aci_mock): + """execute_complete raises AirflowException when event has status=error.""" Review Comment: This test docstring says `execute_complete` raises `AirflowException`, but the assertion expects `RuntimeError`. Please update the docstring (or the implementation) so the documented behavior matches what the test is actually verifying. ```suggestion """execute_complete raises RuntimeError when event has status=error.""" ``` ########## providers/microsoft/azure/src/airflow/providers/microsoft/azure/triggers/container_instance.py: ########## @@ -0,0 +1,123 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +from collections.abc import AsyncIterator +from typing import Any + +from airflow.providers.microsoft.azure.hooks.container_instance import AzureContainerInstanceAsyncHook +from airflow.triggers.base import BaseTrigger, TriggerEvent + +TERMINAL_STATES = frozenset({"Terminated", "Succeeded", "Failed", "Unhealthy"}) +SUCCESS_STATES = frozenset({"Terminated", "Succeeded"}) + + +class AzureContainerInstanceTrigger(BaseTrigger): + """ + Poll an Azure Container Instance until it reaches a terminal state. + + :param resource_group: the name of the resource group + :param name: the name of the container group + :param ci_conn_id: connection id of the Azure service principal + :param polling_interval: time in seconds between state polls + """ + + def __init__( + self, + resource_group: str, + name: str, + ci_conn_id: str, + polling_interval: float = 30.0, + ) -> None: + super().__init__() + self.resource_group = resource_group + self.name = name + self.ci_conn_id = ci_conn_id + self.polling_interval = polling_interval + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serialize trigger arguments and classpath.""" + return ( + "airflow.providers.microsoft.azure.triggers.container_instance.AzureContainerInstanceTrigger", + { + "resource_group": self.resource_group, + "name": self.name, + "ci_conn_id": self.ci_conn_id, + "polling_interval": self.polling_interval, + }, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + """Poll ACI until a terminal state is reached, then yield a TriggerEvent.""" + hook = AzureContainerInstanceAsyncHook(azure_conn_id=self.ci_conn_id) + try: + async with hook.get_async_conn() as client: + while True: + cg_state = await client.container_groups.get(self.resource_group, self.name) + instance_view = cg_state.containers[0].instance_view + + if instance_view is not None: + c_state = instance_view.current_state + state = c_state.state + exit_code = c_state.exit_code + detail_status = c_state.detail_status + else: + state = cg_state.provisioning_state + exit_code = 0 + detail_status = "Provisioning" + + if state in TERMINAL_STATES: + if state in SUCCESS_STATES and exit_code == 0: + yield TriggerEvent( + { + "status": "success", + "exit_code": exit_code, + "detail_status": detail_status, + "resource_group": self.resource_group, + "name": self.name, + } + ) + else: + yield TriggerEvent( + { + "status": "error", + "exit_code": exit_code, + "detail_status": detail_status, + "resource_group": self.resource_group, + "name": self.name, + "message": ( + f"Container group {self.resource_group}/{self.name} " + f"reached state {state!r} with exit code {exit_code} " + f"({detail_status})" + ), + } + ) + return + + await asyncio.sleep(self.polling_interval) Review Comment: Catching `Exception` will also catch `asyncio.CancelledError` on Python versions where it subclasses `Exception`, which can prevent clean trigger cancellation/shutdown and instead emit an error event. Handle `asyncio.CancelledError` explicitly (re-raise) before the generic exception handler. ```suggestion await asyncio.sleep(self.polling_interval) except asyncio.CancelledError: raise ``` ########## providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/container_instances.py: ########## @@ -193,6 +200,8 @@ def __init__( diagnostics: ContainerGroupDiagnostics | None = None, priority: str | None = "Regular", identity: ContainerGroupIdentity | dict | None = None, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + polling_interval: float = 30.0, Review Comment: The PR description states `polling_interval` defaults to `5.0s`, but the operator signature defaults to `30.0`. Please align the code default with the PR description (or update the PR description) to avoid confusing users and reviewers. ########## providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/container_instances.py: ########## @@ -407,24 +444,65 @@ def execute(self, context: Context) -> int: raise AirflowException("Could not start container group") finally: - if exit_code == 0 or self.remove_on_error: - self.on_kill() + if _cleanup: + if exit_code == 0 and self.remove_on_success: + self.on_kill() + elif exit_code != 0 and self.remove_on_error: + self.on_kill() def on_kill(self) -> None: self.log.info("Deleting container group") try: - self._ci_hook.delete(self.resource_group, self.name) + self.hook.delete(self.resource_group, self.name) except Exception: self.log.exception("Could not delete container group") + def execute_complete(self, context: Context, event: dict[str, Any] | None) -> int: + """ + Handle the trigger event after deferral. + + Called by the Triggerer when the container reaches a terminal state. + Raises on failure; returns the exit code on success. + """ + if event is None: + raise ValueError("Trigger error: event is None") + + exit_code: int = event.get("exit_code", 1) + + if event["status"] == "error": + if self.remove_on_error: + self.on_kill() + raise RuntimeError( + event.get( + "message", + f"Container group {self.resource_group}/{self.name} failed with exit code {exit_code}", + ) + ) Review Comment: The non-deferrable path raises `AirflowException` on container failure, but the deferrable completion path raises a bare `RuntimeError`. For consistency with other operators and easier error handling, prefer raising `AirflowException` (preserving the same message) here as well. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
