This is an automated email from the ASF dual-hosted git repository.

turbaszek pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new 686e0ee  Fix incorrect typing, remove hardcoded argument values and 
improve code in AzureContainerInstancesOperator (#11408)
686e0ee is described below

commit 686e0ee7dfb26224e2f91c9af6ef41d59e2f2e96
Author: Ephraim Anierobi <[email protected]>
AuthorDate: Sun Oct 11 15:48:51 2020 +0100

    Fix incorrect typing, remove hardcoded argument values and improve code in 
AzureContainerInstancesOperator (#11408)
---
 .../azure/operators/azure_container_instances.py   |  71 +++++++++----
 .../operators/test_azure_container_instances.py    | 114 +++++++++++++++++++++
 2 files changed, 167 insertions(+), 18 deletions(-)

diff --git 
a/airflow/providers/microsoft/azure/operators/azure_container_instances.py 
b/airflow/providers/microsoft/azure/operators/azure_container_instances.py
index b0ff593..3bf30d9 100644
--- a/airflow/providers/microsoft/azure/operators/azure_container_instances.py
+++ b/airflow/providers/microsoft/azure/operators/azure_container_instances.py
@@ -28,6 +28,8 @@ from azure.mgmt.containerinstance.models import (
     ResourceRequests,
     ResourceRequirements,
     VolumeMount,
+    IpAddress,
+    ContainerPort,
 )
 from msrestazure.azure_exceptions import CloudError
 
@@ -88,37 +90,44 @@ class AzureContainerInstancesOperator(BaseOperator):
     :param gpu: GPU Resource for the container.
     :type gpu: azure.mgmt.containerinstance.models.GpuResource
     :param command: the command to run inside the container
-    :type command: Optional[str]
+    :type command: Optional[List[str]]
     :param container_timeout: max time allowed for the execution of
         the container instance.
     :type container_timeout: datetime.timedelta
     :param tags: azure tags as dict of str:str
     :type tags: Optional[dict[str, str]]
+    :param os_type: The operating system type required by the containers
+        in the container group. Possible values include: 'Windows', 'Linux'
+    :type os_type: str
+    :param restart_policy: Restart policy for all containers within the 
container group.
+        Possible values include: 'Always', 'OnFailure', 'Never'
+    :type restart_policy: str
+    :param ip_address: The IP address type of the container group.
+    :type ip_address: IpAddress
 
     **Example**::
 
                 AzureContainerInstancesOperator(
-                    "azure_service_principal",
-                    "azure_registry_user",
-                    "my-resource-group",
-                    "my-container-name-{{ ds }}",
-                    "myprivateregistry.azurecr.io/my_container:latest",
-                    "westeurope",
-                    {"MODEL_PATH":  "my_value",
+                    ci_conn_id = "azure_service_principal",
+                    registry_conn_id = "azure_registry_user",
+                    resource_group = "my-resource-group",
+                    name = "my-container-name-{{ ds }}",
+                    image = "myprivateregistry.azurecr.io/my_container:latest",
+                    region = "westeurope",
+                    environment_variables = {"MODEL_PATH":  "my_value",
                      "POSTGRES_LOGIN": "{{ 
macros.connection('postgres_default').login }}",
                      "POSTGRES_PASSWORD": "{{ 
macros.connection('postgres_default').password }}",
                      "JOB_GUID": "{{ ti.xcom_pull(task_ids='task1', 
key='guid') }}" },
-                    ['POSTGRES_PASSWORD'],
-                    [("azure_wasb_conn_id",
-                    "my_storage_container",
-                    "my_fileshare",
-                    "/input-data",
-                    True),],
+                    secured_variables = ['POSTGRES_PASSWORD'],
+                    volumes = [("azure_wasb_conn_id",
+                            "my_storage_container",
+                            "my_fileshare",
+                            "/input-data",
+                        True),],
                     memory_in_gb=14.0,
                     cpu=4.0,
                     gpu=GpuResource(count=1, sku='K80'),
                     command=["/bin/echo", "world"],
