Fokko closed pull request #4121: [AIRFLOW-2568] Azure Container Instances 
operator
URL: https://github.com/apache/incubator-airflow/pull/4121
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git 
a/airflow/contrib/example_dags/example_azure_container_instances_operator.py 
b/airflow/contrib/example_dags/example_azure_container_instances_operator.py
new file mode 100644
index 0000000000..181a30b50e
--- /dev/null
+++ b/airflow/contrib/example_dags/example_azure_container_instances_operator.py
@@ -0,0 +1,54 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from airflow import DAG
+from airflow.contrib.operators.azure_container_instances_operator import 
AzureContainerInstancesOperator
+from datetime import datetime, timedelta
+
+default_args = {
+    'owner': 'airflow',
+    'depends_on_past': False,
+    'start_date': datetime(2018, 11, 1),
+    'email': ['airf...@example.com'],
+    'email_on_failure': False,
+    'email_on_retry': False,
+    'retries': 1,
+    'retry_delay': timedelta(minutes=5),
+}
+
+dag = DAG(
+    'aci_example',
+    default_args=default_args,
+    schedule_interval=timedelta(1)
+)
+
+t1 = AzureContainerInstancesOperator(
+    ci_conn_id='azure_container_instances_default',
+    registry_conn_id=None,
+    resource_group='resource-group',
+    name='aci-test-{{ ds }}',
+    image='hello-world',
+    region='WestUS2',
+    environment_variables={},
+    volumes=[],
+    memory_in_gb=4.0,
+    cpu=1.0,
+    task_id='start_container',
+    dag=dag
+)
diff --git a/airflow/contrib/hooks/azure_container_instance_hook.py 
b/airflow/contrib/hooks/azure_container_instance_hook.py
new file mode 100644
index 0000000000..5ad64de6d7
--- /dev/null
+++ b/airflow/contrib/hooks/azure_container_instance_hook.py
@@ -0,0 +1,167 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+
+from airflow.hooks.base_hook import BaseHook
+from airflow.exceptions import AirflowException
+
+from azure.common.client_factory import get_client_from_auth_file
+from azure.common.credentials import ServicePrincipalCredentials
+
+from azure.mgmt.containerinstance import ContainerInstanceManagementClient
+
+
+class AzureContainerInstanceHook(BaseHook):
+    """
+    A hook to communicate with Azure Container Instances.
+
+    This hook requires a service principal in order to work.
+    After creating this service principal
+    (Azure Active Directory/App Registrations), you need to fill in the
+    client_id (Application ID) as login, the generated password as password,
+    and tenantId and subscriptionId in the extra's field as a json.
+
+    :param conn_id: connection id of a service principal which will be used
+        to start the container instance
+    :type conn_id: str
+    """
+
+    def __init__(self, conn_id='azure_default'):
+        self.conn_id = conn_id
+        self.connection = self.get_conn()
+
+    def get_conn(self):
+        conn = self.get_connection(self.conn_id)
+        key_path = conn.extra_dejson.get('key_path', False)
+        if key_path:
+            if key_path.endswith('.json'):
+                self.log.info('Getting connection using a JSON key file.')
+                return 
get_client_from_auth_file(ContainerInstanceManagementClient,
+                                                 key_path)
+            else:
+                raise AirflowException('Unrecognised extension for key file.')
+
+        if os.environ.get('AZURE_AUTH_LOCATION'):
+            key_path = os.environ.get('AZURE_AUTH_LOCATION')
+            if key_path.endswith('.json'):
+                self.log.info('Getting connection using a JSON key file.')
+                return 
get_client_from_auth_file(ContainerInstanceManagementClient,
+                                                 key_path)
+            else:
+                raise AirflowException('Unrecognised extension for key file.')
+
+        credentials = ServicePrincipalCredentials(
+            client_id=conn.login,
+            secret=conn.password,
+            tenant=conn.extra_dejson['tenantId']
+        )
+
+        subscription_id = conn.extra_dejson['subscriptionId']
+        return ContainerInstanceManagementClient(credentials, 
str(subscription_id))
+
+    def create_or_update(self, resource_group, name, container_group):
+        """
+        Create a new container group
+
+        :param resource_group: the name of the resource group
+        :type resource_group: str
+        :param name: the name of the container group
+        :type name: str
+        :param container_group: the properties of the container group
+        :type container_group: 
azure.mgmt.containerinstance.models.ContainerGroup
+        """
+        self.connection.container_groups.create_or_update(resource_group,
+                                                          name,
+                                                          container_group)
+
+    def get_state_exitcode_details(self, resource_group, name):
+        """
+        Get the state and exitcode of a container group
+
+        :param resource_group: the name of the resource group
+        :type resource_group: str
+        :param name: the name of the container group
+        :type name: str
+        :return: A tuple with the state, exitcode, and details.
+        If the exitcode is unknown 0 is returned.
+        :rtype: tuple(state,exitcode,details)
+        """
+        current_state = self._get_instance_view(resource_group, 
name).current_state
+        return (current_state.state,
+                current_state.exit_code,
+                current_state.detail_status)
+
+    def _get_instance_view(self, resource_group, name):
+        response = self.connection.container_groups.get(resource_group,
+                                                        name,
+                                                        raw=False)
+        return response.containers[0].instance_view.current_state
+
+    def get_messages(self, resource_group, name):
+        """
+        Get the messages of a container group
+
+        :param resource_group: the name of the resource group
+        :type resource_group: str
+        :param name: the name of the container group
+        :type name: str
+        :return: A list of the event messages
+        :rtype: list<str>
+        """
+        instance_view = self._get_instance_view(resource_group, name)
+
+        return [event.message for event in instance_view.events]
+
+    def get_logs(self, resource_group, name, tail=1000):
+        """
+        Get the tail from logs of a container group
+
+        :param resource_group: the name of the resource group
+        :type resource_group: str
+        :param name: the name of the container group
+        :type name: str
+        :param tail: the size of the tail
+        :type tail: int
+        :return: A list of log messages
+        :rtype: list<str>
+        """
+        logs = self.connection.container.list_logs(resource_group, name, name, 
tail=tail)
+        return logs.content.splitlines(True)
+
+    def delete(self, resource_group, name):
+        """
+        Delete a container group
+
+        :param resource_group: the name of the resource group
+        :type resource_group: str
+        :param name: the name of the container group
+        :type name: str
+        """
+        self.connection.container_groups.delete(resource_group, name)
+
+    def exists(self, resource_group, name):
+        """
+        Test if a container group exists
+
+        :param resource_group: the name of the resource group
+        :type resource_group: str
+        :param name: the name of the container group
+        :type name: str
+        """
+        for container in 
self.connection.container_groups.list_by_resource_group(resource_group):
+            if container.name == name:
+                return True
+        return False
diff --git a/airflow/contrib/hooks/azure_container_registry_hook.py 
b/airflow/contrib/hooks/azure_container_registry_hook.py
new file mode 100644
index 0000000000..af38c1a943
--- /dev/null
+++ b/airflow/contrib/hooks/azure_container_registry_hook.py
@@ -0,0 +1,39 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from airflow.hooks.base_hook import BaseHook
+from azure.mgmt.containerinstance.models import ImageRegistryCredential
+
+
+class AzureContainerRegistryHook(BaseHook):
+    """
+    A hook to communicate with a Azure Container Registry.
+
+    :param conn_id: connection id of a service principal which will be used
+        to start the container instance
+    :type conn_id: str
+    """
+
+    def __init__(self, conn_id='azure_registry'):
+        self.conn_id = conn_id
+        self.connection = self.get_conn()
+
+    def get_conn(self):
+        conn = self.get_connection(self.conn_id)
+        return ImageRegistryCredential(server=conn.host, username=conn.login, 
password=conn.password)
diff --git a/airflow/contrib/hooks/azure_container_volume_hook.py 
b/airflow/contrib/hooks/azure_container_volume_hook.py
new file mode 100644
index 0000000000..5bf3491064
--- /dev/null
+++ b/airflow/contrib/hooks/azure_container_volume_hook.py
@@ -0,0 +1,54 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from airflow.hooks.base_hook import BaseHook
+from azure.mgmt.containerinstance.models import (Volume,
+                                                 AzureFileVolume)
+
+
+class AzureContainerVolumeHook(BaseHook):
+    """
+    A hook which wraps an Azure Volume.
+
+    :param wasb_conn_id: connection id of a Azure storage account of
+    which file shares should be mounted
+    :type wasb_conn_id: str
+    """
+
+    def __init__(self, wasb_conn_id='wasb_default'):
+        self.conn_id = wasb_conn_id
+
+    def get_storagekey(self):
+        conn = self.get_connection(self.conn_id)
+        service_options = conn.extra_dejson
+
+        if 'connection_string' in service_options:
+            for keyvalue in service_options['connection_string'].split(";"):
+                key, value = keyvalue.split("=", 1)
+                if key == "AccountKey":
+                    return value
+        return conn.password
+
+    def get_file_volume(self, mount_name, share_name,
+                        storage_account_name, read_only=False):
+        return Volume(name=mount_name,
+                      azure_file=AzureFileVolume(share_name=share_name,
+                                                 
storage_account_name=storage_account_name,
+                                                 read_only=read_only,
+                                                 
storage_account_key=self.get_storagekey()))
diff --git a/airflow/contrib/operators/azure_container_instances_operator.py 
b/airflow/contrib/operators/azure_container_instances_operator.py
new file mode 100644
index 0000000000..8b64bb1863
--- /dev/null
+++ b/airflow/contrib/operators/azure_container_instances_operator.py
@@ -0,0 +1,257 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from time import sleep
+
+from airflow.contrib.hooks.azure_container_instance_hook import 
AzureContainerInstanceHook
+from airflow.contrib.hooks.azure_container_registry_hook import 
AzureContainerRegistryHook
+from airflow.contrib.hooks.azure_container_volume_hook import 
AzureContainerVolumeHook
+
+from airflow.exceptions import AirflowException, AirflowTaskTimeout
+from airflow.models import BaseOperator
+from airflow.utils.decorators import apply_defaults
+
+from azure.mgmt.containerinstance.models import (EnvironmentVariable,
+                                                 VolumeMount,
+                                                 ResourceRequests,
+                                                 ResourceRequirements,
+                                                 Container,
+                                                 ContainerGroup)
+from msrestazure.azure_exceptions import CloudError
+
+
+DEFAULT_ENVIRONMENT_VARIABLES = {}
+DEFAULT_VOLUMES = []
+DEFAULT_MEMORY_IN_GB = 2.0
+DEFAULT_CPU = 1.0
+
+
+class AzureContainerInstancesOperator(BaseOperator):
+    """
+    Start a container on Azure Container Instances
+
+    :param ci_conn_id: connection id of a service principal which will be used
+        to start the container instance
+    :type ci_conn_id: str
+    :param registry_conn_id: connection id of a user which can login to a
+        private docker registry. If None, we assume a public registry
+    :type registry_conn_id: str
+    :param resource_group: name of the resource group wherein this container
+        instance should be started
+    :type resource_group: str
+    :param name: name of this container instance. Please note this name has
+        to be unique in order to run containers in parallel.
+    :type name: str
+    :param image: the docker image to be used
+    :type image: str
+    :param region: the region wherein this container instance should be started
+    :type region: str
+    :param: environment_variables: key,value pairs containing environment 
variables
+        which will be passed to the running container
+    :type: environment_variables: dict
+    :param: volumes: list of volumes to be mounted to the container.
+        Currently only Azure Fileshares are supported.
+    :type: volumes: list[<conn_id, account_name, share_name, mount_path, 
read_only>]
+    :param: memory_in_gb: the amount of memory to allocate to this container
+    :type: memory_in_gb: double
+    :param: cpu: the number of cpus to allocate to this container
+    :type: cpu: double
+    :param: command: the command to run inside the container
+    :type: command: str
+
+    :Example:
+
+    >>>  a = AzureContainerInstancesOperator(
+                'azure_service_principal',
+                'azure_registry_user',
+                'my-resource-group',
+                'my-container-name-{{ ds }}',
+                'myprivateregistry.azurecr.io/my_container:latest',
+                'westeurope',
+                {'EXECUTION_DATE': '{{ ds }}'},
+                [('azure_wasb_conn_id',
+                  'my_storage_container',
+                  'my_fileshare',
+                  '/input-data',
+                  True),],
+                memory_in_gb=14.0,
+                cpu=4.0,
+                command='python /app/myfile.py',
+                task_id='start_container'
+            )
+    """
+
+    template_fields = ('name', 'environment_variables')
+    template_ext = tuple()
+
+    @apply_defaults
+    def __init__(self, ci_conn_id, registry_conn_id, resource_group, name, 
image, region,
+                 environment_variables=None, volumes=None, memory_in_gb=None, 
cpu=None,
+                 command=None, remove_on_error=True, fail_if_exists=True, 
*args, **kwargs):
+        super(AzureContainerInstancesOperator, self).__init__(*args, **kwargs)
+
+        self.ci_conn_id = ci_conn_id
+        self.resource_group = resource_group
+        self.name = name
+        self.image = image
+        self.region = region
+        self.registry_conn_id = registry_conn_id
+        self.environment_variables = environment_variables or 
DEFAULT_ENVIRONMENT_VARIABLES
+        self.volumes = volumes or DEFAULT_VOLUMES
+        self.memory_in_gb = memory_in_gb or DEFAULT_MEMORY_IN_GB
+        self.cpu = cpu or DEFAULT_CPU
+        self.command = command
+        self.remove_on_error = remove_on_error
+        self.fail_if_exists = fail_if_exists
+
+    def execute(self, context):
+        ci_hook = AzureContainerInstanceHook(self.ci_conn_id)
+
+        if self.fail_if_exists:
+            self.log.info("Testing if container group already exists")
+            if ci_hook.exists(self.resource_group, self.name):
+                raise AirflowException("Container group exists")
+
+        if self.registry_conn_id:
+            registry_hook = AzureContainerRegistryHook(self.registry_conn_id)
+            image_registry_credentials = [registry_hook.connection, ]
+        else:
+            image_registry_credentials = None
+
+        environment_variables = []
+        for key, value in self.environment_variables.items():
+            environment_variables.append(EnvironmentVariable(key, value))
+
+        volumes = []
+        volume_mounts = []
+        for conn_id, account_name, share_name, mount_path, read_only in 
self.volumes:
+            hook = AzureContainerVolumeHook(conn_id)
+
+            mount_name = "mount-%d" % len(volumes)
+            volumes.append(hook.get_file_volume(mount_name,
+                                                share_name,
+                                                account_name,
+                                                read_only))
+            volume_mounts.append(VolumeMount(mount_name, mount_path, 
read_only))
+
+        exit_code = 1
+        try:
+            self.log.info("Starting container group with %.1f cpu %.1f mem",
+                          self.cpu, self.memory_in_gb)
+
+            resources = ResourceRequirements(requests=ResourceRequests(
+                memory_in_gb=self.memory_in_gb,
+                cpu=self.cpu))
+
+            container = Container(
+                name=self.name,
+                image=self.image,
+                resources=resources,
+                command=self.command,
+                environment_variables=environment_variables,
+                volume_mounts=volume_mounts)
+
+            container_group = ContainerGroup(
+                location=self.region,
+                containers=[container, ],
+                image_registry_credentials=image_registry_credentials,
+                volumes=volumes,
+                restart_policy='Never',
+                os_type='Linux')
+
+            ci_hook.create_or_update(self.resource_group, self.name, 
container_group)
+
+            self.log.info("Container group started %s/%s", 
self.resource_group, self.name)
+
+            exit_code = self._monitor_logging(ci_hook, self.resource_group, 
self.name)
+
+            self.log.info("Container had exit code: %s", exit_code)
+            if exit_code != 0:
+                raise AirflowException("Container had a non-zero exit code, %s"
+                                       % exit_code)
+
+        except CloudError as e:
+            self.log.exception("Could not start container group")
+            raise AirflowException("Could not start container group")
+
+        finally:
+            if exit_code == 0 or self.remove_on_error:
+                self.log.info("Deleting container group")
+                try:
+                    ci_hook.delete(self.resource_group, self.name)
+                except Exception:
+                    self.log.exception("Could not delete container group")
+
+    def _monitor_logging(self, ci_hook, resource_group, name):
+        last_state = None
+        last_message_logged = None
+        last_line_logged = None
+        for _ in range(43200):  # roughly 12 hours
+            try:
+                state, exit_code, detail_status = 
ci_hook.get_state_exitcode_details(resource_group, name)
+                if state != last_state:
+                    self.log.info("Container group state changed to %s", state)
+                    last_state = state
+
+                messages = ci_hook.get_messages(resource_group, name)
+                last_message_logged = self._log_last(messages, 
last_message_logged)
+
+                if state in ["Running", "Terminated"]:
+                    try:
+                        logs = ci_hook.get_logs(resource_group, name)
+                        last_line_logged = self._log_last(logs, 
last_line_logged)
+                    except CloudError as err:
+                        self.log.exception("Exception while getting logs from "
+                                           "container instance, retrying...")
+
+                if state == "Terminated":
+                    self.log.info("Container exited with detail_status %s", 
detail_status)
+                    return exit_code
+
+            except CloudError as err:
+                if 'ResourceNotFound' in str(err):
+                    self.log.warning("ResourceNotFound, container is probably 
removed "
+                                     "by another process "
+                                     "(make sure that the name is unique).")
+                    return 1
+                else:
+                    self.log.exception("Exception while getting container 
groups")
+            except Exception:
+                self.log.exception("Exception while getting container groups")
+
+            sleep(1)
+
+        # no return -> hence still running
+        raise AirflowTaskTimeout("Did not complete on time")
+
+    def _log_last(self, logs, last_line_logged):
+        if logs:
+            # determine the last line which was logged before
+            last_line_index = 0
+            for i in range(len(logs) - 1, -1, -1):
+                if logs[i] == last_line_logged:
+                    # this line is the same, hence print from i+1
+                    last_line_index = i + 1
+                    break
+
+            # log all new ones
+            for line in logs[last_line_index:]:
+                self.log.info(line.rstrip())
+
+            return logs[-1]
diff --git a/airflow/models/connection.py b/airflow/models/connection.py
index 9a51a43c08..ff63560020 100644
--- a/airflow/models/connection.py
+++ b/airflow/models/connection.py
@@ -83,6 +83,7 @@ class Connection(Base, LoggingMixin):
         ('snowflake', 'Snowflake',),
         ('segment', 'Segment',),
         ('azure_data_lake', 'Azure Data Lake'),
+        ('azure_container_instances', 'Azure Container Instances'),
         ('azure_cosmos', 'Azure CosmosDB'),
         ('cassandra', 'Cassandra',),
         ('qubole', 'Qubole'),
diff --git a/airflow/utils/db.py b/airflow/utils/db.py
index 7bd67ce533..e24d73d233 100644
--- a/airflow/utils/db.py
+++ b/airflow/utils/db.py
@@ -281,6 +281,10 @@ def initdb(rbac=False):
         Connection(
             conn_id='azure_cosmos_default', conn_type='azure_cosmos',
             extra='{"database_name": "<DATABASE_NAME>", "collection_name": 
"<COLLECTION_NAME>" }'))
+    merge_conn(
+        Connection(
+            conn_id='azure_container_instances_default', 
conn_type='azure_container_instances',
+            extra='{"tenantId": "<TENANT>", "subscriptionId": "<SUBSCRIPTION 
ID>" }'))
     merge_conn(
         Connection(
             conn_id='cassandra_default', conn_type='cassandra',
diff --git a/docs/integration.rst b/docs/integration.rst
index e74d8f662b..f35c8e87ea 100644
--- a/docs/integration.rst
+++ b/docs/integration.rst
@@ -161,34 +161,34 @@ Logging
 Airflow can be configured to read and write task logs in Azure Blob Storage.
 See :ref:`write-logs-azure`.
 
-Azure CosmosDB 
+Azure CosmosDB
 ''''''''''''''
- 
-AzureCosmosDBHook communicates via the Azure Cosmos library. Make sure that a 
-Airflow connection of type `azure_cosmos` exists. Authorization can be done by 
supplying a 
-login (=Endpoint uri), password (=secret key) and extra fields database_name 
and collection_name to specify the  
-default database and collection to use (see connection `azure_cosmos_default` 
for an example). 
- 
-- :ref:`AzureCosmosDBHook`: Interface with Azure CosmosDB. 
-- :ref:`AzureCosmosInsertDocumentOperator`: Simple operator to insert document 
into CosmosDB. 
-- :ref:`AzureCosmosDocumentSensor`: Simple sensor to detect document existence 
in CosmosDB. 
+
+AzureCosmosDBHook communicates via the Azure Cosmos library. Make sure that a
+Airflow connection of type `azure_cosmos` exists. Authorization can be done by 
supplying a
+login (=Endpoint uri), password (=secret key) and extra fields database_name 
and collection_name to specify the
+default database and collection to use (see connection `azure_cosmos_default` 
for an example).
+
+- :ref:`AzureCosmosDBHook`: Interface with Azure CosmosDB.
+- :ref:`AzureCosmosInsertDocumentOperator`: Simple operator to insert document 
into CosmosDB.
+- :ref:`AzureCosmosDocumentSensor`: Simple sensor to detect document existence 
in CosmosDB.
 
 .. _AzureCosmosDBHook:
 
-AzureCosmosDBHook 
+AzureCosmosDBHook
 """""""""""""""""
- 
-.. autoclass:: airflow.contrib.hooks.azure_cosmos_hook.AzureCosmosDBHook 
- 
-AzureCosmosInsertDocumentOperator 
+
+.. autoclass:: airflow.contrib.hooks.azure_cosmos_hook.AzureCosmosDBHook
+
+AzureCosmosInsertDocumentOperator
 """""""""""""""""""""""""""""""""
- 
+
 .. autoclass:: 
airflow.contrib.operators.azure_cosmos_operator.AzureCosmosInsertDocumentOperator
- 
-AzureCosmosDocumentSensor 
+
+AzureCosmosDocumentSensor
 """""""""""""""""""""""""
- 
-.. autoclass:: 
airflow.contrib.sensors.azure_cosmos_sensor.AzureCosmosDocumentSensor 
+
+.. autoclass:: 
airflow.contrib.sensors.azure_cosmos_sensor.AzureCosmosDocumentSensor
 
 Azure Data Lake
 '''''''''''''''
@@ -223,6 +223,51 @@ AdlsToGoogleCloudStorageOperator
 
 .. autoclass:: 
airflow.contrib.operators.adls_to_gcs.AdlsToGoogleCloudStorageOperator
 
+Azure Container Instances
+'''''''''''''''''''''''''
+
+Azure Container Instances provides a method to run a docker container without 
having to worry
+about managing infrastructure. The AzureContainerInstanceHook requires a 
service principal. The
+credentials for this principal can either be defined in the extra field 
`key_path`, as an
+environment variable named `AZURE_AUTH_LOCATION`,
+or by providing a login/password and tenantId in extras.
+
+The AzureContainerRegistryHook requires a host/login/password to be defined in 
the connection.
+
+- :ref:`AzureContainerInstancesOperator` : Start/Monitor a new ACI.
+- :ref:`AzureContainerInstanceHook` : Wrapper around a single ACI.
+- :ref:`AzureContainerRegistryHook` : Wrapper around a ACR
+- :ref:`AzureContainerVolumeHook` : Wrapper around Container Volumes
+
+.. _AzureContainerInstancesOperator:
+
+AzureContainerInstancesOperator
+"""""""""""""""""""""""""""""""
+
+.. autoclass:: 
airflow.contrib.operators.azure_container_instances_operator.AzureContainerInstancesOperator
+
+.. _AzureContainerInstanceHook:
+
+AzureContainerInstanceHook
+""""""""""""""""""""""""""
+
+.. autoclass:: 
airflow.contrib.hooks.azure_container_instance_hook.AzureContainerInstanceHook
+
+.. _AzureContainerRegistryHook:
+
+AzureContainerRegistryHook
+""""""""""""""""""""""""""
+
+.. autoclass:: 
airflow.contrib.hooks.azure_container_registry_hook.AzureContainerRegistryHook
+
+.. _AzureContainerVolumeHook:
+
+AzureContainerVolumeHook
+""""""""""""""""""""""""
+
+.. autoclass:: 
airflow.contrib.hooks.azure_container_volume_hook.AzureContainerVolumeHook
+
+
 .. _AWS:
 
 AWS: Amazon Web Services
diff --git a/setup.py b/setup.py
index 56b9bb2873..e62ea3f1f8 100644
--- a/setup.py
+++ b/setup.py
@@ -152,6 +152,7 @@ def write_version(filename=os.path.join(*['airflow',
     'azure-datalake-store==0.0.19'
 ]
 azure_cosmos = ['azure-cosmos>=3.0.1']
+azure_container_instances = ['azure-mgmt-containerinstance']
 cassandra = ['cassandra-driver>=3.13.0']
 celery = [
     'celery>=4.1.1, <4.2.0',
@@ -273,7 +274,7 @@ def write_version(filename=os.path.join(*['airflow',
              docker + ssh + kubernetes + celery + azure_blob_storage + redis + 
gcp_api +
              datadog + zendesk + jdbc + ldap + kerberos + password + webhdfs + 
jenkins +
              druid + pinot + segment + snowflake + elasticsearch + 
azure_data_lake + azure_cosmos +
-             atlas)
+             atlas + azure_container_instances)
 
 # Snakebite & Google Cloud Dataflow are not Python 3 compatible :'(
 if PY3:
@@ -348,6 +349,7 @@ def do_setup():
             'azure_blob_storage': azure_blob_storage,
             'azure_data_lake': azure_data_lake,
             'azure_cosmos': azure_cosmos,
+            'azure_container_instances': azure_container_instances,
             'cassandra': cassandra,
             'celery': celery,
             'cgroups': cgroups,
diff --git a/tests/contrib/hooks/test_azure_container_instance_hook.py 
b/tests/contrib/hooks/test_azure_container_instance_hook.py
new file mode 100644
index 0000000000..afaf2ff40d
--- /dev/null
+++ b/tests/contrib/hooks/test_azure_container_instance_hook.py
@@ -0,0 +1,124 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import json
+import unittest
+from collections import namedtuple
+from mock import patch
+
+from airflow import configuration
+from airflow.models.connection import Connection
+from airflow.contrib.hooks.azure_container_instance_hook import 
AzureContainerInstanceHook
+from airflow.utils import db
+
+from azure.mgmt.containerinstance.models import (Container,
+                                                 ContainerGroup,
+                                                 ContainerState,
+                                                 Event,
+                                                 Logs,
+                                                 ResourceRequests,
+                                                 ResourceRequirements)
+
+
+class TestAzureContainerInstanceHook(unittest.TestCase):
+
+    def setUp(self):
+        configuration.load_test_config()
+        db.merge_conn(
+            Connection(
+                conn_id='azure_container_instance_test',
+                conn_type='azure_container_instances',
+                login='login',
+                password='key',
+                extra=json.dumps({'tenantId': 'tenant_id',
+                                  'subscriptionId': 'subscription_id'})
+            )
+        )
+
+        self.resources = ResourceRequirements(requests=ResourceRequests(
+            memory_in_gb='4',
+            cpu='1'))
+        with 
patch('azure.common.credentials.ServicePrincipalCredentials.__init__',
+                   autospec=True, return_value=None):
+            with 
patch('azure.mgmt.containerinstance.ContainerInstanceManagementClient'):
+                self.testHook = 
AzureContainerInstanceHook(conn_id='azure_container_instance_test')
+
+    @patch('azure.mgmt.containerinstance.models.ContainerGroup')
+    
@patch('azure.mgmt.containerinstance.operations.ContainerGroupsOperations.create_or_update')
+    def test_create_or_update(self, create_or_update_mock, 
container_group_mock):
+        self.testHook.create_or_update('resource_group', 'aci-test', 
container_group_mock)
+        create_or_update_mock.assert_called_with('resource_group', 'aci-test', 
container_group_mock)
+
+    @patch('airflow.contrib.hooks.azure_container_instance_hook'
+           '.AzureContainerInstanceHook._get_instance_view')
+    def test_get_state_exitcode_details(self, get_instance_view_mock):
+        expected_state = ContainerState(state='testing', exit_code=1, 
detail_status='details')
+        instance_view = {"current_state": expected_state}
+        named_instance = namedtuple("InstanceView", 
instance_view.keys())(*instance_view.values())
+        get_instance_view_mock.return_value = named_instance
+
+        state, exit_code, details = 
self.testHook.get_state_exitcode_details('resource-group', 'test')
+
+        self.assertEqual(state, expected_state.state)
+        self.assertEqual(exit_code, expected_state.exit_code)
+        self.assertEqual(details, expected_state.detail_status)
+
+    @patch('airflow.contrib.hooks.azure_container_instance_hook'
+           '.AzureContainerInstanceHook._get_instance_view')
+    def test_get_messages(self, get_instance_view_mock):
+        expected_messages = ['test1', 'test2']
+        events = [Event(message=m) for m in expected_messages]
+        instance_view = {"events": events}
+        named_instance = namedtuple("Events", 
instance_view.keys())(*instance_view.values())
+        get_instance_view_mock.return_value = named_instance
+
+        messages = self.testHook.get_messages('resource-group', 'test')
+
+        self.assertSequenceEqual(messages, expected_messages)
+
+    
@patch('azure.mgmt.containerinstance.operations.ContainerOperations.list_logs')
+    def test_get_logs(self, list_logs_mock):
+        expected_messages = ['log line 1\n', 'log line 2\n', 'log line 3\n']
+        logs = Logs(content=''.join(expected_messages))
+        list_logs_mock.return_value = logs
+
+        logs = self.testHook.get_logs('resource_group', 'name', 'name')
+
+        self.assertSequenceEqual(logs, expected_messages)
+
+    
@patch('azure.mgmt.containerinstance.operations.ContainerGroupsOperations.delete')
+    def test_delete(self, delete_mock):
+        self.testHook.delete('resource_group', 'aci-test')
+        delete_mock.assert_called_with('resource_group', 'aci-test')
+
+    
@patch('azure.mgmt.containerinstance.operations.ContainerGroupsOperations.list_by_resource_group')
+    def test_exists_with_existing(self, list_mock):
+        list_mock.return_value = [ContainerGroup(os_type='Linux',
+                                                 
containers=[Container(name='test1',
+                                                                       
image='hello-world',
+                                                                       
resources=self.resources)])]
+        self.assertFalse(self.testHook.exists('test', 'test1'))
+
+    
@patch('azure.mgmt.containerinstance.operations.ContainerGroupsOperations.list_by_resource_group')
+    def test_exists_with_not_existing(self, list_mock):
+        list_mock.return_value = [ContainerGroup(os_type='Linux',
+                                                 
containers=[Container(name='test1',
+                                                                       
image='hello-world',
+                                                                       
resources=self.resources)])]
+        self.assertFalse(self.testHook.exists('test', 'not found'))
diff --git a/tests/contrib/hooks/test_azure_container_registry_hook.py 
b/tests/contrib/hooks/test_azure_container_registry_hook.py
new file mode 100644
index 0000000000..02c72cce83
--- /dev/null
+++ b/tests/contrib/hooks/test_azure_container_registry_hook.py
@@ -0,0 +1,44 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import unittest
+
+from airflow import configuration
+from airflow.models.connection import Connection
+from airflow.contrib.hooks.azure_container_registry_hook import 
AzureContainerRegistryHook
+from airflow.utils import db
+
+
+class TestAzureContainerRegistryHook(unittest.TestCase):
+
+    def test_get_conn(self):
+        configuration.load_test_config()
+        db.merge_conn(
+            Connection(
+                conn_id='azure_container_registry',
+                login='myuser',
+                password='password',
+                host='test.cr',
+            )
+        )
+        hook = AzureContainerRegistryHook(conn_id='azure_container_registry')
+        self.assertIsNotNone(hook.connection)
+        self.assertEqual(hook.connection.username, 'myuser')
+        self.assertEqual(hook.connection.password, 'password')
+        self.assertEqual(hook.connection.server, 'test.cr')
diff --git a/tests/contrib/hooks/test_azure_container_volume_hook.py 
b/tests/contrib/hooks/test_azure_container_volume_hook.py
new file mode 100644
index 0000000000..25b3c53ab1
--- /dev/null
+++ b/tests/contrib/hooks/test_azure_container_volume_hook.py
@@ -0,0 +1,50 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import unittest
+
+from airflow import configuration
+from airflow.models.connection import Connection
+from airflow.contrib.hooks.azure_container_volume_hook import 
AzureContainerVolumeHook
+from airflow.utils import db
+
+
+class TestAzureContainerVolumeHook(unittest.TestCase):
+
+    def test_get_file_volume(self):
+        configuration.load_test_config()
+        db.merge_conn(
+            Connection(
+                conn_id='wasb_test_key',
+                conn_type='wasb',
+                login='login',
+                password='key'
+            )
+        )
+        hook = AzureContainerVolumeHook(wasb_conn_id='wasb_test_key')
+        volume = hook.get_file_volume(mount_name='mount',
+                                      share_name='share',
+                                      storage_account_name='storage',
+                                      read_only=True)
+        self.assertIsNotNone(volume)
+        self.assertEqual(volume.name, 'mount')
+        self.assertEqual(volume.azure_file.share_name, 'share')
+        self.assertEqual(volume.azure_file.storage_account_key, 'key')
+        self.assertEqual(volume.azure_file.storage_account_name, 'storage')
+        self.assertEqual(volume.azure_file.read_only, True)
diff --git a/tests/contrib/operators/test_azure_container_instances_operator.py 
b/tests/contrib/operators/test_azure_container_instances_operator.py
new file mode 100644
index 0000000000..53014de414
--- /dev/null
+++ b/tests/contrib/operators/test_azure_container_instances_operator.py
@@ -0,0 +1,115 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from airflow.exceptions import AirflowException
+from airflow.contrib.operators.azure_container_instances_operator import 
AzureContainerInstancesOperator
+
+import unittest
+
+try:
+    from unittest import mock
+except ImportError:
+    try:
+        import mock
+    except ImportError:
+        mock = None
+
+
+class TestACIOperator(unittest.TestCase):
+
+    @mock.patch("airflow.contrib.operators."
+                
"azure_container_instances_operator.AzureContainerInstanceHook")
+    def test_execute(self, aci_mock):
+        aci_mock.return_value.get_state_exitcode_details.return_value = 
"Terminated", 0, "test"
+        aci_mock.return_value.exists.return_value = False
+
+        aci = AzureContainerInstancesOperator(ci_conn_id=None,
+                                              registry_conn_id=None,
+                                              resource_group='resource-group',
+                                              name='container-name',
+                                              image='container-image',
+                                              region='region',
+                                              task_id='task')
+        aci.execute(None)
+
+        self.assertEqual(aci_mock.return_value.create_or_update.call_count, 1)
+        (called_rg, called_cn, called_cg), _ = \
+            aci_mock.return_value.create_or_update.call_args
+
+        self.assertEqual(called_rg, 'resource-group')
+        self.assertEqual(called_cn, 'container-name')
+
+        self.assertEqual(called_cg.location, 'region')
+        self.assertEqual(called_cg.image_registry_credentials, None)
+        self.assertEqual(called_cg.restart_policy, 'Never')
+        self.assertEqual(called_cg.os_type, 'Linux')
+
+        called_cg_container = called_cg.containers[0]
+        self.assertEqual(called_cg_container.name, 'container-name')
+        self.assertEqual(called_cg_container.image, 'container-image')
+
+        self.assertEqual(aci_mock.return_value.delete.call_count, 1)
+
+    @mock.patch("airflow.contrib.operators."
+                
"azure_container_instances_operator.AzureContainerInstanceHook")
+    def test_execute_with_failures(self, aci_mock):
+        aci_mock.return_value.get_state_exitcode_details.return_value = 
"Terminated", 1, "test"
+        aci_mock.return_value.exists.return_value = False
+
+        aci = AzureContainerInstancesOperator(ci_conn_id=None,
+                                              registry_conn_id=None,
+                                              resource_group='resource-group',
+                                              name='container-name',
+                                              image='container-image',
+                                              region='region',
+                                              task_id='task')
+        with self.assertRaises(AirflowException):
+            aci.execute(None)
+
+        self.assertEqual(aci_mock.return_value.delete.call_count, 1)
+
+    @mock.patch("airflow.contrib.operators."
+                
"azure_container_instances_operator.AzureContainerInstanceHook")
+    def test_execute_with_messages_logs(self, aci_mock):
+        aci_mock.return_value.get_state_exitcode_details.side_effect = 
[("Running", 0, "test"),
+                                                                        
("Terminated", 0, "test")]
+        aci_mock.return_value.get_messages.return_value = ["test", "messages"]
+        aci_mock.return_value.get_logs.return_value = ["test", "logs"]
+        aci_mock.return_value.exists.return_value = False
+
+        aci = AzureContainerInstancesOperator(ci_conn_id=None,
+                                              registry_conn_id=None,
+                                              resource_group='resource-group',
+                                              name='container-name',
+                                              image='container-image',
+                                              region='region',
+                                              task_id='task')
+        aci.execute(None)
+
+        self.assertEqual(aci_mock.return_value.create_or_update.call_count, 1)
+        
self.assertEqual(aci_mock.return_value.get_state_exitcode_details.call_count, 2)
+        self.assertEqual(aci_mock.return_value.get_messages.call_count, 2)
+        self.assertEqual(aci_mock.return_value.get_logs.call_count, 2)
+
+        self.assertEqual(aci_mock.return_value.delete.call_count, 1)
+
+
+if __name__ == '__main__':
+    unittest.main()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to