This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 9d5327806f Bump azure-mgmt-containerinstance>=7.0.0,<9.0.0 (#33696)
9d5327806f is described below

commit 9d5327806fac61cd62abd30a6339b0cb26ad1ebf
Author: Pankaj Singh <[email protected]>
AuthorDate: Mon Aug 28 16:39:11 2023 +0530

    Bump azure-mgmt-containerinstance>=7.0.0,<9.0.0 (#33696)
---
 .../microsoft/azure/hooks/container_instance.py    |  68 +++++++++++--
 .../azure/operators/container_instances.py         |   3 +-
 airflow/providers/microsoft/azure/provider.yaml    |   2 +-
 generated/provider_dependencies.json               |   2 +-
 .../azure/hooks/test_azure_container_instance.py   |   8 +-
 .../operators/test_azure_container_instances.py    | 106 +++++++++++++--------
 6 files changed, 136 insertions(+), 53 deletions(-)

diff --git a/airflow/providers/microsoft/azure/hooks/container_instance.py 
b/airflow/providers/microsoft/azure/hooks/container_instance.py
index 77e24ab944..cdebe7a241 100644
--- a/airflow/providers/microsoft/azure/hooks/container_instance.py
+++ b/airflow/providers/microsoft/azure/hooks/container_instance.py
@@ -19,11 +19,14 @@ from __future__ import annotations
 
 import warnings
 from functools import cached_property
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any
 
+from azure.common.client_factory import get_client_from_auth_file, 
get_client_from_json_dict
+from azure.common.credentials import ServicePrincipalCredentials
+from azure.identity import DefaultAzureCredential
 from azure.mgmt.containerinstance import ContainerInstanceManagementClient
 
-from airflow.exceptions import AirflowProviderDeprecationWarning
+from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarning
 from airflow.providers.microsoft.azure.hooks.base_azure import AzureBaseHook
 
 if TYPE_CHECKING:
@@ -56,6 +59,59 @@ class AzureContainerInstanceHook(AzureBaseHook):
     def connection(self):
         return self.get_conn()
 
+    def get_conn(self) -> Any:
+        """
+        Authenticates the resource using the connection id passed during init.
+
+        :return: the authenticated client.
+        """
+        conn = self.get_connection(self.conn_id)
+        tenant = conn.extra_dejson.get("tenantId")
+        if not tenant and conn.extra_dejson.get("extra__azure__tenantId"):
+            warnings.warn(
+                "`extra__azure__tenantId` is deprecated in azure connection 
extra, "
+                "please use `tenantId` instead",
+                AirflowProviderDeprecationWarning,
+                stacklevel=2,
+            )
+            tenant = conn.extra_dejson.get("extra__azure__tenantId")
+        subscription_id = conn.extra_dejson.get("subscriptionId")
+        if not subscription_id and 
conn.extra_dejson.get("extra__azure__subscriptionId"):
+            warnings.warn(
+                "`extra__azure__subscriptionId` is deprecated in azure 
connection extra, "
+                "please use `subscriptionId` instead",
+                AirflowProviderDeprecationWarning,
+                stacklevel=2,
+            )
+            subscription_id = 
conn.extra_dejson.get("extra__azure__subscriptionId")
+
+        key_path = conn.extra_dejson.get("key_path")
+        if key_path:
+            if not key_path.endswith(".json"):
+                raise AirflowException("Unrecognised extension for key file.")
+            self.log.info("Getting connection using a JSON key file.")
+            return get_client_from_auth_file(client_class=self.sdk_client, 
auth_path=key_path)
+
+        key_json = conn.extra_dejson.get("key_json")
+        if key_json:
+            self.log.info("Getting connection using a JSON config.")
+            return get_client_from_json_dict(client_class=self.sdk_client, 
config_dict=key_json)
+
+        credential: ServicePrincipalCredentials | DefaultAzureCredential
+        if all([conn.login, conn.password, tenant]):
+            self.log.info("Getting connection using specific credentials and 
subscription_id.")
+            credential = ServicePrincipalCredentials(
+                client_id=conn.login, secret=conn.password, tenant=tenant
+            )
+        else:
+            self.log.info("Using DefaultAzureCredential as credential")
+            credential = DefaultAzureCredential()
+
+        return ContainerInstanceManagementClient(
+            credential=credential,
+            subscription_id=subscription_id,
+        )
+
     def create_or_update(self, resource_group: str, name: str, 
container_group: ContainerGroup) -> None:
         """
         Create a new container group.
@@ -64,7 +120,7 @@ class AzureContainerInstanceHook(AzureBaseHook):
         :param name: the name of the container group
         :param container_group: the properties of the container group
         """
-        self.connection.container_groups.create_or_update(resource_group, 
name, container_group)
+        
self.connection.container_groups.begin_create_or_update(resource_group, name, 
container_group)
 
     def get_state_exitcode_details(self, resource_group: str, name: str) -> 
tuple:
         """
@@ -109,7 +165,7 @@ class AzureContainerInstanceHook(AzureBaseHook):
         :param name: the name of the container group
         :return: ContainerGroup
         """
-        return self.connection.container_groups.get(resource_group, name, 
raw=False)
+        return self.connection.container_groups.get(resource_group, name)
 
     def get_logs(self, resource_group: str, name: str, tail: int = 1000) -> 
list:
         """
@@ -120,7 +176,7 @@ class AzureContainerInstanceHook(AzureBaseHook):
         :param tail: the size of the tail
         :return: A list of log messages
         """
-        logs = self.connection.container.list_logs(resource_group, name, name, 
tail=tail)
+        logs = self.connection.containers.list_logs(resource_group, name, 
name, tail=tail)
         return logs.content.splitlines(True)
 
     def delete(self, resource_group: str, name: str) -> None:
@@ -130,7 +186,7 @@ class AzureContainerInstanceHook(AzureBaseHook):
         :param resource_group: the name of the resource group
         :param name: the name of the container group
         """
-        self.connection.container_groups.delete(resource_group, name)
+        self.connection.container_groups.begin_delete(resource_group, name)
 
     def exists(self, resource_group: str, name: str) -> bool:
         """
diff --git a/airflow/providers/microsoft/azure/operators/container_instances.py 
b/airflow/providers/microsoft/azure/operators/container_instances.py
index 8b5fba26cd..efbe8259b1 100644
--- a/airflow/providers/microsoft/azure/operators/container_instances.py
+++ b/airflow/providers/microsoft/azure/operators/container_instances.py
@@ -126,11 +126,11 @@ class AzureContainerInstancesOperator(BaseOperator):
         self,
         *,
         ci_conn_id: str,
-        registry_conn_id: str | None,
         resource_group: str,
         name: str,
         image: str,
         region: str,
+        registry_conn_id: str | None = None,
         environment_variables: dict | None = None,
         secured_variables: str | None = None,
         volumes: list | None = None,
@@ -295,7 +295,6 @@ class AzureContainerInstancesOperator(BaseOperator):
             try:
                 cg_state = self._ci_hook.get_state(resource_group, name)
                 instance_view = cg_state.containers[0].instance_view
-
                 # If there is no instance view, we show the provisioning state
                 if instance_view is not None:
                     c_state = instance_view.current_state
diff --git a/airflow/providers/microsoft/azure/provider.yaml 
b/airflow/providers/microsoft/azure/provider.yaml
index e3ef3db6f8..2c9868bd84 100644
--- a/airflow/providers/microsoft/azure/provider.yaml
+++ b/airflow/providers/microsoft/azure/provider.yaml
@@ -83,7 +83,7 @@ dependencies:
   - azure-kusto-data>=4.1.0
   # TODO: upgrade to newer versions of all the below libraries.
   #   See issue https://github.com/apache/airflow/issues/30199
-  - azure-mgmt-containerinstance>=1.5.0,<2.0
+  - azure-mgmt-containerinstance>=7.0.0,<9.0.0
   - azure-mgmt-datafactory>=1.0.0,<2.0
 
 integrations:
diff --git a/generated/provider_dependencies.json 
b/generated/provider_dependencies.json
index bfe118e289..b82c592b4f 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -559,7 +559,7 @@
       "azure-identity>=1.3.1",
       "azure-keyvault-secrets>=4.1.0",
       "azure-kusto-data>=4.1.0",
-      "azure-mgmt-containerinstance>=1.5.0,<2.0",
+      "azure-mgmt-containerinstance>=7.0.0,<9.0.0",
       "azure-mgmt-cosmosdb",
       "azure-mgmt-datafactory>=1.0.0,<2.0",
       "azure-mgmt-datalake-store>=0.5.0",
diff --git 
a/tests/providers/microsoft/azure/hooks/test_azure_container_instance.py 
b/tests/providers/microsoft/azure/hooks/test_azure_container_instance.py
index 786df4eb16..1e7e19a138 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_container_instance.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_container_instance.py
@@ -52,7 +52,7 @@ class TestAzureContainerInstanceHook:
             yield
 
     @patch("azure.mgmt.containerinstance.models.ContainerGroup")
-    
@patch("azure.mgmt.containerinstance.operations.ContainerGroupsOperations.create_or_update")
+    
@patch("azure.mgmt.containerinstance.operations.ContainerGroupsOperations.begin_create_or_update")
     def test_create_or_update(self, create_or_update_mock, 
container_group_mock):
         self.hook.create_or_update("resource_group", "aci-test", 
container_group_mock)
         create_or_update_mock.assert_called_once_with("resource_group", 
"aci-test", container_group_mock)
@@ -60,9 +60,9 @@ class TestAzureContainerInstanceHook:
     
@patch("azure.mgmt.containerinstance.operations.ContainerGroupsOperations.get")
     def test_get_state(self, get_state_mock):
         self.hook.get_state("resource_group", "aci-test")
-        get_state_mock.assert_called_once_with("resource_group", "aci-test", 
raw=False)
+        get_state_mock.assert_called_once_with("resource_group", "aci-test")
 
-    
@patch("azure.mgmt.containerinstance.operations.ContainerOperations.list_logs")
+    
@patch("azure.mgmt.containerinstance.operations.ContainersOperations.list_logs")
     def test_get_logs(self, list_logs_mock):
         expected_messages = ["log line 1\n", "log line 2\n", "log line 3\n"]
         logs = Logs(content="".join(expected_messages))
@@ -72,7 +72,7 @@ class TestAzureContainerInstanceHook:
 
         assert logs == expected_messages
 
-    
@patch("azure.mgmt.containerinstance.operations.ContainerGroupsOperations.delete")
+    
@patch("azure.mgmt.containerinstance.operations.ContainerGroupsOperations.begin_delete")
     def test_delete(self, delete_mock):
         self.hook.delete("resource_group", "aci-test")
         delete_mock.assert_called_once_with("resource_group", "aci-test")
diff --git 
a/tests/providers/microsoft/azure/operators/test_azure_container_instances.py 
b/tests/providers/microsoft/azure/operators/test_azure_container_instances.py
index 6c7cd1dbb5..2984fe60ee 100644
--- 
a/tests/providers/microsoft/azure/operators/test_azure_container_instances.py
+++ 
b/tests/providers/microsoft/azure/operators/test_azure_container_instances.py
@@ -22,7 +22,13 @@ from unittest import mock
 from unittest.mock import MagicMock
 
 import pytest
-from azure.mgmt.containerinstance.models import ContainerState, Event
+from azure.mgmt.containerinstance.models import (
+    Container,
+    ContainerGroup,
+    ContainerPropertiesInstanceView,
+    ContainerState,
+    Event,
+)
 
 from airflow.exceptions import AirflowException
 from airflow.providers.microsoft.azure.operators.container_instances import 
AzureContainerInstancesOperator
@@ -35,10 +41,12 @@ def make_mock_cg(container_state, events=None):
     """
     events = events or []
     instance_view_dict = {"current_state": container_state, "events": events}
-    instance_view = namedtuple("InstanceView", 
instance_view_dict.keys())(*instance_view_dict.values())
+    instance_view = namedtuple("ContainerPropertiesInstanceView", 
instance_view_dict.keys())(
+        *instance_view_dict.values()
+    )
 
     container_dict = {"instance_view": instance_view}
-    container = namedtuple("Container", 
container_dict.keys())(*container_dict.values())
+    container = namedtuple("Containers", 
container_dict.keys())(*container_dict.values())
 
     container_g_dict = {"containers": [container]}
     container_g = namedtuple("ContainerGroup", 
container_g_dict.keys())(*container_g_dict.values())
@@ -53,23 +61,42 @@ def make_mock_cg_with_missing_events(container_state):
     This can happen, when the container group is provisioned, but not started.
     """
     instance_view_dict = {"current_state": container_state, "events": None}
-    instance_view = namedtuple("InstanceView", 
instance_view_dict.keys())(*instance_view_dict.values())
+    instance_view = namedtuple("ContainerPropertiesInstanceView", 
instance_view_dict.keys())(
+        *instance_view_dict.values()
+    )
 
     container_dict = {"instance_view": instance_view}
-    container = namedtuple("Container", 
container_dict.keys())(*container_dict.values())
+    container = namedtuple("Containers", 
container_dict.keys())(*container_dict.values())
 
     container_g_dict = {"containers": [container]}
     container_g = namedtuple("ContainerGroup", 
container_g_dict.keys())(*container_g_dict.values())
     return container_g
 
 
+def make_mock_container(state: str, exit_code: int, detail_status: str, 
events: Event | None = None):
+    container = Container(name="hello_world", image="test", resources="test")
+    container_prop = ContainerPropertiesInstanceView()
+    container_state = ContainerState()
+    container_state.state = state
+    container_state.exit_code = exit_code
+    container_state.detail_status = detail_status
+    container_prop.current_state = container_state
+    if events:
+        container_prop.events = events
+    container.instance_view = container_prop
+
+    cg = ContainerGroup(containers=[container], os_type="Linux")
+
+    return cg
+
+
 class TestACIOperator:
     
@mock.patch("airflow.providers.microsoft.azure.operators.container_instances.AzureContainerInstanceHook")
     def test_execute(self, aci_mock):
-        expected_c_state = ContainerState(state="Terminated", exit_code=0, 
detail_status="test")
-        expected_cg = make_mock_cg(expected_c_state)
+        expected_cg = make_mock_container(state="Terminated", exit_code=0, 
detail_status="test")
 
         aci_mock.return_value.get_state.return_value = expected_cg
+
         aci_mock.return_value.exists.return_value = False
 
         aci = AzureContainerInstancesOperator(
@@ -102,10 +129,10 @@ class TestACIOperator:
 
     
@mock.patch("airflow.providers.microsoft.azure.operators.container_instances.AzureContainerInstanceHook")
     def test_execute_with_failures(self, aci_mock):
-        expected_c_state = ContainerState(state="Terminated", exit_code=1, 
detail_status="test")
-        expected_cg = make_mock_cg(expected_c_state)
 
+        expected_cg = make_mock_container(state="Terminated", exit_code=1, 
detail_status="test")
         aci_mock.return_value.get_state.return_value = expected_cg
+
         aci_mock.return_value.exists.return_value = False
 
         aci = AzureContainerInstancesOperator(
@@ -124,11 +151,11 @@ class TestACIOperator:
 
     
@mock.patch("airflow.providers.microsoft.azure.operators.container_instances.AzureContainerInstanceHook")
     def test_execute_with_tags(self, aci_mock):
-        expected_c_state = ContainerState(state="Terminated", exit_code=0, 
detail_status="test")
-        expected_cg = make_mock_cg(expected_c_state)
-        tags = {"testKey": "testValue"}
 
+        expected_cg = make_mock_container(state="Terminated", exit_code=0, 
detail_status="test")
         aci_mock.return_value.get_state.return_value = expected_cg
+        tags = {"testKey": "testValue"}
+
         aci_mock.return_value.exists.return_value = False
 
         aci = AzureContainerInstancesOperator(
@@ -163,13 +190,18 @@ class TestACIOperator:
 
     
@mock.patch("airflow.providers.microsoft.azure.operators.container_instances.AzureContainerInstanceHook")
     def test_execute_with_messages_logs(self, aci_mock):
-        events = [Event(message="test"), Event(message="messages")]
-        expected_c_state1 = ContainerState(state="Succeeded", exit_code=0, 
detail_status="test")
-        expected_cg1 = make_mock_cg(expected_c_state1, events)
-        expected_c_state2 = ContainerState(state="Running", exit_code=0, 
detail_status="test")
-        expected_cg2 = make_mock_cg(expected_c_state2, events)
-        expected_c_state3 = ContainerState(state="Terminated", exit_code=0, 
detail_status="test")
-        expected_cg3 = make_mock_cg(expected_c_state3, events)
+        event1 = Event()
+        event1.message = "test"
+        event2 = Event()
+        event2.message = "messages"
+        events = [event1, event2]
+        expected_cg1 = make_mock_container(
+            state="Succeeded", exit_code=0, detail_status="test", events=events
+        )
+        expected_cg2 = make_mock_container(state="Running", exit_code=0, 
detail_status="test", events=events)
+        expected_cg3 = make_mock_container(
+            state="Terminated", exit_code=0, detail_status="test", 
events=events
+        )
 
         aci_mock.return_value.get_state.side_effect = [expected_cg1, 
expected_cg2, expected_cg3]
         aci_mock.return_value.get_logs.return_value = ["test", "logs"]
@@ -211,11 +243,11 @@ class TestACIOperator:
 
     
@mock.patch("airflow.providers.microsoft.azure.operators.container_instances.AzureContainerInstanceHook")
     def test_execute_with_ipaddress(self, aci_mock):
-        expected_c_state = ContainerState(state="Terminated", exit_code=0, 
detail_status="test")
-        expected_cg = make_mock_cg(expected_c_state)
         ipaddress = MagicMock()
 
-        aci_mock.return_value.get_state.return_value = expected_cg
+        aci_mock.return_value.get_state.return_value = make_mock_container(
+            state="Terminated", exit_code=0, detail_status="test"
+        )
         aci_mock.return_value.exists.return_value = False
 
         aci = AzureContainerInstancesOperator(
@@ -236,10 +268,10 @@ class TestACIOperator:
 
     
@mock.patch("airflow.providers.microsoft.azure.operators.container_instances.AzureContainerInstanceHook")
     def test_execute_with_windows_os_and_diff_restart_policy(self, aci_mock):
-        expected_c_state = ContainerState(state="Terminated", exit_code=0, 
detail_status="test")
-        expected_cg = make_mock_cg(expected_c_state)
 
-        aci_mock.return_value.get_state.return_value = expected_cg
+        aci_mock.return_value.get_state.return_value = make_mock_container(
+            state="Terminated", exit_code=0, detail_status="test"
+        )
         aci_mock.return_value.exists.return_value = False
 
         aci = AzureContainerInstancesOperator(
@@ -262,10 +294,10 @@ class TestACIOperator:
 
     
@mock.patch("airflow.providers.microsoft.azure.operators.container_instances.AzureContainerInstanceHook")
     def test_execute_fails_with_incorrect_os_type(self, aci_mock):
-        expected_c_state = ContainerState(state="Terminated", exit_code=0, 
detail_status="test")
-        expected_cg = make_mock_cg(expected_c_state)
 
-        aci_mock.return_value.get_state.return_value = expected_cg
+        aci_mock.return_value.get_state.return_value = make_mock_container(
+            state="Terminated", exit_code=0, detail_status="test"
+        )
         aci_mock.return_value.exists.return_value = False
 
         with pytest.raises(AirflowException) as ctx:
@@ -288,10 +320,10 @@ class TestACIOperator:
 
     
@mock.patch("airflow.providers.microsoft.azure.operators.container_instances.AzureContainerInstanceHook")
     def test_execute_fails_with_incorrect_restart_policy(self, aci_mock):
-        expected_c_state = ContainerState(state="Terminated", exit_code=0, 
detail_status="test")
-        expected_cg = make_mock_cg(expected_c_state)
 
-        aci_mock.return_value.get_state.return_value = expected_cg
+        aci_mock.return_value.get_state.return_value = make_mock_container(
+            state="Terminated", exit_code=0, detail_status="test"
+        )
         aci_mock.return_value.exists.return_value = False
 
         with pytest.raises(AirflowException) as ctx:
@@ -315,10 +347,8 @@ class TestACIOperator:
     
@mock.patch("airflow.providers.microsoft.azure.operators.container_instances.AzureContainerInstanceHook")
     
@mock.patch("airflow.providers.microsoft.azure.operators.container_instances.sleep")
     def test_execute_correct_sleep_cycle(self, sleep_mock, aci_mock):
-        expected_c_state1 = ContainerState(state="Running", exit_code=0, 
detail_status="test")
-        expected_cg1 = make_mock_cg(expected_c_state1)
-        expected_c_state2 = ContainerState(state="Terminated", exit_code=0, 
detail_status="test")
-        expected_cg2 = make_mock_cg(expected_c_state2)
+        expected_cg1 = make_mock_container(state="Running", exit_code=0, 
detail_status="test")
+        expected_cg2 = make_mock_container(state="Terminated", exit_code=0, 
detail_status="test")
 
         aci_mock.return_value.get_state.side_effect = [expected_cg1, 
expected_cg1, expected_cg2]
         aci_mock.return_value.exists.return_value = False
@@ -340,10 +370,8 @@ class TestACIOperator:
     
@mock.patch("airflow.providers.microsoft.azure.operators.container_instances.AzureContainerInstanceHook")
     @mock.patch("logging.Logger.exception")
     def test_execute_with_missing_events(self, log_mock, aci_mock):
-        expected_c_state1 = ContainerState(state="Running", exit_code=0, 
detail_status="test")
-        expected_cg1 = make_mock_cg_with_missing_events(expected_c_state1)
-        expected_c_state2 = ContainerState(state="Terminated", exit_code=0, 
detail_status="test")
-        expected_cg2 = make_mock_cg(expected_c_state2)
+        expected_cg1 = make_mock_container(state="Running", exit_code=0, 
detail_status="test")
+        expected_cg2 = make_mock_container(state="Terminated", exit_code=0, 
detail_status="test")
 
         aci_mock.return_value.get_state.side_effect = [expected_cg1, 
expected_cg2]
         aci_mock.return_value.exists.return_value = False

Reply via email to