dabla commented on code in PR #62391:
URL: https://github.com/apache/airflow/pull/62391#discussion_r2848983340


##########
providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/compute.py:
##########
@@ -0,0 +1,158 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from collections.abc import Sequence
+from typing import TYPE_CHECKING
+
+from airflow.providers.common.compat.sdk import BaseOperator
+from airflow.providers.microsoft.azure.hooks.compute import AzureComputeHook
+
+if TYPE_CHECKING:
+    from airflow.sdk import Context
+
+
+class AzureVirtualMachineStartOperator(BaseOperator):

Review Comment:
   I would add a BaseAzureVirtualMachineOperator class which has following hook 
property:
   
   ```
   @cached_property
   def hook(self) -> AzureComputeHook:
      return AzureComputeHook(azure_conn_id=self.azure_conn_id)
   ```
   
   That way all underlying operators inherit from 
BaseAzureVirtualMachineOperator and just can call self.hook and you have a more 
DRY approach.



##########
providers/microsoft/azure/src/airflow/providers/microsoft/azure/triggers/compute.py:
##########
@@ -0,0 +1,90 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+from collections.abc import AsyncIterator
+from typing import Any
+
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+
+
+class AzureVirtualMachineStateTrigger(BaseTrigger):
+    """
+    Poll the Azure VM power state and yield a TriggerEvent once it matches the 
target.
+
+    Uses the native async Azure SDK client (``azure.mgmt.compute.aio``) so that
+    the triggerer event loop is never blocked.
+
+    :param resource_group_name: Name of the Azure resource group.
+    :param vm_name: Name of the virtual machine.
+    :param target_state: Desired power state, e.g. ``running``, 
``deallocated``.
+    :param azure_conn_id: Azure connection id.
+    :param poke_interval: Polling interval in seconds.
+    """
+
+    def __init__(
+        self,
+        resource_group_name: str,
+        vm_name: str,
+        target_state: str,
+        azure_conn_id: str = "azure_default",
+        poke_interval: float = 30.0,
+    ) -> None:
+        super().__init__()
+        self.resource_group_name = resource_group_name
+        self.vm_name = vm_name
+        self.target_state = target_state
+        self.azure_conn_id = azure_conn_id
+        self.poke_interval = poke_interval
+
+    def serialize(self) -> tuple[str, dict[str, Any]]:
+        """Serialize AzureVirtualMachineStateTrigger arguments and 
classpath."""
+        return (
+            
"airflow.providers.microsoft.azure.triggers.compute.AzureVirtualMachineStateTrigger",

Review Comment:
   I would do it like this which is more refactor prone/safe:
   
   f"{self.__class__.__module__}.{self.__class__.__name__}",
   
   Instead of:
   
   
"airflow.providers.microsoft.azure.triggers.compute.AzureVirtualMachineStateTrigger",



##########
providers/microsoft/azure/src/airflow/providers/microsoft/azure/triggers/compute.py:
##########
@@ -0,0 +1,90 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+from collections.abc import AsyncIterator
+from typing import Any
+
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+
+
+class AzureVirtualMachineStateTrigger(BaseTrigger):
+    """
+    Poll the Azure VM power state and yield a TriggerEvent once it matches the 
target.
+
+    Uses the native async Azure SDK client (``azure.mgmt.compute.aio``) so that
+    the triggerer event loop is never blocked.
+
+    :param resource_group_name: Name of the Azure resource group.
+    :param vm_name: Name of the virtual machine.
+    :param target_state: Desired power state, e.g. ``running``, 
``deallocated``.
+    :param azure_conn_id: Azure connection id.
+    :param poke_interval: Polling interval in seconds.
+    """
+
+    def __init__(
+        self,
+        resource_group_name: str,
+        vm_name: str,
+        target_state: str,
+        azure_conn_id: str = "azure_default",
+        poke_interval: float = 30.0,
+    ) -> None:
+        super().__init__()
+        self.resource_group_name = resource_group_name
+        self.vm_name = vm_name
+        self.target_state = target_state
+        self.azure_conn_id = azure_conn_id
+        self.poke_interval = poke_interval
+
+    def serialize(self) -> tuple[str, dict[str, Any]]:
+        """Serialize AzureVirtualMachineStateTrigger arguments and 
classpath."""
+        return (
+            
"airflow.providers.microsoft.azure.triggers.compute.AzureVirtualMachineStateTrigger",
+            {
+                "resource_group_name": self.resource_group_name,
+                "vm_name": self.vm_name,
+                "target_state": self.target_state,
+                "azure_conn_id": self.azure_conn_id,
+                "poke_interval": self.poke_interval,
+            },
+        )
+
+    async def run(self) -> AsyncIterator[TriggerEvent]:
+        """Poll VM power state asynchronously until it matches the target 
state."""
+        from airflow.providers.microsoft.azure.hooks.compute import 
AzureComputeHook
+
+        try:
+            async with AzureComputeHook(azure_conn_id=self.azure_conn_id) as 
hook:
+                while True:
+                    power_state = await 
hook.async_get_power_state(self.resource_group_name, self.vm_name)
+                    if power_state == self.target_state:
+                        message = f"VM {self.vm_name} reached state 
'{self.target_state}'."
+                        yield TriggerEvent({"status": "success", "message": 
message})
+                        return
+                    self.log.info(
+                        "VM %s power state: %s. Waiting for %s. Sleeping for 
%s seconds.",
+                        self.vm_name,
+                        power_state,
+                        self.target_state,
+                        self.poke_interval,
+                    )
+                    await asyncio.sleep(self.poke_interval)
+        except Exception as e:
+            yield TriggerEvent({"status": "error", "message": str(e)})

Review Comment:
   Probably best to also add return statement here



##########
providers/microsoft/azure/tests/unit/microsoft/azure/triggers/test_compute.py:
##########
@@ -0,0 +1,121 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+
+import pytest
+
+from airflow.providers.microsoft.azure.triggers.compute import 
AzureVirtualMachineStateTrigger
+from airflow.triggers.base import TriggerEvent
+
+RESOURCE_GROUP = "test-rg"
+VM_NAME = "test-vm"
+TARGET_STATE = "running"
+CONN_ID = "azure_default"
+POKE_INTERVAL = 10.0
+
+
+class TestAzureVirtualMachineStateTrigger:
+    def test_serialize(self):
+        trigger = AzureVirtualMachineStateTrigger(
+            resource_group_name=RESOURCE_GROUP,
+            vm_name=VM_NAME,
+            target_state=TARGET_STATE,
+            azure_conn_id=CONN_ID,
+            poke_interval=POKE_INTERVAL,
+        )
+
+        class_path, args = trigger.serialize()

Review Comment:
   I would change assert here:
   
   ```
   actual = trigger.serialize()
   
           assert isinstance(actual, tuple)
           assert actual[0] == 
f"{AzureVirtualMachineStateTrigger.__module__}.{AzureVirtualMachineStateTrigger.__name__}"
           assert actual[1] == {
               "resource_group_name": RESOURCE_GROUP,
               "vm_name": VM_NAME,
               "target_state": TARGET_STATE,
               "azure_conn_id": CONN_ID,
               "poke_interval": POKE_INTERVAL,
           }
   ```



##########
providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/compute.py:
##########
@@ -0,0 +1,237 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from functools import cached_property
+from typing import Any, cast
+
+from azure.common.client_factory import get_client_from_auth_file, 
get_client_from_json_dict
+from azure.identity import ClientSecretCredential, DefaultAzureCredential
+from azure.identity.aio import (
+    ClientSecretCredential as AsyncClientSecretCredential,
+    DefaultAzureCredential as AsyncDefaultAzureCredential,
+)
+from azure.mgmt.compute import ComputeManagementClient
+from azure.mgmt.compute.aio import ComputeManagementClient as 
AsyncComputeManagementClient
+
+from airflow.providers.common.compat.connection import get_async_connection
+from airflow.providers.microsoft.azure.hooks.base_azure import AzureBaseHook
+from airflow.providers.microsoft.azure.utils import (
+    get_async_default_azure_credential,
+    get_sync_default_azure_credential,
+)
+
+
+class AzureComputeHook(AzureBaseHook):
+    """
+    A hook to interact with Azure Compute to manage Virtual Machines.
+
+    :param azure_conn_id: :ref:`Azure connection id<howto/connection:azure>` of
+        a service principal which will be used to manage virtual machines.
+    """
+
+    conn_name_attr = "azure_conn_id"
+    default_conn_name = "azure_default"
+    conn_type = "azure_compute"
+    hook_name = "Azure Compute"
+
+    def __init__(self, azure_conn_id: str = default_conn_name) -> None:
+        super().__init__(sdk_client=ComputeManagementClient, 
conn_id=azure_conn_id)
+        self._async_conn: AsyncComputeManagementClient | None = None
+
+    @cached_property
+    def connection(self) -> ComputeManagementClient:
+        return self.get_conn()
+
+    def get_conn(self) -> Any:

Review Comment:
   would return ComputeManagementClient as type here instead of Any



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to