This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 9d5327806f Bump azure-mgmt-containerinstance>=7.0.0,<9.0.0 (#33696)
9d5327806f is described below
commit 9d5327806fac61cd62abd30a6339b0cb26ad1ebf
Author: Pankaj Singh <[email protected]>
AuthorDate: Mon Aug 28 16:39:11 2023 +0530
Bump azure-mgmt-containerinstance>=7.0.0,<9.0.0 (#33696)
---
.../microsoft/azure/hooks/container_instance.py | 68 +++++++++++--
.../azure/operators/container_instances.py | 3 +-
airflow/providers/microsoft/azure/provider.yaml | 2 +-
generated/provider_dependencies.json | 2 +-
.../azure/hooks/test_azure_container_instance.py | 8 +-
.../operators/test_azure_container_instances.py | 106 +++++++++++++--------
6 files changed, 136 insertions(+), 53 deletions(-)
diff --git a/airflow/providers/microsoft/azure/hooks/container_instance.py
b/airflow/providers/microsoft/azure/hooks/container_instance.py
index 77e24ab944..cdebe7a241 100644
--- a/airflow/providers/microsoft/azure/hooks/container_instance.py
+++ b/airflow/providers/microsoft/azure/hooks/container_instance.py
@@ -19,11 +19,14 @@ from __future__ import annotations
import warnings
from functools import cached_property
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any
+from azure.common.client_factory import get_client_from_auth_file,
get_client_from_json_dict
+from azure.common.credentials import ServicePrincipalCredentials
+from azure.identity import DefaultAzureCredential
from azure.mgmt.containerinstance import ContainerInstanceManagementClient
-from airflow.exceptions import AirflowProviderDeprecationWarning
+from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarning
from airflow.providers.microsoft.azure.hooks.base_azure import AzureBaseHook
if TYPE_CHECKING:
@@ -56,6 +59,59 @@ class AzureContainerInstanceHook(AzureBaseHook):
def connection(self):
return self.get_conn()
+ def get_conn(self) -> Any:
+ """
+ Authenticates the resource using the connection id passed during init.
+
+ :return: the authenticated client.
+ """
+ conn = self.get_connection(self.conn_id)
+ tenant = conn.extra_dejson.get("tenantId")
+ if not tenant and conn.extra_dejson.get("extra__azure__tenantId"):
+ warnings.warn(
+ "`extra__azure__tenantId` is deprecated in azure connection
extra, "
+ "please use `tenantId` instead",
+ AirflowProviderDeprecationWarning,
+ stacklevel=2,
+ )
+ tenant = conn.extra_dejson.get("extra__azure__tenantId")
+ subscription_id = conn.extra_dejson.get("subscriptionId")
+ if not subscription_id and
conn.extra_dejson.get("extra__azure__subscriptionId"):
+ warnings.warn(
+ "`extra__azure__subscriptionId` is deprecated in azure
connection extra, "
+ "please use `subscriptionId` instead",
+ AirflowProviderDeprecationWarning,
+ stacklevel=2,
+ )
+ subscription_id =
conn.extra_dejson.get("extra__azure__subscriptionId")
+
+ key_path = conn.extra_dejson.get("key_path")
+ if key_path:
+ if not key_path.endswith(".json"):
+ raise AirflowException("Unrecognised extension for key file.")
+ self.log.info("Getting connection using a JSON key file.")
+ return get_client_from_auth_file(client_class=self.sdk_client,
auth_path=key_path)
+
+ key_json = conn.extra_dejson.get("key_json")
+ if key_json:
+ self.log.info("Getting connection using a JSON config.")
+ return get_client_from_json_dict(client_class=self.sdk_client,
config_dict=key_json)
+
+ credential: ServicePrincipalCredentials | DefaultAzureCredential
+ if all([conn.login, conn.password, tenant]):
+ self.log.info("Getting connection using specific credentials and
subscription_id.")
+ credential = ServicePrincipalCredentials(
+ client_id=conn.login, secret=conn.password, tenant=tenant
+ )
+ else:
+ self.log.info("Using DefaultAzureCredential as credential")
+ credential = DefaultAzureCredential()
+
+ return ContainerInstanceManagementClient(
+ credential=credential,
+ subscription_id=subscription_id,
+ )
+
def create_or_update(self, resource_group: str, name: str,
container_group: ContainerGroup) -> None:
"""
Create a new container group.
@@ -64,7 +120,7 @@ class AzureContainerInstanceHook(AzureBaseHook):
:param name: the name of the container group
:param container_group: the properties of the container group
"""
- self.connection.container_groups.create_or_update(resource_group,
name, container_group)
+
self.connection.container_groups.begin_create_or_update(resource_group, name,
container_group)
def get_state_exitcode_details(self, resource_group: str, name: str) ->
tuple:
"""
@@ -109,7 +165,7 @@ class AzureContainerInstanceHook(AzureBaseHook):
:param name: the name of the container group
:return: ContainerGroup
"""
- return self.connection.container_groups.get(resource_group, name,
raw=False)
+ return self.connection.container_groups.get(resource_group, name)
def get_logs(self, resource_group: str, name: str, tail: int = 1000) ->
list:
"""
@@ -120,7 +176,7 @@ class AzureContainerInstanceHook(AzureBaseHook):
:param tail: the size of the tail
:return: A list of log messages
"""
- logs = self.connection.container.list_logs(resource_group, name, name,
tail=tail)
+ logs = self.connection.containers.list_logs(resource_group, name,
name, tail=tail)
return logs.content.splitlines(True)
def delete(self, resource_group: str, name: str) -> None:
@@ -130,7 +186,7 @@ class AzureContainerInstanceHook(AzureBaseHook):
:param resource_group: the name of the resource group
:param name: the name of the container group
"""
- self.connection.container_groups.delete(resource_group, name)
+ self.connection.container_groups.begin_delete(resource_group, name)
def exists(self, resource_group: str, name: str) -> bool:
"""
diff --git a/airflow/providers/microsoft/azure/operators/container_instances.py
b/airflow/providers/microsoft/azure/operators/container_instances.py
index 8b5fba26cd..efbe8259b1 100644
--- a/airflow/providers/microsoft/azure/operators/container_instances.py
+++ b/airflow/providers/microsoft/azure/operators/container_instances.py
@@ -126,11 +126,11 @@ class AzureContainerInstancesOperator(BaseOperator):
self,
*,
ci_conn_id: str,
- registry_conn_id: str | None,
resource_group: str,
name: str,
image: str,
region: str,
+ registry_conn_id: str | None = None,
environment_variables: dict | None = None,
secured_variables: str | None = None,
volumes: list | None = None,
@@ -295,7 +295,6 @@ class AzureContainerInstancesOperator(BaseOperator):
try:
cg_state = self._ci_hook.get_state(resource_group, name)
instance_view = cg_state.containers[0].instance_view
-
# If there is no instance view, we show the provisioning state
if instance_view is not None:
c_state = instance_view.current_state
diff --git a/airflow/providers/microsoft/azure/provider.yaml
b/airflow/providers/microsoft/azure/provider.yaml
index e3ef3db6f8..2c9868bd84 100644
--- a/airflow/providers/microsoft/azure/provider.yaml
+++ b/airflow/providers/microsoft/azure/provider.yaml
@@ -83,7 +83,7 @@ dependencies:
- azure-kusto-data>=4.1.0
# TODO: upgrade to newer versions of all the below libraries.
# See issue https://github.com/apache/airflow/issues/30199
- - azure-mgmt-containerinstance>=1.5.0,<2.0
+ - azure-mgmt-containerinstance>=7.0.0,<9.0.0
- azure-mgmt-datafactory>=1.0.0,<2.0
integrations:
diff --git a/generated/provider_dependencies.json
b/generated/provider_dependencies.json
index bfe118e289..b82c592b4f 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -559,7 +559,7 @@
"azure-identity>=1.3.1",
"azure-keyvault-secrets>=4.1.0",
"azure-kusto-data>=4.1.0",
- "azure-mgmt-containerinstance>=1.5.0,<2.0",
+ "azure-mgmt-containerinstance>=7.0.0,<9.0.0",
"azure-mgmt-cosmosdb",
"azure-mgmt-datafactory>=1.0.0,<2.0",
"azure-mgmt-datalake-store>=0.5.0",
diff --git
a/tests/providers/microsoft/azure/hooks/test_azure_container_instance.py
b/tests/providers/microsoft/azure/hooks/test_azure_container_instance.py
index 786df4eb16..1e7e19a138 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_container_instance.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_container_instance.py
@@ -52,7 +52,7 @@ class TestAzureContainerInstanceHook:
yield
@patch("azure.mgmt.containerinstance.models.ContainerGroup")
-
@patch("azure.mgmt.containerinstance.operations.ContainerGroupsOperations.create_or_update")
+
@patch("azure.mgmt.containerinstance.operations.ContainerGroupsOperations.begin_create_or_update")
def test_create_or_update(self, create_or_update_mock,
container_group_mock):
self.hook.create_or_update("resource_group", "aci-test",
container_group_mock)
create_or_update_mock.assert_called_once_with("resource_group",
"aci-test", container_group_mock)
@@ -60,9 +60,9 @@ class TestAzureContainerInstanceHook:
@patch("azure.mgmt.containerinstance.operations.ContainerGroupsOperations.get")
def test_get_state(self, get_state_mock):
self.hook.get_state("resource_group", "aci-test")
- get_state_mock.assert_called_once_with("resource_group", "aci-test",
raw=False)
+ get_state_mock.assert_called_once_with("resource_group", "aci-test")
-
@patch("azure.mgmt.containerinstance.operations.ContainerOperations.list_logs")
+
@patch("azure.mgmt.containerinstance.operations.ContainersOperations.list_logs")
def test_get_logs(self, list_logs_mock):
expected_messages = ["log line 1\n", "log line 2\n", "log line 3\n"]
logs = Logs(content="".join(expected_messages))
@@ -72,7 +72,7 @@ class TestAzureContainerInstanceHook:
assert logs == expected_messages
-
@patch("azure.mgmt.containerinstance.operations.ContainerGroupsOperations.delete")
+
@patch("azure.mgmt.containerinstance.operations.ContainerGroupsOperations.begin_delete")
def test_delete(self, delete_mock):
self.hook.delete("resource_group", "aci-test")
delete_mock.assert_called_once_with("resource_group", "aci-test")
diff --git
a/tests/providers/microsoft/azure/operators/test_azure_container_instances.py
b/tests/providers/microsoft/azure/operators/test_azure_container_instances.py
index 6c7cd1dbb5..2984fe60ee 100644
---
a/tests/providers/microsoft/azure/operators/test_azure_container_instances.py
+++
b/tests/providers/microsoft/azure/operators/test_azure_container_instances.py
@@ -22,7 +22,13 @@ from unittest import mock
from unittest.mock import MagicMock
import pytest
-from azure.mgmt.containerinstance.models import ContainerState, Event
+from azure.mgmt.containerinstance.models import (
+ Container,
+ ContainerGroup,
+ ContainerPropertiesInstanceView,
+ ContainerState,
+ Event,
+)
from airflow.exceptions import AirflowException
from airflow.providers.microsoft.azure.operators.container_instances import
AzureContainerInstancesOperator
@@ -35,10 +41,12 @@ def make_mock_cg(container_state, events=None):
"""
events = events or []
instance_view_dict = {"current_state": container_state, "events": events}
- instance_view = namedtuple("InstanceView",
instance_view_dict.keys())(*instance_view_dict.values())
+ instance_view = namedtuple("ContainerPropertiesInstanceView",
instance_view_dict.keys())(
+ *instance_view_dict.values()
+ )
container_dict = {"instance_view": instance_view}
- container = namedtuple("Container",
container_dict.keys())(*container_dict.values())
+ container = namedtuple("Containers",
container_dict.keys())(*container_dict.values())
container_g_dict = {"containers": [container]}
container_g = namedtuple("ContainerGroup",
container_g_dict.keys())(*container_g_dict.values())
@@ -53,23 +61,42 @@ def make_mock_cg_with_missing_events(container_state):
This can happen, when the container group is provisioned, but not started.
"""
instance_view_dict = {"current_state": container_state, "events": None}
- instance_view = namedtuple("InstanceView",
instance_view_dict.keys())(*instance_view_dict.values())
+ instance_view = namedtuple("ContainerPropertiesInstanceView",
instance_view_dict.keys())(
+ *instance_view_dict.values()
+ )
container_dict = {"instance_view": instance_view}
- container = namedtuple("Container",
container_dict.keys())(*container_dict.values())
+ container = namedtuple("Containers",
container_dict.keys())(*container_dict.values())
container_g_dict = {"containers": [container]}
container_g = namedtuple("ContainerGroup",
container_g_dict.keys())(*container_g_dict.values())
return container_g
+def make_mock_container(state: str, exit_code: int, detail_status: str,
events: Event | None = None):
+ container = Container(name="hello_world", image="test", resources="test")
+ container_prop = ContainerPropertiesInstanceView()
+ container_state = ContainerState()
+ container_state.state = state
+ container_state.exit_code = exit_code
+ container_state.detail_status = detail_status
+ container_prop.current_state = container_state
+ if events:
+ container_prop.events = events
+ container.instance_view = container_prop
+
+ cg = ContainerGroup(containers=[container], os_type="Linux")
+
+ return cg
+
+
class TestACIOperator:
@mock.patch("airflow.providers.microsoft.azure.operators.container_instances.AzureContainerInstanceHook")
def test_execute(self, aci_mock):
- expected_c_state = ContainerState(state="Terminated", exit_code=0,
detail_status="test")
- expected_cg = make_mock_cg(expected_c_state)
+ expected_cg = make_mock_container(state="Terminated", exit_code=0,
detail_status="test")
aci_mock.return_value.get_state.return_value = expected_cg
+
aci_mock.return_value.exists.return_value = False
aci = AzureContainerInstancesOperator(
@@ -102,10 +129,10 @@ class TestACIOperator:
@mock.patch("airflow.providers.microsoft.azure.operators.container_instances.AzureContainerInstanceHook")
def test_execute_with_failures(self, aci_mock):
- expected_c_state = ContainerState(state="Terminated", exit_code=1,
detail_status="test")
- expected_cg = make_mock_cg(expected_c_state)
+ expected_cg = make_mock_container(state="Terminated", exit_code=1,
detail_status="test")
aci_mock.return_value.get_state.return_value = expected_cg
+
aci_mock.return_value.exists.return_value = False
aci = AzureContainerInstancesOperator(
@@ -124,11 +151,11 @@ class TestACIOperator:
@mock.patch("airflow.providers.microsoft.azure.operators.container_instances.AzureContainerInstanceHook")
def test_execute_with_tags(self, aci_mock):
- expected_c_state = ContainerState(state="Terminated", exit_code=0,
detail_status="test")
- expected_cg = make_mock_cg(expected_c_state)
- tags = {"testKey": "testValue"}
+ expected_cg = make_mock_container(state="Terminated", exit_code=0,
detail_status="test")
aci_mock.return_value.get_state.return_value = expected_cg
+ tags = {"testKey": "testValue"}
+
aci_mock.return_value.exists.return_value = False
aci = AzureContainerInstancesOperator(
@@ -163,13 +190,18 @@ class TestACIOperator:
@mock.patch("airflow.providers.microsoft.azure.operators.container_instances.AzureContainerInstanceHook")
def test_execute_with_messages_logs(self, aci_mock):
- events = [Event(message="test"), Event(message="messages")]
- expected_c_state1 = ContainerState(state="Succeeded", exit_code=0,
detail_status="test")
- expected_cg1 = make_mock_cg(expected_c_state1, events)
- expected_c_state2 = ContainerState(state="Running", exit_code=0,
detail_status="test")
- expected_cg2 = make_mock_cg(expected_c_state2, events)
- expected_c_state3 = ContainerState(state="Terminated", exit_code=0,
detail_status="test")
- expected_cg3 = make_mock_cg(expected_c_state3, events)
+ event1 = Event()
+ event1.message = "test"
+ event2 = Event()
+ event2.message = "messages"
+ events = [event1, event2]
+ expected_cg1 = make_mock_container(
+ state="Succeeded", exit_code=0, detail_status="test", events=events
+ )
+ expected_cg2 = make_mock_container(state="Running", exit_code=0,
detail_status="test", events=events)
+ expected_cg3 = make_mock_container(
+ state="Terminated", exit_code=0, detail_status="test",
events=events
+ )
aci_mock.return_value.get_state.side_effect = [expected_cg1,
expected_cg2, expected_cg3]
aci_mock.return_value.get_logs.return_value = ["test", "logs"]
@@ -211,11 +243,11 @@ class TestACIOperator:
@mock.patch("airflow.providers.microsoft.azure.operators.container_instances.AzureContainerInstanceHook")
def test_execute_with_ipaddress(self, aci_mock):
- expected_c_state = ContainerState(state="Terminated", exit_code=0,
detail_status="test")
- expected_cg = make_mock_cg(expected_c_state)
ipaddress = MagicMock()
- aci_mock.return_value.get_state.return_value = expected_cg
+ aci_mock.return_value.get_state.return_value = make_mock_container(
+ state="Terminated", exit_code=0, detail_status="test"
+ )
aci_mock.return_value.exists.return_value = False
aci = AzureContainerInstancesOperator(
@@ -236,10 +268,10 @@ class TestACIOperator:
@mock.patch("airflow.providers.microsoft.azure.operators.container_instances.AzureContainerInstanceHook")
def test_execute_with_windows_os_and_diff_restart_policy(self, aci_mock):
- expected_c_state = ContainerState(state="Terminated", exit_code=0,
detail_status="test")
- expected_cg = make_mock_cg(expected_c_state)
- aci_mock.return_value.get_state.return_value = expected_cg
+ aci_mock.return_value.get_state.return_value = make_mock_container(
+ state="Terminated", exit_code=0, detail_status="test"
+ )
aci_mock.return_value.exists.return_value = False
aci = AzureContainerInstancesOperator(
@@ -262,10 +294,10 @@ class TestACIOperator:
@mock.patch("airflow.providers.microsoft.azure.operators.container_instances.AzureContainerInstanceHook")
def test_execute_fails_with_incorrect_os_type(self, aci_mock):
- expected_c_state = ContainerState(state="Terminated", exit_code=0,
detail_status="test")
- expected_cg = make_mock_cg(expected_c_state)
- aci_mock.return_value.get_state.return_value = expected_cg
+ aci_mock.return_value.get_state.return_value = make_mock_container(
+ state="Terminated", exit_code=0, detail_status="test"
+ )
aci_mock.return_value.exists.return_value = False
with pytest.raises(AirflowException) as ctx:
@@ -288,10 +320,10 @@ class TestACIOperator:
@mock.patch("airflow.providers.microsoft.azure.operators.container_instances.AzureContainerInstanceHook")
def test_execute_fails_with_incorrect_restart_policy(self, aci_mock):
- expected_c_state = ContainerState(state="Terminated", exit_code=0,
detail_status="test")
- expected_cg = make_mock_cg(expected_c_state)
- aci_mock.return_value.get_state.return_value = expected_cg
+ aci_mock.return_value.get_state.return_value = make_mock_container(
+ state="Terminated", exit_code=0, detail_status="test"
+ )
aci_mock.return_value.exists.return_value = False
with pytest.raises(AirflowException) as ctx:
@@ -315,10 +347,8 @@ class TestACIOperator:
@mock.patch("airflow.providers.microsoft.azure.operators.container_instances.AzureContainerInstanceHook")
@mock.patch("airflow.providers.microsoft.azure.operators.container_instances.sleep")
def test_execute_correct_sleep_cycle(self, sleep_mock, aci_mock):
- expected_c_state1 = ContainerState(state="Running", exit_code=0,
detail_status="test")
- expected_cg1 = make_mock_cg(expected_c_state1)
- expected_c_state2 = ContainerState(state="Terminated", exit_code=0,
detail_status="test")
- expected_cg2 = make_mock_cg(expected_c_state2)
+ expected_cg1 = make_mock_container(state="Running", exit_code=0,
detail_status="test")
+ expected_cg2 = make_mock_container(state="Terminated", exit_code=0,
detail_status="test")
aci_mock.return_value.get_state.side_effect = [expected_cg1,
expected_cg1, expected_cg2]
aci_mock.return_value.exists.return_value = False
@@ -340,10 +370,8 @@ class TestACIOperator:
@mock.patch("airflow.providers.microsoft.azure.operators.container_instances.AzureContainerInstanceHook")
@mock.patch("logging.Logger.exception")
def test_execute_with_missing_events(self, log_mock, aci_mock):
- expected_c_state1 = ContainerState(state="Running", exit_code=0,
detail_status="test")
- expected_cg1 = make_mock_cg_with_missing_events(expected_c_state1)
- expected_c_state2 = ContainerState(state="Terminated", exit_code=0,
detail_status="test")
- expected_cg2 = make_mock_cg(expected_c_state2)
+ expected_cg1 = make_mock_container(state="Running", exit_code=0,
detail_status="test")
+ expected_cg2 = make_mock_container(state="Terminated", exit_code=0,
detail_status="test")
aci_mock.return_value.get_state.side_effect = [expected_cg1,
expected_cg2]
aci_mock.return_value.exists.return_value = False