This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new 630aeff  Fix AzureDataFactoryHook failing to instantiate its 
connection (#14565)
630aeff is described below

commit 630aeff72c7903ae8d4608f3530057bb6255e10b
Author: flvndh <[email protected]>
AuthorDate: Wed Mar 3 00:22:32 2021 +0100

    Fix AzureDataFactoryHook failing to instantiate its connection (#14565)
    
    closes #14557
---
 .../microsoft/azure/hooks/azure_data_factory.py    | 41 ++++++++++++++--------
 setup.py                                           |  2 +-
 .../azure/hooks/test_azure_data_factory.py         |  4 +--
 3 files changed, 29 insertions(+), 18 deletions(-)

diff --git a/airflow/providers/microsoft/azure/hooks/azure_data_factory.py 
b/airflow/providers/microsoft/azure/hooks/azure_data_factory.py
index d6c686b..4da6a25 100644
--- a/airflow/providers/microsoft/azure/hooks/azure_data_factory.py
+++ b/airflow/providers/microsoft/azure/hooks/azure_data_factory.py
@@ -18,6 +18,8 @@ import inspect
 from functools import wraps
 from typing import Any, Callable, Optional
 
+from azure.core.polling import LROPoller
+from azure.identity import ClientSecretCredential
 from azure.mgmt.datafactory import DataFactoryManagementClient
 from azure.mgmt.datafactory.models import (
     CreateRunResponse,
@@ -31,10 +33,9 @@ from azure.mgmt.datafactory.models import (
     Trigger,
     TriggerResource,
 )
-from msrestazure.azure_operation import AzureOperationPoller
 
 from airflow.exceptions import AirflowException
-from airflow.providers.microsoft.azure.hooks.base_azure import AzureBaseHook
+from airflow.hooks.base import BaseHook
 
 
 def provide_targeted_factory(func: Callable) -> Callable:
@@ -69,7 +70,7 @@ def provide_targeted_factory(func: Callable) -> Callable:
     return wrapper
 
 
-class AzureDataFactoryHook(AzureBaseHook):  # pylint: 
disable=too-many-public-methods
+class AzureDataFactoryHook(BaseHook):  # pylint: 
disable=too-many-public-methods
     """
     A hook to interact with Azure Data Factory.
 
@@ -77,12 +78,22 @@ class AzureDataFactoryHook(AzureBaseHook):  # pylint: 
disable=too-many-public-me
     """
 
     def __init__(self, conn_id: str = "azure_data_factory_default"):
-        super().__init__(sdk_client=DataFactoryManagementClient, 
conn_id=conn_id)
         self._conn: DataFactoryManagementClient = None
+        self.conn_id = conn_id
+        super().__init__()
 
     def get_conn(self) -> DataFactoryManagementClient:
-        if not self._conn:
-            self._conn = super().get_conn()
+        if self._conn is not None:
+            return self._conn
+
+        conn = self.get_connection(self.conn_id)
+
+        self._conn = DataFactoryManagementClient(
+            credential=ClientSecretCredential(
+                client_id=conn.login, client_secret=conn.password, 
tenant_id=conn.extra_dejson.get("tenantId")
+            ),
+            subscription_id=conn.extra_dejson.get("subscriptionId"),
+        )
 
         return self._conn
 
@@ -126,7 +137,7 @@ class AzureDataFactoryHook(AzureBaseHook):  # pylint: 
disable=too-many-public-me
         :raise AirflowException: If the factory does not exist.
         :return: The factory.
         """
-        if not self._factory_exists(resource_group_name, factory):
+        if not self._factory_exists(resource_group_name, factory_name):
             raise AirflowException(f"Factory {factory!r} does not exist.")
 
         return self.get_conn().factories.create_or_update(
@@ -151,7 +162,7 @@ class AzureDataFactoryHook(AzureBaseHook):  # pylint: 
disable=too-many-public-me
         :raise AirflowException: If the factory already exists.
         :return: The factory.
         """
-        if self._factory_exists(resource_group_name, factory):
+        if self._factory_exists(resource_group_name, factory_name):
             raise AirflowException(f"Factory {factory!r} already exists.")
 
         return self.get_conn().factories.create_or_update(
@@ -266,7 +277,7 @@ class AzureDataFactoryHook(AzureBaseHook):  # pylint: 
disable=too-many-public-me
         **config: Any,
     ) -> None:
         """
-        Delete the linked service:
+        Delete the linked service.
 
         :param linked_service_name: The linked service name.
         :param resource_group_name: The linked service name.
@@ -368,7 +379,7 @@ class AzureDataFactoryHook(AzureBaseHook):  # pylint: 
disable=too-many-public-me
         **config: Any,
     ) -> None:
         """
-        Delete the dataset:
+        Delete the dataset.
 
         :param dataset_name: The dataset name.
         :param resource_group_name: The dataset name.
@@ -468,7 +479,7 @@ class AzureDataFactoryHook(AzureBaseHook):  # pylint: 
disable=too-many-public-me
         **config: Any,
     ) -> None:
         """
-        Delete the pipeline:
+        Delete the pipeline.
 
         :param pipeline_name: The pipeline name.
         :param resource_group_name: The pipeline name.
@@ -642,7 +653,7 @@ class AzureDataFactoryHook(AzureBaseHook):  # pylint: 
disable=too-many-public-me
         resource_group_name: Optional[str] = None,
         factory_name: Optional[str] = None,
         **config: Any,
-    ) -> AzureOperationPoller:
+    ) -> LROPoller:
         """
         Start the trigger.
 
@@ -652,7 +663,7 @@ class AzureDataFactoryHook(AzureBaseHook):  # pylint: 
disable=too-many-public-me
         :param config: Extra parameters for the ADF client.
         :return: An Azure operation poller.
         """
-        return self.get_conn().triggers.start(resource_group_name, 
factory_name, trigger_name, **config)
+        return self.get_conn().triggers.begin_start(resource_group_name, 
factory_name, trigger_name, **config)
 
     @provide_targeted_factory
     def stop_trigger(
@@ -661,7 +672,7 @@ class AzureDataFactoryHook(AzureBaseHook):  # pylint: 
disable=too-many-public-me
         resource_group_name: Optional[str] = None,
         factory_name: Optional[str] = None,
         **config: Any,
-    ) -> AzureOperationPoller:
+    ) -> LROPoller:
         """
         Stop the trigger.
 
@@ -671,7 +682,7 @@ class AzureDataFactoryHook(AzureBaseHook):  # pylint: 
disable=too-many-public-me
         :param config: Extra parameters for the ADF client.
         :return: An Azure operation poller.
         """
-        return self.get_conn().triggers.stop(resource_group_name, 
factory_name, trigger_name, **config)
+        return self.get_conn().triggers.begin_stop(resource_group_name, 
factory_name, trigger_name, **config)
 
     @provide_targeted_factory
     def rerun_trigger(
diff --git a/setup.py b/setup.py
index a08a4d7..8c0e617 100644
--- a/setup.py
+++ b/setup.py
@@ -217,7 +217,7 @@ azure = [
     'azure-keyvault>=4.1.0',
     'azure-kusto-data>=0.0.43,<0.1',
     'azure-mgmt-containerinstance>=1.5.0,<2.0',
-    'azure-mgmt-datafactory>=0.13.0',
+    'azure-mgmt-datafactory>=1.0.0,<2.0',
     'azure-mgmt-datalake-store>=0.5.0',
     'azure-mgmt-resource>=2.2.0',
     'azure-storage-blob>=12.7.0',
diff --git a/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py 
b/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py
index ea445ec..e771b48 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py
@@ -406,7 +406,7 @@ def test_delete_trigger(hook: AzureDataFactoryHook, 
user_args, sdk_args):
 def test_start_trigger(hook: AzureDataFactoryHook, user_args, sdk_args):
     hook.start_trigger(*user_args)
 
-    hook._conn.triggers.start.assert_called_with(*sdk_args)
+    hook._conn.triggers.begin_start.assert_called_with(*sdk_args)
 
 
 @parametrize(
@@ -416,7 +416,7 @@ def test_start_trigger(hook: AzureDataFactoryHook, 
user_args, sdk_args):
 def test_stop_trigger(hook: AzureDataFactoryHook, user_args, sdk_args):
     hook.stop_trigger(*user_args)
 
-    hook._conn.triggers.stop.assert_called_with(*sdk_args)
+    hook._conn.triggers.begin_stop.assert_called_with(*sdk_args)
 
 
 @parametrize(

Reply via email to