This is an automated email from the ASF dual-hosted git repository.
turbaszek pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/master by this push:
new 686e0ee Fix incorrect typing, remove hardcoded argument values and
improve code in AzureContainerInstancesOperator (#11408)
686e0ee is described below
commit 686e0ee7dfb26224e2f91c9af6ef41d59e2f2e96
Author: Ephraim Anierobi <[email protected]>
AuthorDate: Sun Oct 11 15:48:51 2020 +0100
Fix incorrect typing, remove hardcoded argument values and improve code in
AzureContainerInstancesOperator (#11408)
---
.../azure/operators/azure_container_instances.py | 71 +++++++++----
.../operators/test_azure_container_instances.py | 114 +++++++++++++++++++++
2 files changed, 167 insertions(+), 18 deletions(-)
diff --git
a/airflow/providers/microsoft/azure/operators/azure_container_instances.py
b/airflow/providers/microsoft/azure/operators/azure_container_instances.py
index b0ff593..3bf30d9 100644
--- a/airflow/providers/microsoft/azure/operators/azure_container_instances.py
+++ b/airflow/providers/microsoft/azure/operators/azure_container_instances.py
@@ -28,6 +28,8 @@ from azure.mgmt.containerinstance.models import (
ResourceRequests,
ResourceRequirements,
VolumeMount,
+ IpAddress,
+ ContainerPort,
)
from msrestazure.azure_exceptions import CloudError
@@ -88,37 +90,44 @@ class AzureContainerInstancesOperator(BaseOperator):
:param gpu: GPU Resource for the container.
:type gpu: azure.mgmt.containerinstance.models.GpuResource
:param command: the command to run inside the container
- :type command: Optional[str]
+ :type command: Optional[List[str]]
:param container_timeout: max time allowed for the execution of
the container instance.
:type container_timeout: datetime.timedelta
:param tags: azure tags as dict of str:str
:type tags: Optional[dict[str, str]]
+ :param os_type: The operating system type required by the containers
+ in the container group. Possible values include: 'Windows', 'Linux'
+ :type os_type: str
+ :param restart_policy: Restart policy for all containers within the
container group.
+ Possible values include: 'Always', 'OnFailure', 'Never'
+ :type restart_policy: str
+ :param ip_address: The IP address type of the container group.
+ :type ip_address: IpAddress
**Example**::
AzureContainerInstancesOperator(
- "azure_service_principal",
- "azure_registry_user",
- "my-resource-group",
- "my-container-name-{{ ds }}",
- "myprivateregistry.azurecr.io/my_container:latest",
- "westeurope",
- {"MODEL_PATH": "my_value",
+ ci_conn_id = "azure_service_principal",
+ registry_conn_id = "azure_registry_user",
+ resource_group = "my-resource-group",
+ name = "my-container-name-{{ ds }}",
+ image = "myprivateregistry.azurecr.io/my_container:latest",
+ region = "westeurope",
+ environment_variables = {"MODEL_PATH": "my_value",
"POSTGRES_LOGIN": "{{
macros.connection('postgres_default').login }}",
"POSTGRES_PASSWORD": "{{
macros.connection('postgres_default').password }}",
"JOB_GUID": "{{ ti.xcom_pull(task_ids='task1',
key='guid') }}" },
- ['POSTGRES_PASSWORD'],
- [("azure_wasb_conn_id",
- "my_storage_container",
- "my_fileshare",
- "/input-data",
- True),],
+ secured_variables = ['POSTGRES_PASSWORD'],
+ volumes = [("azure_wasb_conn_id",
+ "my_storage_container",
+ "my_fileshare",
+ "/input-data",
+ True),],
memory_in_gb=14.0,
cpu=4.0,
gpu=GpuResource(count=1, sku='K80'),
command=["/bin/echo", "world"],
- container_timeout=timedelta(hours=2),
task_id="start_container"
)
"""
@@ -142,10 +151,14 @@ class AzureContainerInstancesOperator(BaseOperator):
memory_in_gb: Optional[Any] = None,
cpu: Optional[Any] = None,
gpu: Optional[Any] = None,
- command: Optional[str] = None,
+ command: Optional[List[str]] = None,
remove_on_error: bool = True,
fail_if_exists: bool = True,
tags: Optional[Dict[str, str]] = None,
+ os_type: str = 'Linux',
+ restart_policy: str = 'Never',
+ ip_address: Optional[IpAddress] = None,
+ ports: Optional[List[ContainerPort]] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -167,6 +180,22 @@ class AzureContainerInstancesOperator(BaseOperator):
self.fail_if_exists = fail_if_exists
self._ci_hook: Any = None
self.tags = tags
+ self.os_type = os_type
+ if self.os_type not in ['Linux', 'Windows']:
+ raise AirflowException(
+ "Invalid value for the os_type argument. "
+ "Please set 'Linux' or 'Windows' as the os_type. "
+ f"Found `{self.os_type}`."
+ )
+ self.restart_policy = restart_policy
+ if self.restart_policy not in ['Always', 'OnFailure', 'Never']:
+ raise AirflowException(
+ "Invalid value for the restart_policy argument. "
+ "Please set one of 'Always', 'OnFailure','Never' as the
restart_policy. "
+ f"Found `{self.restart_policy}`"
+ )
+ self.ip_address = ip_address
+ self.ports = ports
def execute(self, context: dict) -> int:
# Check name again in case it was templated.
@@ -214,6 +243,10 @@ class AzureContainerInstancesOperator(BaseOperator):
requests=ResourceRequests(memory_in_gb=self.memory_in_gb,
cpu=self.cpu, gpu=self.gpu)
)
+ if self.ip_address and not self.ports:
+ self.ports = [ContainerPort(port=80)]
+ self.log.info("Default port set. Container will listen on port
80")
+
container = Container(
name=self.name,
image=self.image,
@@ -221,6 +254,7 @@ class AzureContainerInstancesOperator(BaseOperator):
command=self.command,
environment_variables=environment_variables,
volume_mounts=volume_mounts,
+ ports=self.ports,
)
container_group = ContainerGroup(
@@ -230,9 +264,10 @@ class AzureContainerInstancesOperator(BaseOperator):
],
image_registry_credentials=image_registry_credentials,
volumes=volumes,
- restart_policy='Never',
- os_type='Linux',
+ restart_policy=self.restart_policy,
+ os_type=self.os_type,
tags=self.tags,
+ ip_address=self.ip_address,
)
self._ci_hook.create_or_update(self.resource_group, self.name,
container_group)
diff --git
a/tests/providers/microsoft/azure/operators/test_azure_container_instances.py
b/tests/providers/microsoft/azure/operators/test_azure_container_instances.py
index dd15558..cf20b91 100644
---
a/tests/providers/microsoft/azure/operators/test_azure_container_instances.py
+++
b/tests/providers/microsoft/azure/operators/test_azure_container_instances.py
@@ -19,6 +19,7 @@
import unittest
from collections import namedtuple
+from unittest.mock import MagicMock
import mock
from azure.mgmt.containerinstance.models import ContainerState, Event
@@ -197,3 +198,116 @@ class TestACIOperator(unittest.TestCase):
for name in valid_names:
checked_name = AzureContainerInstancesOperator._check_name(name)
self.assertEqual(checked_name, name)
+
+ @mock.patch(
+
"airflow.providers.microsoft.azure.operators.azure_container_instances.AzureContainerInstanceHook"
+ )
+ def test_execute_with_ipaddress(self, aci_mock):
+ expected_c_state = ContainerState(state='Terminated', exit_code=0,
detail_status='test')
+ expected_cg = make_mock_cg(expected_c_state)
+ ipaddress = MagicMock()
+
+ aci_mock.return_value.get_state.return_value = expected_cg
+ aci_mock.return_value.exists.return_value = False
+
+ aci = AzureContainerInstancesOperator(
+ ci_conn_id=None,
+ registry_conn_id=None,
+ resource_group='resource-group',
+ name='container-name',
+ image='container-image',
+ region='region',
+ task_id='task',
+ ip_address=ipaddress,
+ )
+ aci.execute(None)
+ self.assertEqual(aci_mock.return_value.create_or_update.call_count, 1)
+ (_, _, called_cg), _ = aci_mock.return_value.create_or_update.call_args
+
+ self.assertEqual(called_cg.ip_address, ipaddress)
+
+ @mock.patch(
+
"airflow.providers.microsoft.azure.operators.azure_container_instances.AzureContainerInstanceHook"
+ )
+ def test_execute_with_windows_os_and_diff_restart_policy(self, aci_mock):
+ expected_c_state = ContainerState(state='Terminated', exit_code=0,
detail_status='test')
+ expected_cg = make_mock_cg(expected_c_state)
+
+ aci_mock.return_value.get_state.return_value = expected_cg
+ aci_mock.return_value.exists.return_value = False
+
+ aci = AzureContainerInstancesOperator(
+ ci_conn_id=None,
+ registry_conn_id=None,
+ resource_group='resource-group',
+ name='container-name',
+ image='container-image',
+ region='region',
+ task_id='task',
+ restart_policy="Always",
+ os_type='Windows',
+ )
+ aci.execute(None)
+ self.assertEqual(aci_mock.return_value.create_or_update.call_count, 1)
+ (_, _, called_cg), _ = aci_mock.return_value.create_or_update.call_args
+
+ self.assertEqual(called_cg.restart_policy, 'Always')
+ self.assertEqual(called_cg.os_type, 'Windows')
+
+ @mock.patch(
+
"airflow.providers.microsoft.azure.operators.azure_container_instances.AzureContainerInstanceHook"
+ )
+ def test_execute_fails_with_incorrect_os_type(self, aci_mock):
+ expected_c_state = ContainerState(state='Terminated', exit_code=0,
detail_status='test')
+ expected_cg = make_mock_cg(expected_c_state)
+
+ aci_mock.return_value.get_state.return_value = expected_cg
+ aci_mock.return_value.exists.return_value = False
+
+ with self.assertRaises(AirflowException) as e:
+ AzureContainerInstancesOperator(
+ ci_conn_id=None,
+ registry_conn_id=None,
+ resource_group='resource-group',
+ name='container-name',
+ image='container-image',
+ region='region',
+ task_id='task',
+ os_type='MacOs',
+ )
+
+ self.assertEqual(
+ str(e.exception),
+ "Invalid value for the os_type argument. "
+ "Please set 'Linux' or 'Windows' as the os_type. "
+ "Found `MacOs`.",
+ )
+
+ @mock.patch(
+
"airflow.providers.microsoft.azure.operators.azure_container_instances.AzureContainerInstanceHook"
+ )
+ def test_execute_fails_with_incorrect_restart_policy(self, aci_mock):
+ expected_c_state = ContainerState(state='Terminated', exit_code=0,
detail_status='test')
+ expected_cg = make_mock_cg(expected_c_state)
+
+ aci_mock.return_value.get_state.return_value = expected_cg
+ aci_mock.return_value.exists.return_value = False
+
+ with self.assertRaises(AirflowException) as e:
+ AzureContainerInstancesOperator(
+ ci_conn_id=None,
+ registry_conn_id=None,
+ resource_group='resource-group',
+ name='container-name',
+ image='container-image',
+ region='region',
+ task_id='task',
+ restart_policy='Everyday',
+ )
+
+ self.assertEqual(
+ str(e.exception),
+ "Invalid value for the restart_policy argument. "
+ "Please set one of 'Always', 'OnFailure','Never' as the
restart_policy. "
+ "Found `Everyday`",
+ )