-                    container_timeout=timedelta(hours=2),
                     task_id="start_container"
                 )
     """
@@ -142,10 +151,14 @@ class AzureContainerInstancesOperator(BaseOperator):
         memory_in_gb: Optional[Any] = None,
         cpu: Optional[Any] = None,
         gpu: Optional[Any] = None,
-        command: Optional[str] = None,
+        command: Optional[List[str]] = None,
         remove_on_error: bool = True,
         fail_if_exists: bool = True,
         tags: Optional[Dict[str, str]] = None,
+        os_type: str = 'Linux',
+        restart_policy: str = 'Never',
+        ip_address: Optional[IpAddress] = None,
+        ports: Optional[List[ContainerPort]] = None,
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
@@ -167,6 +180,22 @@ class AzureContainerInstancesOperator(BaseOperator):
         self.fail_if_exists = fail_if_exists
         self._ci_hook: Any = None
         self.tags = tags
+        self.os_type = os_type
+        if self.os_type not in ['Linux', 'Windows']:
+            raise AirflowException(
+                "Invalid value for the os_type argument. "
+                "Please set 'Linux' or 'Windows' as the os_type. "
+                f"Found `{self.os_type}`."
+            )
+        self.restart_policy = restart_policy
+        if self.restart_policy not in ['Always', 'OnFailure', 'Never']:
+            raise AirflowException(
+                "Invalid value for the restart_policy argument. "
+                "Please set one of 'Always', 'OnFailure','Never' as the 
restart_policy. "
+                f"Found `{self.restart_policy}`"
+            )
+        self.ip_address = ip_address
+        self.ports = ports
 
     def execute(self, context: dict) -> int:
         # Check name again in case it was templated.
@@ -214,6 +243,10 @@ class AzureContainerInstancesOperator(BaseOperator):
                 requests=ResourceRequests(memory_in_gb=self.memory_in_gb, 
cpu=self.cpu, gpu=self.gpu)
             )
 
+            if self.ip_address and not self.ports:
+                self.ports = [ContainerPort(port=80)]
+                self.log.info("Default port set. Container will listen on port 
80")
+
             container = Container(
                 name=self.name,
                 image=self.image,
@@ -221,6 +254,7 @@ class AzureContainerInstancesOperator(BaseOperator):
                 command=self.command,
                 environment_variables=environment_variables,
                 volume_mounts=volume_mounts,
+                ports=self.ports,
             )
 
             container_group = ContainerGroup(
@@ -230,9 +264,10 @@ class AzureContainerInstancesOperator(BaseOperator):
                 ],
                 image_registry_credentials=image_registry_credentials,
                 volumes=volumes,
-                restart_policy='Never',
-                os_type='Linux',
+                restart_policy=self.restart_policy,
+                os_type=self.os_type,
                 tags=self.tags,
+                ip_address=self.ip_address,
             )
 
             self._ci_hook.create_or_update(self.resource_group, self.name, 
container_group)
diff --git 
a/tests/providers/microsoft/azure/operators/test_azure_container_instances.py 
b/tests/providers/microsoft/azure/operators/test_azure_container_instances.py
index dd15558..cf20b91 100644
--- 
a/tests/providers/microsoft/azure/operators/test_azure_container_instances.py
+++ 
b/tests/providers/microsoft/azure/operators/test_azure_container_instances.py
@@ -19,6 +19,7 @@
 
 import unittest
 from collections import namedtuple
+from unittest.mock import MagicMock
 
 import mock
 from azure.mgmt.containerinstance.models import ContainerState, Event
@@ -197,3 +198,116 @@ class TestACIOperator(unittest.TestCase):
         for name in valid_names:
             checked_name = AzureContainerInstancesOperator._check_name(name)
             self.assertEqual(checked_name, name)
+
+    @mock.patch(
+        
"airflow.providers.microsoft.azure.operators.azure_container_instances.AzureContainerInstanceHook"
+    )
+    def test_execute_with_ipaddress(self, aci_mock):
+        expected_c_state = ContainerState(state='Terminated', exit_code=0, 
detail_status='test')
+        expected_cg = make_mock_cg(expected_c_state)
+        ipaddress = MagicMock()
+
+        aci_mock.return_value.get_state.return_value = expected_cg
+        aci_mock.return_value.exists.return_value = False
+
+        aci = AzureContainerInstancesOperator(
+            ci_conn_id=None,
+            registry_conn_id=None,
+            resource_group='resource-group',
+            name='container-name',
+            image='container-image',
+            region='region',
+            task_id='task',
+            ip_address=ipaddress,
+        )
+        aci.execute(None)
+        self.assertEqual(aci_mock.return_value.create_or_update.call_count, 1)
+        (_, _, called_cg), _ = aci_mock.return_value.create_or_update.call_args
+
+        self.assertEqual(called_cg.ip_address, ipaddress)
+
+    @mock.patch(
+        
"airflow.providers.microsoft.azure.operators.azure_container_instances.AzureContainerInstanceHook"
+    )
+    def test_execute_with_windows_os_and_diff_restart_policy(self, aci_mock):
+        expected_c_state = ContainerState(state='Terminated', exit_code=0, 
detail_status='test')
+        expected_cg = make_mock_cg(expected_c_state)
+
+        aci_mock.return_value.get_state.return_value = expected_cg
+        aci_mock.return_value.exists.return_value = False
+
+        aci = AzureContainerInstancesOperator(
+            ci_conn_id=None,
+            registry_conn_id=None,
+            resource_group='resource-group',
+            name='container-name',
+            image='container-image',
+            region='region',
+            task_id='task',
+            restart_policy="Always",
+            os_type='Windows',
+        )
+        aci.execute(None)
+        self.assertEqual(aci_mock.return_value.create_or_update.call_count, 1)
+        (_, _, called_cg), _ = aci_mock.return_value.create_or_update.call_args
+
+        self.assertEqual(called_cg.restart_policy, 'Always')
+        self.assertEqual(called_cg.os_type, 'Windows')
+
+    @mock.patch(
+        
"airflow.providers.microsoft.azure.operators.azure_container_instances.AzureContainerInstanceHook"
+    )
+    def test_execute_fails_with_incorrect_os_type(self, aci_mock):
+        expected_c_state = ContainerState(state='Terminated', exit_code=0, 
detail_status='test')
+        expected_cg = make_mock_cg(expected_c_state)
+
+        aci_mock.return_value.get_state.return_value = expected_cg
+        aci_mock.return_value.exists.return_value = False
+
+        with self.assertRaises(AirflowException) as e:
+            AzureContainerInstancesOperator(
+                ci_conn_id=None,
+                registry_conn_id=None,
+                resource_group='resource-group',
+                name='container-name',
+                image='container-image',
+                region='region',
+                task_id='task',
+                os_type='MacOs',
+            )
+
+        self.assertEqual(
+            str(e.exception),
+            "Invalid value for the os_type argument. "
+            "Please set 'Linux' or 'Windows' as the os_type. "
+            "Found `MacOs`.",
+        )
+
+    @mock.patch(
+        
"airflow.providers.microsoft.azure.operators.azure_container_instances.AzureContainerInstanceHook"
+    )
+    def test_execute_fails_with_incorrect_restart_policy(self, aci_mock):
+        expected_c_state = ContainerState(state='Terminated', exit_code=0, 
detail_status='test')
+        expected_cg = make_mock_cg(expected_c_state)
+
+        aci_mock.return_value.get_state.return_value = expected_cg
+        aci_mock.return_value.exists.return_value = False
+
+        with self.assertRaises(AirflowException) as e:
+            AzureContainerInstancesOperator(
+                ci_conn_id=None,
+                registry_conn_id=None,
+                resource_group='resource-group',
+                name='container-name',
+                image='container-image',
+                region='region',
+                task_id='task',
+                restart_policy='Everyday',
+            )
+
+        self.assertEqual(
+            str(e.exception),
+            "Invalid value for the restart_policy argument. "
+            "Please set one of 'Always', 'OnFailure','Never' as the 
restart_policy. "
+            "Found `Everyday`",
+        )

Reply via email to