apraovjr closed pull request #4288: [AIRFLOW-3282] Implement an Azure 
Kubernetes Service Operator 
URL: https://github.com/apache/incubator-airflow/pull/4288
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git 
a/airflow/contrib/example_dags/example_azure_kubernetes_container_operator.py 
b/airflow/contrib/example_dags/example_azure_kubernetes_container_operator.py
new file mode 100644
index 0000000000..79fa5c5c16
--- /dev/null
+++ 
b/airflow/contrib/example_dags/example_azure_kubernetes_container_operator.py
@@ -0,0 +1,52 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from airflow import DAG
+from airflow.contrib.operators.aks_operator import AzureKubernetesOperator
+from datetime import datetime, timedelta
+
+seven_days_ago = datetime.combine(datetime.today() - timedelta(7),
+                                  datetime.min.time())
+default_args = {
+    'owner': 'airflow',
+    'depends_on_past': False,
+    'start_date': seven_days_ago,
+    'email': ['em...@microsoft.com'],
+    'email_on_failure': False,
+    'email_on_retry': False,
+    'retries': 1,
+    'retry_delay': timedelta(minutes=5),
+}
+
+dag = DAG(
+    dag_id='aks_container',
+    default_args=default_args,
+    schedule_interval=None,
+)
+
+start_aks_container = AzureKubernetesOperator(
+    task_id="start_aks_container",
+    ci_conn_id='azure_kubernetes_default',
+    resource_group="apraotest1",
+    name="akres1",
+    ssh_key_value=None,
+    dns_name_prefix=None,
+    location="eastus",
+    tags=None,
+    dag=dag)
diff --git a/airflow/contrib/hooks/azure_kubernetes_hook.py 
b/airflow/contrib/hooks/azure_kubernetes_hook.py
new file mode 100644
index 0000000000..dfd9200f64
--- /dev/null
+++ b/airflow/contrib/hooks/azure_kubernetes_hook.py
@@ -0,0 +1,75 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+
+from airflow.hooks.base_hook import BaseHook
+from airflow.exceptions import AirflowException
+
+from azure.common.credentials import ServicePrincipalCredentials
+from azure.mgmt.containerservice import ContainerServiceClient
+from azure.mgmt.resource import ResourceManagementClient
+from airflow.contrib.utils.aks_utils import load_json
+
+
+class AzureKubernetesServiceHook(BaseHook):
+
+    def __init__(self, conn_id=None):
+        self.conn_id = conn_id
+        self.connection = self.get_conn()
+        self.configData = None
+        self.credentials = None
+        self.subscription_id = None
+        self.clientId = None
+        self.clientSecret = None
+
+    def get_conn(self):
+        if self.conn_id:
+            conn = self.get_connection(self.conn_id)
+            key_path = conn.extra_dejson.get('key_path', False)
+            if key_path:
+                if key_path.endswith('.json'):
+                    self.log.info('Getting connection using a JSON key file.')
+
+                    self.configData = load_json(self, key_path)
+                else:
+                    raise AirflowException('Unrecognised extension for key 
file.')
+
+        if os.environ.get('AZURE_AUTH_LOCATION'):
+            key_path = os.environ.get('AZURE_AUTH_LOCATION')
+            if key_path.endswith('.json'):
+                self.log.info('Getting connection using a JSON key file.')
+                self.configData = load_json(self, key_path)
+            else:
+                raise AirflowException('Unrecognised extension for key file.')
+
+        self.credentials = ServicePrincipalCredentials(
+            client_id=self.configData['clientId'],
+            secret=self.configData['clientSecret'],
+            tenant=self.configData['tenantId']
+        )
+
+        self.subscription_id = self.configData['subscriptionId']
+        self.clientId = self.configData["clientId"]
+        self.clientSecret = self.configData["clientSecret"]
+        return ContainerServiceClient(self.credentials, 
str(self.subscription_id))
+
+    def check_resource(self, credentials, subscriptionId, resource_group):
+        resource_grp = ResourceManagementClient(credentials, subscriptionId)
+        return resource_grp.resource_groups.check_existence(resource_group)
+
+    def create_resource(self, credentials, subscriptionId, resource_group, 
location):
+        resource_grp = ResourceManagementClient(credentials, subscriptionId)
+        return resource_grp.resource_groups.create_or_update(resource_group, 
location)
diff --git a/airflow/contrib/operators/aks_operator.py 
b/airflow/contrib/operators/aks_operator.py
new file mode 100644
index 0000000000..6ffe25ecd6
--- /dev/null
+++ b/airflow/contrib/operators/aks_operator.py
@@ -0,0 +1,200 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from airflow.contrib.hooks.azure_kubernetes_hook import 
AzureKubernetesServiceHook
+from airflow.models import BaseOperator
+
+from azure.mgmt.containerservice.models import ContainerServiceLinuxProfile
+from azure.mgmt.containerservice.models import 
ContainerServiceServicePrincipalProfile
+from azure.mgmt.containerservice.models import ContainerServiceSshConfiguration
+from azure.mgmt.containerservice.models import ContainerServiceSshPublicKey
+from azure.mgmt.containerservice.models import 
ContainerServiceStorageProfileTypes
+from azure.mgmt.containerservice.models import ManagedCluster
+from azure.mgmt.containerservice.models import ManagedClusterAgentPoolProfile
+
+from airflow.contrib.utils.aks_utils import \
+    is_valid_ssh_rsa_public_key, get_poller_result, get_public_key, 
get_default_dns_prefix
+from knack.log import get_logger
+from msrestazure.azure_exceptions import CloudError
+
+logger = get_logger(__name__)
+
+
+class AzureKubernetesOperator(BaseOperator):
+    """
+    Start a Azure Kubernetes Service
+
+    :param ci_conn_id: connection id of a
+        service principal which will be used to start the azure kubernetes 
service
+    :type ci_conn_id: str
+    :param resource_group: Required name of the resource group
+        wherein this container instance should be started
+    :type resource_group: str
+    :param name: Required name of this container. Please note this name
+        has to be unique in order to run containers in parallel.
+    :type name: str
+    :param ssh_key_value: the ssh value used to connect to machine to be used
+    :type ssh_key_value: str
+    :param dns_name_prefix: DNS prefix specified when creating the managed 
cluster.
+    :type region: dns_name_prefix
+    :param admin_username: The administrator username to use for Linux VMs.
+    :type admin_username: str
+    :param kubernetes_version: Version of Kubernetes specified when creating 
the managed cluster.
+    :type kubernetes_version: str
+    :param node_vm_size: Vm to be spin up.
+    :type node_vm_size: str or ContainerServiceVMSizeTypes Enum
+    :param node_osdisk_size: Size in GB to be used to specify the disk size for
+        every machine in this master/agent pool. If you specify 0, it will 
apply the default
+        osDisk size according to the vmSize specified.
+    :type node_osdisk_size: int
+    :param node_count: Number of agents (VMs) to host docker containers.
+        Allowed values must be in the range of 1 to 100 (inclusive). The 
default value is 1.
+    :type node_count: int
+    :param no_ssh_key: Specified if it is linuxprofile.
+    :type no_ssh_key: boolean
+    :param vnet_subnet_id: VNet SubnetID specifies the vnet's subnet 
identifier.
+    :type vnet_subnet_id: str
+    :param max_pods: Maximum number of pods that can run on a node.
+    :type max_pods: int
+    :param os_type: OsType to be used to specify os type. Choose from Linux 
and Windows.
+        Default to Linux.
+    :type os_type: str or OSType Enum
+    :param tags: Resource tags.
+    :type tags: dict[str, str]
+    :param location: Required resource location
+    :type location: str
+
+    :Example:
+
+    >>> a = 
AzureKubernetesOperator(task_id="task",ci_conn_id='azure_kubernetes_default',
+            resource_group="my_resource_group",
+            name="my_aks_container",
+            ssh_key_value=None,
+            dns_name_prefix=None,
+            location="my_region",
+            tags=None
+        )
+    """
+
+    def __init__(self, ci_conn_id, resource_group, name, ssh_key_value,
+                 dns_name_prefix=None,
+                 location=None,
+                 admin_username="azureuser",
+                 kubernetes_version='',
+                 node_vm_size="Standard_DS2_v2",
+                 node_osdisk_size=0,
+                 node_count=3,
+                 no_ssh_key=False,
+                 vnet_subnet_id=None,
+                 max_pods=None,
+                 os_type="Linux",
+                 tags=None,
+                 *args, **kwargs):
+
+        self.ci_conn_id = ci_conn_id
+        self.resource_group = resource_group
+        self.name = name
+        self.no_ssh_key = no_ssh_key
+        self.dns_name_prefix = dns_name_prefix
+        self.location = location
+        self.admin_username = admin_username
+        self.node_vm_size = node_vm_size
+        self.node_count = node_count
+        self.ssh_key_value = ssh_key_value
+        self.vnet_subnet_id = vnet_subnet_id
+        self.max_pods = max_pods
+        self.os_type = os_type
+        self.tags = tags
+        self.node_osdisk_size = node_osdisk_size
+        self.kubernetes_version = kubernetes_version
+
+        super(AzureKubernetesOperator, self).__init__(*args, **kwargs)
+
+    def execute(self, context):
+        ci_hook = AzureKubernetesServiceHook(self.ci_conn_id)
+
+        containerservice = ci_hook.get_conn()
+
+        if not self.no_ssh_key:
+            try:
+                if not self.ssh_key_value or not 
is_valid_ssh_rsa_public_key(self.ssh_key_value):
+                    raise ValueError()
+            except (TypeError, ValueError):
+                self.ssh_key_value = get_public_key(self)
+
+        # dns_prefix
+        if not self.dns_name_prefix:
+            self.dns_name_prefix = get_default_dns_prefix(
+                self, self.name, self.resource_group, ci_hook.subscription_id)
+
+        # Check if the resource group exists
+        if ci_hook.check_resource(ci_hook.credentials, 
ci_hook.subscription_id, self.resource_group):
+            logger.info("Resource group already existing:" + 
self.resource_group)
+        else:
+            logger.info("Creating resource {0}".format(self.resource_group))
+            created_resource_group = ci_hook.create_resource(
+                ci_hook.credentials, ci_hook.subscription_id, 
self.resource_group, {
+                    'location': self.location})
+            print('Got resource group:', created_resource_group.name)
+
+        # Add agent_pool_profile
+        agent_pool_profile = ManagedClusterAgentPoolProfile(
+            name='nodepool1',  # Must be 12 chars or less before ACS RP adds 
to it
+            count=int(self.node_count),
+            vm_size=self.node_vm_size,
+            os_type=self.os_type,
+            storage_profile=ContainerServiceStorageProfileTypes.managed_disks,
+            vnet_subnet_id=self.vnet_subnet_id,
+            max_pods=int(self.max_pods) if self.max_pods else None
+        )
+
+        if self.node_osdisk_size:
+            agent_pool_profile.os_disk_size_gb = int(self.node_osdisk_size)
+
+        linux_profile = None
+
+        # LinuxProfile is just used for SSH access to VMs, so omit it if 
--no-ssh-key was specified.
+        if not self.no_ssh_key:
+            ssh_config = ContainerServiceSshConfiguration(
+                
public_keys=[ContainerServiceSshPublicKey(key_data=self.ssh_key_value)])
+            linux_profile = 
ContainerServiceLinuxProfile(admin_username=self.admin_username, ssh=ssh_config)
+
+        service_profile = ContainerServiceServicePrincipalProfile(
+            client_id=ci_hook.clientId, secret=ci_hook.clientSecret, 
key_vault_secret_ref=None)
+
+        mc = ManagedCluster(
+            location=self.location, tags=self.tags,
+            dns_prefix=self.dns_name_prefix,
+            kubernetes_version=self.kubernetes_version,
+            agent_pool_profiles=[agent_pool_profile],
+            linux_profile=linux_profile,
+            service_principal_profile=service_profile)
+
+        try:
+            logger.info("Checking if the AKS instance {0} is 
present".format(self.name))
+            response = 
containerservice.managed_clusters.get(self.resource_group, self.name)
+            logger.info("Response : {0}".format(response))
+            logger.info("AKS instance : {0} found".format(response.name))
+            return response
+        except CloudError:
+            poller = containerservice.managed_clusters.create_or_update(
+                resource_group_name=self.resource_group, 
resource_name=self.name, parameters=mc)
+            response = get_poller_result(self, poller)
+            logger.info("AKS instance created {0}".format(self.name))
+            return response
diff --git a/airflow/contrib/utils/aks_utils.py 
b/airflow/contrib/utils/aks_utils.py
new file mode 100644
index 0000000000..3363ac52c7
--- /dev/null
+++ b/airflow/contrib/utils/aks_utils.py
@@ -0,0 +1,92 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import sys
+import json
+import re
+from cryptography.hazmat.backends import default_backend
+from cryptography.hazmat.primitives.asymmetric import rsa
+from cryptography.hazmat.primitives import serialization
+
+from knack.log import get_logger
+logger = get_logger(__name__)
+
+
+def is_valid_ssh_rsa_public_key(openssh_pubkey):
+    # 
http://stackoverflow.com/questions/2494450/ssh-rsa-public-key-validation-using-a-regular-expression
+    # A "good enough" check is to see if the key starts with the correct 
header.
+    import struct
+    try:
+        from base64 import decodebytes as base64_decode
+    except ImportError:
+        # deprecated and redirected to decodebytes in Python 3
+        from base64 import decodestring as base64_decode
+
+    parts = openssh_pubkey.split()
+    if len(parts) < 2:
+        return False
+    key_type = parts[0]
+    key_string = parts[1]
+
+    data = base64_decode(key_string.encode())  # 
pylint:disable=deprecated-method
+    int_len = 4
+    str_len = struct.unpack('>I', data[:int_len])[0]  # this should return 7
+    return data[int_len:int_len + str_len] == key_type.encode()
+
+
+def get_public_key(self):
+    key = rsa.generate_private_key(backend=default_backend(), 
public_exponent=65537, key_size=2048)
+    return key.public_key().public_bytes(serialization.Encoding.OpenSSH,
+                                         
serialization.PublicFormat.OpenSSH).decode('utf-8')
+
+
+def load_json(self, file_path):
+    try:
+        with open(file_path) as configFile:
+            configData = json.load(configFile)
+    except FileNotFoundError:
+        print("Error: Expecting azurermconfig.json in current folder")
+        sys.exit()
+    return configData
+
+
+def get_poller_result(self, poller, wait=5):
+    '''
+    Consistent method of waiting on and retrieving results from Azure's long 
poller
+    :param poller Azure poller object
+    :return object resulting from the original request
+    '''
+    try:
+        delay = wait
+        while not poller.done():
+            logger.info("Waiting for {0} sec".format(delay))
+            poller.wait(timeout=delay)
+        return poller.result()
+    except Exception as exc:
+        logger.info("exception here {0} ".format(exc))
+        raise
+
+
+def get_default_dns_prefix(self, name, resource_group, subscription_id):
+    # Use subscription id to provide uniqueness and prevent DNS name clashes
+    name_part = re.sub('[^A-Za-z0-9-]', '', name)[0:10]
+    if not name_part[0].isalpha():
+        name_part = (str('a') + name_part)[0:10]
+    resource_group_part = re.sub('[^A-Za-z0-9-]', '', resource_group)[0:16]
+    return '{}-{}-{}'.format(name_part, resource_group_part, 
subscription_id[0:6])
diff --git a/airflow/models.py b/airflow/models.py
index 1bca27cbc8..7b7fb2f86a 100755
--- a/airflow/models.py
+++ b/airflow/models.py
@@ -668,6 +668,7 @@ class Connection(Base, LoggingMixin):
         ('segment', 'Segment',),
         ('azure_data_lake', 'Azure Data Lake'),
         ('cassandra', 'Cassandra',),
+        ('azure_kubernetes_instances', 'Azure Kubernetes Instances'),
         ('qubole', 'Qubole'),
         ('mongo', 'MongoDB'),
         ('gcpcloudsql', 'Google Cloud SQL'),
@@ -808,6 +809,9 @@ def get_hook(self):
             elif self.conn_type == 'azure_data_lake':
                 from airflow.contrib.hooks.azure_data_lake_hook import 
AzureDataLakeHook
                 return AzureDataLakeHook(azure_data_lake_conn_id=self.conn_id)
+            elif self.conn_type == 'azure_kubernetes_instances':
+                from airflow.contrib.hooks.azure_kubernetes_hook import 
AzureKubernetesServiceHook
+                return AzureKubernetesServiceHook(conn_id=self.conn_id)
             elif self.conn_type == 'cassandra':
                 from airflow.contrib.hooks.cassandra_hook import CassandraHook
                 return CassandraHook(cassandra_conn_id=self.conn_id)
diff --git a/airflow/utils/db.py b/airflow/utils/db.py
index b6a807c86c..dd03f4b683 100644
--- a/airflow/utils/db.py
+++ b/airflow/utils/db.py
@@ -277,6 +277,10 @@ def initdb(rbac=False):
         models.Connection(
             conn_id='azure_data_lake_default', conn_type='azure_data_lake',
             extra='{"tenant": "<TENANT>", "account_name": "<ACCOUNTNAME>" }'))
+    merge_conn(
+        models.Connection(
+            conn_id='azure_kubernetes_default', 
conn_type='azure_kubernetes_instances',
+            extra='{"key_path": "<AZURE_AUTH_LOCATION>" }'))
     merge_conn(
         models.Connection(
             conn_id='cassandra_default', conn_type='cassandra',
diff --git a/docs/integration.rst b/docs/integration.rst
index 8fbc4be764..4e68b701ef 100644
--- a/docs/integration.rst
+++ b/docs/integration.rst
@@ -196,6 +196,23 @@ AdlsToGoogleCloudStorageOperator
 
 .. _AWS:
 
+Azure Kubernetes Service
+'''''''''''''''''''''''''
+Azure Kubernetes Service (AKS) simplifies the deployment and operations of 
Kubernetes and enables you to dynamically scale your application infrastructure.
+The AzureKubernetesServiceHook requires a service principal. The credentials 
for this principal can either be defined in the extra field `key_path`, as an 
+environment variable named `AZURE_AUTH_LOCATION`.
+
+- :ref:`AzureKubernetesOperator` : Start new AKS.
+- :ref:`AzureKubernetesServiceHook` : Wrapper around a aks.
+
+AzureKubernetesOperator
+"""""""""""""""""""""""""""""""
+ .. autoclass:: airflow.contrib.operators.aks_operator.AzureKubernetesOperator
+
+AzureKubernetesServiceHook
+""""""""""""""""""""""""""
+ .. autoclass:: 
airflow.contrib.hooks.azure_kubernetes_hook.AzureKubernetesServiceHook
+
 AWS: Amazon Web Services
 ------------------------
 
diff --git a/setup.py b/setup.py
index 9170ce7f6c..564d49207c 100644
--- a/setup.py
+++ b/setup.py
@@ -151,6 +151,11 @@ def write_version(filename=os.path.join(*['airflow',
     'azure-mgmt-datalake-store==0.4.0',
     'azure-datalake-store==0.0.19'
 ]
+azure_kubernetes_service = [
+    'azure.mgmt.containerservice==4.2.2',
+    'azure.mgmt.resource==2.0.0',
+    'knack==0.4.4'
+]
 cassandra = ['cassandra-driver>=3.13.0']
 celery = [
     'celery>=4.1.1, <4.2.0',
@@ -269,7 +274,7 @@ def write_version(filename=os.path.join(*['airflow',
              docker + ssh + kubernetes + celery + azure_blob_storage + redis + 
gcp_api +
              datadog + zendesk + jdbc + ldap + kerberos + password + webhdfs + 
jenkins +
              druid + pinot + segment + snowflake + elasticsearch + 
azure_data_lake +
-             atlas)
+             azure_kubernetes_service + atlas)
 
 # Snakebite & Google Cloud Dataflow are not Python 3 compatible :'(
 if PY3:
@@ -343,6 +348,7 @@ def do_setup():
             'async': async_packages,
             'azure_blob_storage': azure_blob_storage,
             'azure_data_lake': azure_data_lake,
+            'azure_kubernetes_service': azure_kubernetes_service,
             'cassandra': cassandra,
             'celery': celery,
             'cgroups': cgroups,
diff --git a/tests/contrib/hooks/test_azure_kubernetes_hook.py 
b/tests/contrib/hooks/test_azure_kubernetes_hook.py
new file mode 100644
index 0000000000..7fdaf5dc36
--- /dev/null
+++ b/tests/contrib/hooks/test_azure_kubernetes_hook.py
@@ -0,0 +1,84 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import json
+import unittest
+
+from airflow import configuration
+from airflow.exceptions import AirflowException
+from airflow import models
+from airflow.contrib.hooks.azure_kubernetes_hook import 
AzureKubernetesServiceHook
+from airflow.utils import db
+
+try:
+    from unittest import mock
+except ImportError:
+    try:
+        import mock
+    except ImportError:
+        mock = None
+
+CONFIG_DATA = {
+    "clientId": "Id",
+    "clientSecret": "secret",
+    "subscriptionId": "subscription",
+    "tenantId": "tenant"
+}
+
+
+class TestAzureKubernetesHook(unittest.TestCase):
+    def setUp(self):
+        configuration.load_test_config()
+        db.merge_conn(
+            models.Connection(
+                conn_id='azure_default',
+                extra=json.dumps({"key_path": "azureauth.json"})
+            )
+        )
+
+    @mock.patch('airflow.contrib.hooks.azure_kubernetes_hook.load_json')
+    
@mock.patch('airflow.contrib.hooks.azure_kubernetes_hook.ServicePrincipalCredentials')
+    def test_conn(self, mock_json, mock_service):
+        from azure.mgmt.containerservice import ContainerServiceClient
+        mock_json.return_value = CONFIG_DATA
+        hook = AzureKubernetesServiceHook(conn_id='azure_default')
+        self.assertEqual(hook.conn_id, 'azure_default')
+        self.assertIsInstance(hook.connection, ContainerServiceClient)
+
+    @mock.patch('airflow.contrib.hooks.azure_kubernetes_hook.load_json')
+    
@mock.patch('airflow.contrib.hooks.azure_kubernetes_hook.ServicePrincipalCredentials')
+    @mock.patch('os.environ.get', new={'AZURE_AUTH_LOCATION': 
'azureauth.json'}.get, spec_set=True)
+    def test_conn_env(self, mock_json, mock_service):
+        from azure.mgmt.containerservice import ContainerServiceClient
+        mock_json.return_value = CONFIG_DATA
+        hook = AzureKubernetesServiceHook(conn_id=None)
+        self.assertEqual(hook.conn_id, None)
+        self.assertIsInstance(hook.connection, ContainerServiceClient)
+
+    @mock.patch('os.environ.get', new={'AZURE_AUTH_LOCATION': 
'azureauth.jpeg'}.get, spec_set=True)
+    def test_conn_with_failures(self):
+        with self.assertRaises(AirflowException) as ex:
+            AzureKubernetesServiceHook(conn_id=None)
+
+        self.assertEqual(str(ex.exception), "Unrecognised extension for key 
file.")
+
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/tests/contrib/operators/test_aks_operator.py 
b/tests/contrib/operators/test_aks_operator.py
new file mode 100644
index 0000000000..159b0a8ea0
--- /dev/null
+++ b/tests/contrib/operators/test_aks_operator.py
@@ -0,0 +1,78 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import unittest
+
+from airflow.contrib.operators.aks_operator import AzureKubernetesOperator
+from msrestazure.azure_exceptions import CloudError
+
+try:
+    from unittest import mock
+except ImportError:
+    try:
+        import mock
+    except ImportError:
+        mock = None
+
+
+class AzureKubernetesKubernetesOperatorTest(unittest.TestCase):
+    
@mock.patch('airflow.contrib.operators.aks_operator.AzureKubernetesServiceHook')
+    def test_execute_existing_kubernetes(self, aks_hook_mock):
+
+        aks = AzureKubernetesOperator(ci_conn_id=None,
+                                      resource_group="resource_group",
+                                      name="name",
+                                      ssh_key_value=None,
+                                      dns_name_prefix=None,
+                                      location="location",
+                                      tags=None,
+                                      task_id='task')
+
+        client_hook = aks_hook_mock.return_value.get_conn.return_value
+
+        aks.execute(None)
+
+        
client_hook.managed_clusters.get.assert_called_once_with('resource_group', 
'name')
+        
self.assertEqual(client_hook.managed_clusters.create_or_update.call_count, 0)
+
+    
@mock.patch('airflow.contrib.operators.aks_operator.AzureKubernetesServiceHook')
+    def test_execute_create_kubernetes(self, aks_hook_mock):
+
+        aks = AzureKubernetesOperator(ci_conn_id=None,
+                                      resource_group="resource_group",
+                                      name="name",
+                                      ssh_key_value=None,
+                                      dns_name_prefix=None,
+                                      location="location",
+                                      tags=None,
+                                      task_id='task')
+
+        client_hook = aks_hook_mock.return_value.get_conn.return_value
+
+        resp = mock.MagicMock()
+        resp.status_code = 404
+        resp.text = '{"Message": "The Resource 
Microsoft.ContainerService/managedClusters/name under resource group \
+        resource_group was not found."}'
+
+        client_hook.managed_clusters.get.side_effect = CloudError(resp, 
error="Not found")
+
+        aks.execute(None)
+
+        self.assertEqual(client_hook.managed_clusters.get.call_count, 1)
+        
self.assertEqual(client_hook.managed_clusters.create_or_update.call_count, 1)


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to