This is an automated email from the ASF dual-hosted git repository. potiuk pushed a commit to branch v2-0-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 9ecee99ef439ba8bc7e624bea768e24445e7d841 Author: flvndh <[email protected]> AuthorDate: Fri Feb 26 17:28:21 2021 +0100 Add Azure Data Factory hook (#11015) fixes #10995 (cherry picked from commit 11d03d2f63d88a284d6aaded5f9ab6642a60561b) --- .../microsoft/azure/hooks/azure_data_factory.py | 716 +++++++++++++++++++++ airflow/providers/microsoft/azure/provider.yaml | 8 + .../integration-logos/azure/Azure Data Factory.svg | 1 + docs/spelling_wordlist.txt | 1 + setup.py | 1 + .../azure/hooks/test_azure_data_factory.py | 439 +++++++++++++ 6 files changed, 1166 insertions(+) diff --git a/airflow/providers/microsoft/azure/hooks/azure_data_factory.py b/airflow/providers/microsoft/azure/hooks/azure_data_factory.py new file mode 100644 index 0000000..d6c686b --- /dev/null +++ b/airflow/providers/microsoft/azure/hooks/azure_data_factory.py @@ -0,0 +1,716 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import inspect +from functools import wraps +from typing import Any, Callable, Optional + +from azure.mgmt.datafactory import DataFactoryManagementClient +from azure.mgmt.datafactory.models import ( + CreateRunResponse, + Dataset, + DatasetResource, + Factory, + LinkedService, + LinkedServiceResource, + PipelineResource, + PipelineRun, + Trigger, + TriggerResource, +) +from msrestazure.azure_operation import AzureOperationPoller + +from airflow.exceptions import AirflowException +from airflow.providers.microsoft.azure.hooks.base_azure import AzureBaseHook + + +def provide_targeted_factory(func: Callable) -> Callable: + """ + Provide the targeted factory to the decorated function in case it isn't specified. + + If ``resource_group_name`` or ``factory_name`` is not provided it defaults to the value specified in + the connection extras. + """ + signature = inspect.signature(func) + + @wraps(func) + def wrapper(*args, **kwargs) -> Callable: + bound_args = signature.bind(*args, **kwargs) + + def bind_argument(arg, default_key): + if arg not in bound_args.arguments: + self = args[0] + conn = self.get_connection(self.conn_id) + default_value = conn.extra_dejson.get(default_key) + + if not default_value: + raise AirflowException("Could not determine the targeted data factory.") + + bound_args.arguments[arg] = conn.extra_dejson[default_key] + + bind_argument("resource_group_name", "resourceGroup") + bind_argument("factory_name", "factory") + + return func(*bound_args.args, **bound_args.kwargs) + + return wrapper + + +class AzureDataFactoryHook(AzureBaseHook): # pylint: disable=too-many-public-methods + """ + A hook to interact with Azure Data Factory. + + :param conn_id: The Azure Data Factory connection id. + """ + + def __init__(self, conn_id: str = "azure_data_factory_default"): + super().__init__(sdk_client=DataFactoryManagementClient, conn_id=conn_id) + self._conn: DataFactoryManagementClient = None + + def get_conn(self) -> DataFactoryManagementClient: + if not self._conn: + self._conn = super().get_conn() + + return self._conn + + @provide_targeted_factory + def get_factory( + self, resource_group_name: Optional[str] = None, factory_name: Optional[str] = None, **config: Any + ) -> Factory: + """ + Get the factory. + + :param resource_group_name: The resource group name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + :return: The factory. + """ + return self.get_conn().factories.get(resource_group_name, factory_name, **config) + + def _factory_exists(self, resource_group_name, factory_name) -> bool: + """Return whether or not the factory already exists.""" + factories = { + factory.name for factory in self.get_conn().factories.list_by_resource_group(resource_group_name) + } + + return factory_name in factories + + @provide_targeted_factory + def update_factory( + self, + factory: Factory, + resource_group_name: Optional[str] = None, + factory_name: Optional[str] = None, + **config: Any, + ) -> Factory: + """ + Update the factory. + + :param factory: The factory resource definition. + :param resource_group_name: The resource group name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + :raise AirflowException: If the factory does not exist. + :return: The factory. + """ + if not self._factory_exists(resource_group_name, factory): + raise AirflowException(f"Factory {factory!r} does not exist.") + + return self.get_conn().factories.create_or_update( + resource_group_name, factory_name, factory, **config + ) + + @provide_targeted_factory + def create_factory( + self, + factory: Factory, + resource_group_name: Optional[str] = None, + factory_name: Optional[str] = None, + **config: Any, + ) -> Factory: + """ + Create the factory. + + :param factory: The factory resource definition. + :param resource_group_name: The resource group name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + :raise AirflowException: If the factory already exists. + :return: The factory. + """ + if self._factory_exists(resource_group_name, factory): + raise AirflowException(f"Factory {factory!r} already exists.") + + return self.get_conn().factories.create_or_update( + resource_group_name, factory_name, factory, **config + ) + + @provide_targeted_factory + def delete_factory( + self, resource_group_name: Optional[str] = None, factory_name: Optional[str] = None, **config: Any + ) -> None: + """ + Delete the factory. + + :param resource_group_name: The resource group name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + """ + self.get_conn().factories.delete(resource_group_name, factory_name, **config) + + @provide_targeted_factory + def get_linked_service( + self, + linked_service_name: str, + resource_group_name: Optional[str] = None, + factory_name: Optional[str] = None, + **config: Any, + ) -> LinkedServiceResource: + """ + Get the linked service. + + :param linked_service_name: The linked service name. + :param resource_group_name: The resource group name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + :return: The linked service. + """ + return self.get_conn().linked_services.get( + resource_group_name, factory_name, linked_service_name, **config + ) + + def _linked_service_exists(self, resource_group_name, factory_name, linked_service_name) -> bool: + """Return whether or not the linked service already exists.""" + linked_services = { + linked_service.name + for linked_service in self.get_conn().linked_services.list_by_factory( + resource_group_name, factory_name + ) + } + + return linked_service_name in linked_services + + @provide_targeted_factory + def update_linked_service( + self, + linked_service_name: str, + linked_service: LinkedService, + resource_group_name: Optional[str] = None, + factory_name: Optional[str] = None, + **config: Any, + ) -> LinkedServiceResource: + """ + Update the linked service. + + :param linked_service_name: The linked service name. + :param linked_service: The linked service resource definition. + :param resource_group_name: The resource group name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + :raise AirflowException: If the linked service does not exist. + :return: The linked service. + """ + if not self._linked_service_exists(resource_group_name, factory_name, linked_service_name): + raise AirflowException(f"Linked service {linked_service_name!r} does not exist.") + + return self.get_conn().linked_services.create_or_update( + resource_group_name, factory_name, linked_service_name, linked_service, **config + ) + + @provide_targeted_factory + def create_linked_service( + self, + linked_service_name: str, + linked_service: LinkedService, + resource_group_name: Optional[str] = None, + factory_name: Optional[str] = None, + **config: Any, + ) -> LinkedServiceResource: + """ + Create the linked service. + + :param linked_service_name: The linked service name. + :param linked_service: The linked service resource definition. + :param resource_group_name: The resource group name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + :raise AirflowException: If the linked service already exists. + :return: The linked service. + """ + if self._linked_service_exists(resource_group_name, factory_name, linked_service_name): + raise AirflowException(f"Linked service {linked_service_name!r} already exists.") + + return self.get_conn().linked_services.create_or_update( + resource_group_name, factory_name, linked_service_name, linked_service, **config + ) + + @provide_targeted_factory + def delete_linked_service( + self, + linked_service_name: str, + resource_group_name: Optional[str] = None, + factory_name: Optional[str] = None, + **config: Any, + ) -> None: + """ + Delete the linked service: + + :param linked_service_name: The linked service name. + :param resource_group_name: The linked service name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + """ + self.get_conn().linked_services.delete( + resource_group_name, factory_name, linked_service_name, **config + ) + + @provide_targeted_factory + def get_dataset( + self, + dataset_name: str, + resource_group_name: Optional[str] = None, + factory_name: Optional[str] = None, + **config: Any, + ) -> DatasetResource: + """ + Get the dataset. + + :param dataset_name: The dataset name. + :param resource_group_name: The resource group name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + :return: The dataset. + """ + return self.get_conn().datasets.get(resource_group_name, factory_name, dataset_name, **config) + + def _dataset_exists(self, resource_group_name, factory_name, dataset_name) -> bool: + """Return whether or not the dataset already exists.""" + datasets = { + dataset.name + for dataset in self.get_conn().datasets.list_by_factory(resource_group_name, factory_name) + } + + return dataset_name in datasets + + @provide_targeted_factory + def update_dataset( + self, + dataset_name: str, + dataset: Dataset, + resource_group_name: Optional[str] = None, + factory_name: Optional[str] = None, + **config: Any, + ) -> DatasetResource: + """ + Update the dataset. + + :param dataset_name: The dataset name. + :param dataset: The dataset resource definition. + :param resource_group_name: The resource group name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + :raise AirflowException: If the dataset does not exist. + :return: The dataset. + """ + if not self._dataset_exists(resource_group_name, factory_name, dataset_name): + raise AirflowException(f"Dataset {dataset_name!r} does not exist.") + + return self.get_conn().datasets.create_or_update( + resource_group_name, factory_name, dataset_name, dataset, **config + ) + + @provide_targeted_factory + def create_dataset( + self, + dataset_name: str, + dataset: Dataset, + resource_group_name: Optional[str] = None, + factory_name: Optional[str] = None, + **config: Any, + ) -> DatasetResource: + """ + Create the dataset. + + :param dataset_name: The dataset name. + :param dataset: The dataset resource definition. + :param resource_group_name: The resource group name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + :raise AirflowException: If the dataset already exists. + :return: The dataset. + """ + if self._dataset_exists(resource_group_name, factory_name, dataset_name): + raise AirflowException(f"Dataset {dataset_name!r} already exists.") + + return self.get_conn().datasets.create_or_update( + resource_group_name, factory_name, dataset_name, dataset, **config + ) + + @provide_targeted_factory + def delete_dataset( + self, + dataset_name: str, + resource_group_name: Optional[str] = None, + factory_name: Optional[str] = None, + **config: Any, + ) -> None: + """ + Delete the dataset: + + :param dataset_name: The dataset name. + :param resource_group_name: The dataset name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + """ + self.get_conn().datasets.delete(resource_group_name, factory_name, dataset_name, **config) + + @provide_targeted_factory + def get_pipeline( + self, + pipeline_name: str, + resource_group_name: Optional[str] = None, + factory_name: Optional[str] = None, + **config: Any, + ) -> PipelineResource: + """ + Get the pipeline. + + :param pipeline_name: The pipeline name. + :param resource_group_name: The resource group name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + :return: The pipeline. + """ + return self.get_conn().pipelines.get(resource_group_name, factory_name, pipeline_name, **config) + + def _pipeline_exists(self, resource_group_name, factory_name, pipeline_name) -> bool: + """Return whether or not the pipeline already exists.""" + pipelines = { + pipeline.name + for pipeline in self.get_conn().pipelines.list_by_factory(resource_group_name, factory_name) + } + + return pipeline_name in pipelines + + @provide_targeted_factory + def update_pipeline( + self, + pipeline_name: str, + pipeline: PipelineResource, + resource_group_name: Optional[str] = None, + factory_name: Optional[str] = None, + **config: Any, + ) -> PipelineResource: + """ + Update the pipeline. + + :param pipeline_name: The pipeline name. + :param pipeline: The pipeline resource definition. + :param resource_group_name: The resource group name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + :raise AirflowException: If the pipeline does not exist. + :return: The pipeline. + """ + if not self._pipeline_exists(resource_group_name, factory_name, pipeline_name): + raise AirflowException(f"Pipeline {pipeline_name!r} does not exist.") + + return self.get_conn().pipelines.create_or_update( + resource_group_name, factory_name, pipeline_name, pipeline, **config + ) + + @provide_targeted_factory + def create_pipeline( + self, + pipeline_name: str, + pipeline: PipelineResource, + resource_group_name: Optional[str] = None, + factory_name: Optional[str] = None, + **config: Any, + ) -> PipelineResource: + """ + Create the pipeline. + + :param pipeline_name: The pipeline name. + :param pipeline: The pipeline resource definition. + :param resource_group_name: The resource group name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + :raise AirflowException: If the pipeline already exists. + :return: The pipeline. + """ + if self._pipeline_exists(resource_group_name, factory_name, pipeline_name): + raise AirflowException(f"Pipeline {pipeline_name!r} already exists.") + + return self.get_conn().pipelines.create_or_update( + resource_group_name, factory_name, pipeline_name, pipeline, **config + ) + + @provide_targeted_factory + def delete_pipeline( + self, + pipeline_name: str, + resource_group_name: Optional[str] = None, + factory_name: Optional[str] = None, + **config: Any, + ) -> None: + """ + Delete the pipeline: + + :param pipeline_name: The pipeline name. + :param resource_group_name: The pipeline name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + """ + self.get_conn().pipelines.delete(resource_group_name, factory_name, pipeline_name, **config) + + @provide_targeted_factory + def run_pipeline( + self, + pipeline_name: str, + resource_group_name: Optional[str] = None, + factory_name: Optional[str] = None, + **config: Any, + ) -> CreateRunResponse: + """ + Run a pipeline. + + :param pipeline_name: The pipeline name. + :param resource_group_name: The resource group name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + :return: The pipeline run. + """ + return self.get_conn().pipelines.create_run( + resource_group_name, factory_name, pipeline_name, **config + ) + + @provide_targeted_factory + def get_pipeline_run( + self, + run_id: str, + resource_group_name: Optional[str] = None, + factory_name: Optional[str] = None, + **config: Any, + ) -> PipelineRun: + """ + Get the pipeline run. + + :param run_id: The pipeline run identifier. + :param resource_group_name: The resource group name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + :return: The pipeline run. + """ + return self.get_conn().pipeline_runs.get(resource_group_name, factory_name, run_id, **config) + + @provide_targeted_factory + def cancel_pipeline_run( + self, + run_id: str, + resource_group_name: Optional[str] = None, + factory_name: Optional[str] = None, + **config: Any, + ) -> None: + """ + Cancel the pipeline run. + + :param run_id: The pipeline run identifier. + :param resource_group_name: The resource group name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + """ + self.get_conn().pipeline_runs.cancel(resource_group_name, factory_name, run_id, **config) + + @provide_targeted_factory + def get_trigger( + self, + trigger_name: str, + resource_group_name: Optional[str] = None, + factory_name: Optional[str] = None, + **config: Any, + ) -> TriggerResource: + """ + Get the trigger. + + :param trigger_name: The trigger name. + :param resource_group_name: The resource group name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + :return: The trigger. + """ + return self.get_conn().triggers.get(resource_group_name, factory_name, trigger_name, **config) + + def _trigger_exists(self, resource_group_name, factory_name, trigger_name) -> bool: + """Return whether or not the trigger already exists.""" + triggers = { + trigger.name + for trigger in self.get_conn().triggers.list_by_factory(resource_group_name, factory_name) + } + + return trigger_name in triggers + + @provide_targeted_factory + def update_trigger( + self, + trigger_name: str, + trigger: Trigger, + resource_group_name: Optional[str] = None, + factory_name: Optional[str] = None, + **config: Any, + ) -> TriggerResource: + """ + Update the trigger. + + :param trigger_name: The trigger name. + :param trigger: The trigger resource definition. + :param resource_group_name: The resource group name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + :raise AirflowException: If the trigger does not exist. + :return: The trigger. + """ + if not self._trigger_exists(resource_group_name, factory_name, trigger_name): + raise AirflowException(f"Trigger {trigger_name!r} does not exist.") + + return self.get_conn().triggers.create_or_update( + resource_group_name, factory_name, trigger_name, trigger, **config + ) + + @provide_targeted_factory + def create_trigger( + self, + trigger_name: str, + trigger: Trigger, + resource_group_name: Optional[str] = None, + factory_name: Optional[str] = None, + **config: Any, + ) -> TriggerResource: + """ + Create the trigger. + + :param trigger_name: The trigger name. + :param trigger: The trigger resource definition. + :param resource_group_name: The resource group name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + :raise AirflowException: If the trigger already exists. + :return: The trigger. + """ + if self._trigger_exists(resource_group_name, factory_name, trigger_name): + raise AirflowException(f"Trigger {trigger_name!r} already exists.") + + return self.get_conn().triggers.create_or_update( + resource_group_name, factory_name, trigger_name, trigger, **config + ) + + @provide_targeted_factory + def delete_trigger( + self, + trigger_name: str, + resource_group_name: Optional[str] = None, + factory_name: Optional[str] = None, + **config: Any, + ) -> None: + """ + Delete the trigger. + + :param trigger_name: The trigger name. + :param resource_group_name: The resource group name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + """ + self.get_conn().triggers.delete(resource_group_name, factory_name, trigger_name, **config) + + @provide_targeted_factory + def start_trigger( + self, + trigger_name: str, + resource_group_name: Optional[str] = None, + factory_name: Optional[str] = None, + **config: Any, + ) -> AzureOperationPoller: + """ + Start the trigger. + + :param trigger_name: The trigger name. + :param resource_group_name: The resource group name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + :return: An Azure operation poller. + """ + return self.get_conn().triggers.start(resource_group_name, factory_name, trigger_name, **config) + + @provide_targeted_factory + def stop_trigger( + self, + trigger_name: str, + resource_group_name: Optional[str] = None, + factory_name: Optional[str] = None, + **config: Any, + ) -> AzureOperationPoller: + """ + Stop the trigger. + + :param trigger_name: The trigger name. + :param resource_group_name: The resource group name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + :return: An Azure operation poller. + """ + return self.get_conn().triggers.stop(resource_group_name, factory_name, trigger_name, **config) + + @provide_targeted_factory + def rerun_trigger( + self, + trigger_name: str, + run_id: str, + resource_group_name: Optional[str] = None, + factory_name: Optional[str] = None, + **config: Any, + ) -> None: + """ + Rerun the trigger. + + :param trigger_name: The trigger name. + :param run_id: The trigger run identifier. + :param resource_group_name: The resource group name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + """ + return self.get_conn().trigger_runs.rerun( + resource_group_name, factory_name, trigger_name, run_id, **config + ) + + @provide_targeted_factory + def cancel_trigger( + self, + trigger_name: str, + run_id: str, + resource_group_name: Optional[str] = None, + factory_name: Optional[str] = None, + **config: Any, + ) -> None: + """ + Cancel the trigger. + + :param trigger_name: The trigger name. + :param run_id: The trigger run identifier. + :param resource_group_name: The resource group name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + """ + self.get_conn().trigger_runs.cancel(resource_group_name, factory_name, trigger_name, run_id, **config) diff --git a/airflow/providers/microsoft/azure/provider.yaml b/airflow/providers/microsoft/azure/provider.yaml index fa0d112..da7b330 100644 --- a/airflow/providers/microsoft/azure/provider.yaml +++ b/airflow/providers/microsoft/azure/provider.yaml @@ -54,6 +54,10 @@ integrations: - integration-name: Microsoft Azure FileShare external-doc-url: https://cloud.google.com/storage/ tags: [azure] + - integration-name: Microsoft Azure Data Factory + external-doc-url: https://azure.microsoft.com/en-us/services/data-factory/ + logo: /integration-logos/azure/Azure Data Factory.svg + tags: [azure] - integration-name: Microsoft Azure external-doc-url: https://azure.microsoft.com/ tags: [azure] @@ -113,6 +117,9 @@ hooks: - integration-name: Microsoft Azure Blob Storage python-modules: - airflow.providers.microsoft.azure.hooks.wasb + - integration-name: Microsoft Azure Data Factory + python-modules: + - airflow.providers.microsoft.azure.hooks.azure_data_factory transfers: - source-integration-name: Local @@ -138,3 +145,4 @@ hook-class-names: - airflow.providers.microsoft.azure.hooks.azure_data_lake.AzureDataLakeHook - airflow.providers.microsoft.azure.hooks.azure_container_instance.AzureContainerInstanceHook - airflow.providers.microsoft.azure.hooks.wasb.WasbHook + - airflow.providers.microsoft.azure.hooks.azure_data_factory.AzureDataFactoryHook diff --git a/docs/integration-logos/azure/Azure Data Factory.svg b/docs/integration-logos/azure/Azure Data Factory.svg new file mode 100644 index 0000000..481d3d4 --- /dev/null +++ b/docs/integration-logos/azure/Azure Data Factory.svg @@ -0,0 +1 @@ +<svg id="f9ed9690-6753-43a7-8b32-d66ac7b8a99a" xmlns="http://www.w3.org/2000/svg" width="18" height="18" viewBox="0 0 18 18"><defs><linearGradient id="f710a364-083f-494c-9d96-89b92ee2d5a8" x1="0.5" y1="9.77" x2="9" y2="9.77" gradientUnits="userSpaceOnUse"><stop offset="0" stop-color="#005ba1" /><stop offset="0.07" stop-color="#0060a9" /><stop offset="0.36" stop-color="#0071c8" /><stop offset="0.52" stop-color="#0078d4" /><stop offset="0.64" stop-color="#0074cd" /><stop offset="0.81" stop [...] diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 0e89285..238021e 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -1062,6 +1062,7 @@ png podName podSpec podspec +poller polyfill postMessage postfix diff --git a/setup.py b/setup.py index 4ee7a5c..0846ec9 100644 --- a/setup.py +++ b/setup.py @@ -217,6 +217,7 @@ azure = [ 'azure-keyvault>=4.1.0', 'azure-kusto-data>=0.0.43,<0.1', 'azure-mgmt-containerinstance>=1.5.0,<2.0', + 'azure-mgmt-datafactory>=0.13.0', 'azure-mgmt-datalake-store>=0.5.0', 'azure-mgmt-resource>=2.2.0', 'azure-storage-blob>=12.7.0', diff --git a/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py b/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py new file mode 100644 index 0000000..ea445ec --- /dev/null +++ b/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py @@ -0,0 +1,439 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=redefined-outer-name,unused-argument + +import json +from unittest.mock import MagicMock, Mock + +import pytest +from pytest import fixture + +from airflow.exceptions import AirflowException +from airflow.models.connection import Connection +from airflow.providers.microsoft.azure.hooks.azure_data_factory import ( + AzureDataFactoryHook, + provide_targeted_factory, +) +from airflow.utils import db + +DEFAULT_RESOURCE_GROUP = "defaultResourceGroup" +RESOURCE_GROUP = "testResourceGroup" + +DEFAULT_FACTORY = "defaultFactory" +FACTORY = "testFactory" + +MODEL = object() +NAME = "testName" +ID = "testId" + + +def setup_module(): + connection = Connection( + conn_id="azure_data_factory_test", + conn_type="azure_data_factory", + login="clientId", + password="clientSecret", + extra=json.dumps( + { + "tenantId": "tenantId", + "subscriptionId": "subscriptionId", + "resourceGroup": DEFAULT_RESOURCE_GROUP, + "factory": DEFAULT_FACTORY, + } + ), + ) + + db.merge_conn(connection) + + +@fixture +def hook(): + client = AzureDataFactoryHook(conn_id="azure_data_factory_test") + client._conn = MagicMock( + spec=[ + "factories", + "linked_services", + "datasets", + "pipelines", + "pipeline_runs", + "triggers", + "trigger_runs", + ] + ) + + return client + + +def parametrize(explicit_factory, implicit_factory): + def wrapper(func): + return pytest.mark.parametrize( + ("user_args", "sdk_args"), + (explicit_factory, implicit_factory), + ids=("explicit factory", "implicit factory"), + )(func) + + return wrapper + + +def test_provide_targeted_factory(): + def echo(_, resource_group_name=None, factory_name=None): + return resource_group_name, factory_name + + conn = MagicMock() + hook = MagicMock() + hook.get_connection.return_value = conn + + conn.extra_dejson = {} + assert provide_targeted_factory(echo)(hook, RESOURCE_GROUP, FACTORY) == (RESOURCE_GROUP, FACTORY) + + conn.extra_dejson = {"resourceGroup": DEFAULT_RESOURCE_GROUP, "factory": DEFAULT_FACTORY} + assert provide_targeted_factory(echo)(hook) == (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY) + + with pytest.raises(AirflowException): + conn.extra_dejson = {} + provide_targeted_factory(echo)(hook) + + +@parametrize( + explicit_factory=((RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY)), + implicit_factory=((), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY)), +) +def test_get_factory(hook: AzureDataFactoryHook, user_args, sdk_args): + hook.get_factory(*user_args) + + hook._conn.factories.get.assert_called_with(*sdk_args) + + +@parametrize( + explicit_factory=((MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, MODEL)), + implicit_factory=((MODEL,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, MODEL)), +) +def test_create_factory(hook: AzureDataFactoryHook, user_args, sdk_args): + hook.create_factory(*user_args) + + hook._conn.factories.create_or_update.assert_called_with(*sdk_args) + + +@parametrize( + explicit_factory=((MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, MODEL)), + implicit_factory=((MODEL,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, MODEL)), +) +def test_update_factory(hook: AzureDataFactoryHook, user_args, sdk_args): + hook._factory_exists = Mock(return_value=True) + hook.update_factory(*user_args) + + hook._conn.factories.create_or_update.assert_called_with(*sdk_args) + + +@parametrize( + explicit_factory=((MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, MODEL)), + implicit_factory=((MODEL,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, MODEL)), +) +def test_update_factory_non_existent(hook: AzureDataFactoryHook, user_args, sdk_args): + hook._factory_exists = Mock(return_value=False) + + with pytest.raises(AirflowException, match=r"Factory .+ does not exist"): + hook.update_factory(*user_args) + + +@parametrize( + explicit_factory=((RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY)), + implicit_factory=((), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY)), +) +def test_delete_factory(hook: AzureDataFactoryHook, user_args, sdk_args): + hook.delete_factory(*user_args) + + hook._conn.factories.delete.assert_called_with(*sdk_args) + + +@parametrize( + explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME)), + implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME)), +) +def test_get_linked_service(hook: AzureDataFactoryHook, user_args, sdk_args): + hook.get_linked_service(*user_args) + + hook._conn.linked_services.get.assert_called_with(*sdk_args) + + +@parametrize( + explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, MODEL)), + implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)), +) +def test_create_linked_service(hook: AzureDataFactoryHook, user_args, sdk_args): + hook.create_linked_service(*user_args) + + hook._conn.linked_services.create_or_update(*sdk_args) + + +@parametrize( + explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, MODEL)), + implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)), +) +def test_update_linked_service(hook: AzureDataFactoryHook, user_args, sdk_args): + hook._linked_service_exists = Mock(return_value=True) + hook.update_linked_service(*user_args) + + hook._conn.linked_services.create_or_update(*sdk_args) + + +@parametrize( + explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, MODEL)), + implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)), +) +def test_update_linked_service_non_existent(hook: AzureDataFactoryHook, user_args, sdk_args): + hook._linked_service_exists = Mock(return_value=False) + + with pytest.raises(AirflowException, match=r"Linked service .+ does not exist"): + hook.update_linked_service(*user_args) + + +@parametrize( + explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME)), + implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME)), +) +def test_delete_linked_service(hook: AzureDataFactoryHook, user_args, sdk_args): + hook.delete_linked_service(*user_args) + + hook._conn.linked_services.delete.assert_called_with(*sdk_args) + + +@parametrize( + explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME)), + implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME)), +) +def test_get_dataset(hook: AzureDataFactoryHook, user_args, sdk_args): + hook.get_dataset(*user_args) + + hook._conn.datasets.get.assert_called_with(*sdk_args) + + +@parametrize( + explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, MODEL)), + implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)), +) +def test_create_dataset(hook: AzureDataFactoryHook, user_args, sdk_args): + hook.create_dataset(*user_args) + + hook._conn.datasets.create_or_update.assert_called_with(*sdk_args) + + +@parametrize( + explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, MODEL)), + implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)), +) +def test_update_dataset(hook: AzureDataFactoryHook, user_args, sdk_args): + hook._dataset_exists = Mock(return_value=True) + hook.update_dataset(*user_args) + + hook._conn.datasets.create_or_update.assert_called_with(*sdk_args) + + +@parametrize( + explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, MODEL)), + implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)), +) +def test_update_dataset_non_existent(hook: AzureDataFactoryHook, user_args, sdk_args): + hook._dataset_exists = Mock(return_value=False) + + with pytest.raises(AirflowException, match=r"Dataset .+ does not exist"): + hook.update_dataset(*user_args) + + +@parametrize( + explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME)), + implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME)), +) +def test_delete_dataset(hook: AzureDataFactoryHook, user_args, sdk_args): + hook.delete_dataset(*user_args) + + hook._conn.datasets.delete.assert_called_with(*sdk_args) + + +@parametrize( + explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME)), + implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME)), +) +def test_get_pipeline(hook: AzureDataFactoryHook, user_args, sdk_args): + hook.get_pipeline(*user_args) + + hook._conn.pipelines.get.assert_called_with(*sdk_args) + + +@parametrize( + explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, MODEL)), + implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)), +) +def test_create_pipeline(hook: AzureDataFactoryHook, user_args, sdk_args): + hook.create_pipeline(*user_args) + + hook._conn.pipelines.create_or_update.assert_called_with(*sdk_args) + + +@parametrize( + explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, MODEL)), + implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)), +) +def test_update_pipeline(hook: AzureDataFactoryHook, user_args, sdk_args): + hook._pipeline_exists = Mock(return_value=True) + hook.update_pipeline(*user_args) + + hook._conn.pipelines.create_or_update.assert_called_with(*sdk_args) + + +@parametrize( + explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, MODEL)), + implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)), +) +def test_update_pipeline_non_existent(hook: AzureDataFactoryHook, user_args, sdk_args): + hook._pipeline_exists = Mock(return_value=False) + + with pytest.raises(AirflowException, match=r"Pipeline .+ does not exist"): + hook.update_pipeline(*user_args) + + +@parametrize( + explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME)), + implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME)), +) +def test_delete_pipeline(hook: AzureDataFactoryHook, user_args, sdk_args): + hook.delete_pipeline(*user_args) + + hook._conn.pipelines.delete.assert_called_with(*sdk_args) + + +@parametrize( + explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME)), + implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME)), +) +def test_run_pipeline(hook: AzureDataFactoryHook, user_args, sdk_args): + hook.run_pipeline(*user_args) + + hook._conn.pipelines.create_run.assert_called_with(*sdk_args) + + +@parametrize( + explicit_factory=((ID, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, ID)), + implicit_factory=((ID,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, ID)), +) +def test_get_pipeline_run(hook: AzureDataFactoryHook, user_args, sdk_args): + hook.get_pipeline_run(*user_args) + + hook._conn.pipeline_runs.get.assert_called_with(*sdk_args) + + +@parametrize( + explicit_factory=((ID, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, ID)), + implicit_factory=((ID,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, ID)), +) +def test_cancel_pipeline_run(hook: AzureDataFactoryHook, user_args, sdk_args): + hook.cancel_pipeline_run(*user_args) + + hook._conn.pipeline_runs.cancel.assert_called_with(*sdk_args) + + +@parametrize( + explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME)), + implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME)), +) +def test_get_trigger(hook: AzureDataFactoryHook, user_args, sdk_args): + hook.get_trigger(*user_args) + + hook._conn.triggers.get.assert_called_with(*sdk_args) + + +@parametrize( + explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, MODEL)), + implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)), +) +def test_create_trigger(hook: AzureDataFactoryHook, user_args, sdk_args): + hook.create_trigger(*user_args) + + hook._conn.triggers.create_or_update.assert_called_with(*sdk_args) + + +@parametrize( + explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, MODEL)), + implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)), +) +def test_update_trigger(hook: AzureDataFactoryHook, user_args, sdk_args): + hook._trigger_exists = Mock(return_value=True) + hook.update_trigger(*user_args) + + hook._conn.triggers.create_or_update.assert_called_with(*sdk_args) + + +@parametrize( + explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, MODEL)), + implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)), +) +def test_update_trigger_non_existent(hook: AzureDataFactoryHook, user_args, sdk_args): + hook._trigger_exists = Mock(return_value=False) + + with pytest.raises(AirflowException, match=r"Trigger .+ does not exist"): + hook.update_trigger(*user_args) + + +@parametrize( + explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME)), + implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME)), +) +def test_delete_trigger(hook: AzureDataFactoryHook, user_args, sdk_args): + hook.delete_trigger(*user_args) + + hook._conn.triggers.delete.assert_called_with(*sdk_args) + + +@parametrize( + explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME)), + implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME)), +) +def test_start_trigger(hook: AzureDataFactoryHook, user_args, sdk_args): + hook.start_trigger(*user_args) + + hook._conn.triggers.start.assert_called_with(*sdk_args) + + +@parametrize( + explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME)), + implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME)), +) +def test_stop_trigger(hook: AzureDataFactoryHook, user_args, sdk_args): + hook.stop_trigger(*user_args) + + hook._conn.triggers.stop.assert_called_with(*sdk_args) + + +@parametrize( + explicit_factory=((NAME, ID, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, ID)), + implicit_factory=((NAME, ID), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, ID)), +) +def test_rerun_trigger(hook: AzureDataFactoryHook, user_args, sdk_args): + hook.rerun_trigger(*user_args) + + hook._conn.trigger_runs.rerun.assert_called_with(*sdk_args) + + +@parametrize( + explicit_factory=((NAME, ID, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, ID)), + implicit_factory=((NAME, ID), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, ID)), +) +def test_cancel_trigger(hook: AzureDataFactoryHook, user_args, sdk_args): + hook.cancel_trigger(*user_args) + + hook._conn.trigger_runs.cancel.assert_called_with(*sdk_args)
