This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/master by this push:
new 630aeff Fix AzureDataFactoryHook failing to instantiate its
connection (#14565)
630aeff is described below
commit 630aeff72c7903ae8d4608f3530057bb6255e10b
Author: flvndh <[email protected]>
AuthorDate: Wed Mar 3 00:22:32 2021 +0100
Fix AzureDataFactoryHook failing to instantiate its connection (#14565)
closes #14557
---
.../microsoft/azure/hooks/azure_data_factory.py | 41 ++++++++++++++--------
setup.py | 2 +-
.../azure/hooks/test_azure_data_factory.py | 4 +--
3 files changed, 29 insertions(+), 18 deletions(-)
diff --git a/airflow/providers/microsoft/azure/hooks/azure_data_factory.py
b/airflow/providers/microsoft/azure/hooks/azure_data_factory.py
index d6c686b..4da6a25 100644
--- a/airflow/providers/microsoft/azure/hooks/azure_data_factory.py
+++ b/airflow/providers/microsoft/azure/hooks/azure_data_factory.py
@@ -18,6 +18,8 @@ import inspect
from functools import wraps
from typing import Any, Callable, Optional
+from azure.core.polling import LROPoller
+from azure.identity import ClientSecretCredential
from azure.mgmt.datafactory import DataFactoryManagementClient
from azure.mgmt.datafactory.models import (
CreateRunResponse,
@@ -31,10 +33,9 @@ from azure.mgmt.datafactory.models import (
Trigger,
TriggerResource,
)
-from msrestazure.azure_operation import AzureOperationPoller
from airflow.exceptions import AirflowException
-from airflow.providers.microsoft.azure.hooks.base_azure import AzureBaseHook
+from airflow.hooks.base import BaseHook
def provide_targeted_factory(func: Callable) -> Callable:
@@ -69,7 +70,7 @@ def provide_targeted_factory(func: Callable) -> Callable:
return wrapper
-class AzureDataFactoryHook(AzureBaseHook): # pylint:
disable=too-many-public-methods
+class AzureDataFactoryHook(BaseHook): # pylint:
disable=too-many-public-methods
"""
A hook to interact with Azure Data Factory.
@@ -77,12 +78,22 @@ class AzureDataFactoryHook(AzureBaseHook): # pylint:
disable=too-many-public-me
"""
def __init__(self, conn_id: str = "azure_data_factory_default"):
- super().__init__(sdk_client=DataFactoryManagementClient,
conn_id=conn_id)
self._conn: DataFactoryManagementClient = None
+ self.conn_id = conn_id
+ super().__init__()
def get_conn(self) -> DataFactoryManagementClient:
- if not self._conn:
- self._conn = super().get_conn()
+ if self._conn is not None:
+ return self._conn
+
+ conn = self.get_connection(self.conn_id)
+
+ self._conn = DataFactoryManagementClient(
+ credential=ClientSecretCredential(
+ client_id=conn.login, client_secret=conn.password,
tenant_id=conn.extra_dejson.get("tenantId")
+ ),
+ subscription_id=conn.extra_dejson.get("subscriptionId"),
+ )
return self._conn
@@ -126,7 +137,7 @@ class AzureDataFactoryHook(AzureBaseHook): # pylint:
disable=too-many-public-me
:raise AirflowException: If the factory does not exist.
:return: The factory.
"""
- if not self._factory_exists(resource_group_name, factory):
+ if not self._factory_exists(resource_group_name, factory_name):
raise AirflowException(f"Factory {factory!r} does not exist.")
return self.get_conn().factories.create_or_update(
@@ -151,7 +162,7 @@ class AzureDataFactoryHook(AzureBaseHook): # pylint:
disable=too-many-public-me
:raise AirflowException: If the factory already exists.
:return: The factory.
"""
- if self._factory_exists(resource_group_name, factory):
+ if self._factory_exists(resource_group_name, factory_name):
raise AirflowException(f"Factory {factory!r} already exists.")
return self.get_conn().factories.create_or_update(
@@ -266,7 +277,7 @@ class AzureDataFactoryHook(AzureBaseHook): # pylint:
disable=too-many-public-me
**config: Any,
) -> None:
"""
- Delete the linked service:
+ Delete the linked service.
:param linked_service_name: The linked service name.
:param resource_group_name: The linked service name.
@@ -368,7 +379,7 @@ class AzureDataFactoryHook(AzureBaseHook): # pylint:
disable=too-many-public-me
**config: Any,
) -> None:
"""
- Delete the dataset:
+ Delete the dataset.
:param dataset_name: The dataset name.
:param resource_group_name: The dataset name.
@@ -468,7 +479,7 @@ class AzureDataFactoryHook(AzureBaseHook): # pylint:
disable=too-many-public-me
**config: Any,
) -> None:
"""
- Delete the pipeline:
+ Delete the pipeline.
:param pipeline_name: The pipeline name.
:param resource_group_name: The pipeline name.
@@ -642,7 +653,7 @@ class AzureDataFactoryHook(AzureBaseHook): # pylint:
disable=too-many-public-me
resource_group_name: Optional[str] = None,
factory_name: Optional[str] = None,
**config: Any,
- ) -> AzureOperationPoller:
+ ) -> LROPoller:
"""
Start the trigger.
@@ -652,7 +663,7 @@ class AzureDataFactoryHook(AzureBaseHook): # pylint:
disable=too-many-public-me
:param config: Extra parameters for the ADF client.
:return: An Azure operation poller.
"""
- return self.get_conn().triggers.start(resource_group_name,
factory_name, trigger_name, **config)
+ return self.get_conn().triggers.begin_start(resource_group_name,
factory_name, trigger_name, **config)
@provide_targeted_factory
def stop_trigger(
@@ -661,7 +672,7 @@ class AzureDataFactoryHook(AzureBaseHook): # pylint:
disable=too-many-public-me
resource_group_name: Optional[str] = None,
factory_name: Optional[str] = None,
**config: Any,
- ) -> AzureOperationPoller:
+ ) -> LROPoller:
"""
Stop the trigger.
@@ -671,7 +682,7 @@ class AzureDataFactoryHook(AzureBaseHook): # pylint:
disable=too-many-public-me
:param config: Extra parameters for the ADF client.
:return: An Azure operation poller.
"""
- return self.get_conn().triggers.stop(resource_group_name,
factory_name, trigger_name, **config)
+ return self.get_conn().triggers.begin_stop(resource_group_name,
factory_name, trigger_name, **config)
@provide_targeted_factory
def rerun_trigger(
diff --git a/setup.py b/setup.py
index a08a4d7..8c0e617 100644
--- a/setup.py
+++ b/setup.py
@@ -217,7 +217,7 @@ azure = [
'azure-keyvault>=4.1.0',
'azure-kusto-data>=0.0.43,<0.1',
'azure-mgmt-containerinstance>=1.5.0,<2.0',
- 'azure-mgmt-datafactory>=0.13.0',
+ 'azure-mgmt-datafactory>=1.0.0,<2.0',
'azure-mgmt-datalake-store>=0.5.0',
'azure-mgmt-resource>=2.2.0',
'azure-storage-blob>=12.7.0',
diff --git a/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py
b/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py
index ea445ec..e771b48 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py
@@ -406,7 +406,7 @@ def test_delete_trigger(hook: AzureDataFactoryHook,
user_args, sdk_args):
def test_start_trigger(hook: AzureDataFactoryHook, user_args, sdk_args):
hook.start_trigger(*user_args)
- hook._conn.triggers.start.assert_called_with(*sdk_args)
+ hook._conn.triggers.begin_start.assert_called_with(*sdk_args)
@parametrize(
@@ -416,7 +416,7 @@ def test_start_trigger(hook: AzureDataFactoryHook,
user_args, sdk_args):
def test_stop_trigger(hook: AzureDataFactoryHook, user_args, sdk_args):
hook.stop_trigger(*user_args)
- hook._conn.triggers.stop.assert_called_with(*sdk_args)
+ hook._conn.triggers.begin_stop.assert_called_with(*sdk_args)
@parametrize(