This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 1fc286791f Upgrade azure-mgmt-datafactory in microsift azure provider
(#34040)
1fc286791f is described below
commit 1fc286791f54e4b8ac8349c5b93456dd65e62d98
Author: Pankaj Singh <[email protected]>
AuthorDate: Fri Oct 6 13:20:53 2023 +0530
Upgrade azure-mgmt-datafactory in microsift azure provider (#34040)
* Upgrade azure-mgmt-datafactory in microsift azure provider
---
airflow/providers/microsoft/azure/CHANGELOG.rst | 18 +
.../microsoft/azure/hooks/data_factory.py | 275 +++++++--------
.../microsoft/azure/operators/data_factory.py | 26 +-
airflow/providers/microsoft/azure/provider.yaml | 3 +-
.../microsoft/azure/sensors/data_factory.py | 4 +-
.../microsoft/azure/triggers/data_factory.py | 8 +-
docs/spelling_wordlist.txt | 1 +
generated/provider_dependencies.json | 2 +-
.../azure/hooks/test_azure_data_factory.py | 382 +++++++--------------
.../azure/operators/test_azure_data_factory.py | 31 +-
.../azure/sensors/test_azure_data_factory.py | 6 +-
11 files changed, 305 insertions(+), 451 deletions(-)
diff --git a/airflow/providers/microsoft/azure/CHANGELOG.rst
b/airflow/providers/microsoft/azure/CHANGELOG.rst
index f2bdd9f6de..b355c60f81 100644
--- a/airflow/providers/microsoft/azure/CHANGELOG.rst
+++ b/airflow/providers/microsoft/azure/CHANGELOG.rst
@@ -27,6 +27,24 @@
Changelog
---------
+8.0.0
+.....
+
+Breaking changes
+~~~~~~~~~~~~~~~~
+
+.. warning::
+ AzureDataFactoryHook methods and AzureDataFactoryRunPipelineOperator
arguments resource_group_name and factory_name is
+ now required instead of kwargs
+
+* resource_group_name and factory_name is now required argument in
AzureDataFactoryHook method get_factory, update_factory,
+ create_factory, delete_factory, get_linked_service, delete_linked_service,
get_dataset, delete_dataset, get_dataflow,
+ update_dataflow, create_dataflow, delete_dataflow, get_pipeline,
delete_pipeline, run_pipeline, get_pipeline_run,
+ get_trigger, get_pipeline_run_status, cancel_pipeline_run, create_trigger,
delete_trigger, start_trigger,
+ stop_trigger, get_adf_pipeline_run_status, cancel_pipeline_run
+* resource_group_name and factory_name is now required in
AzureDataFactoryRunPipelineOperator
+* Remove class ``PipelineRunInfo`` from
``airflow.providers.microsoft.azure.hooks.data_factory``
+
7.0.0
.....
diff --git a/airflow/providers/microsoft/azure/hooks/data_factory.py
b/airflow/providers/microsoft/azure/hooks/data_factory.py
index b4516ccfd7..656aa54a6a 100644
--- a/airflow/providers/microsoft/azure/hooks/data_factory.py
+++ b/airflow/providers/microsoft/azure/hooks/data_factory.py
@@ -26,6 +26,7 @@
TriggerResource
datafactory
DataFlow
+ DataFlowResource
mgmt
"""
from __future__ import annotations
@@ -34,10 +35,9 @@ import inspect
import time
import warnings
from functools import wraps
-from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union, cast
+from typing import IO, TYPE_CHECKING, Any, Callable, TypeVar, Union, cast
from asgiref.sync import sync_to_async
-from azure.core.exceptions import ServiceRequestError
from azure.identity import ClientSecretCredential, DefaultAzureCredential
from azure.identity.aio import (
ClientSecretCredential as AsyncClientSecretCredential,
@@ -48,13 +48,12 @@ from azure.mgmt.datafactory.aio import
DataFactoryManagementClient as AsyncDataF
from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarning
from airflow.hooks.base import BaseHook
-from airflow.typing_compat import TypedDict
if TYPE_CHECKING:
from azure.core.polling import LROPoller
from azure.mgmt.datafactory.models import (
CreateRunResponse,
- DataFlow,
+ DataFlowResource,
DatasetResource,
Factory,
LinkedServiceResource,
@@ -88,15 +87,9 @@ def provide_targeted_factory(func: Callable) -> Callable:
self = args[0]
conn = self.get_connection(self.conn_id)
extras = conn.extra_dejson
- default_value = extras.get(default_key)
- if not default_value and
extras.get(f"extra__azure_data_factory__{default_key}"):
- warnings.warn(
- f"`extra__azure_data_factory__{default_key}` is
deprecated in azure connection extra,"
- f" please use `{default_key}` instead",
- AirflowProviderDeprecationWarning,
- stacklevel=2,
- )
- default_value =
extras.get(f"extra__azure_data_factory__{default_key}")
+ default_value = extras.get(default_key) or extras.get(
+ f"extra__azure_data_factory__{default_key}"
+ )
if not default_value:
raise AirflowException("Could not determine the targeted
data factory.")
@@ -110,14 +103,6 @@ def provide_targeted_factory(func: Callable) -> Callable:
return wrapper
-class PipelineRunInfo(TypedDict):
- """Type class for the pipeline run info dictionary."""
-
- run_id: str
- factory_name: str | None
- resource_group_name: str | None
-
-
class AzureDataFactoryPipelineRunStatus:
"""Azure Data Factory pipeline operation statuses."""
@@ -127,6 +112,7 @@ class AzureDataFactoryPipelineRunStatus:
FAILED = "Failed"
CANCELING = "Canceling"
CANCELLED = "Cancelled"
+
TERMINAL_STATUSES = {CANCELLED, FAILED, SUCCEEDED}
INTERMEDIATE_STATES = {QUEUED, IN_PROGRESS, CANCELING}
FAILURE_STATES = {FAILED, CANCELLED}
@@ -148,12 +134,6 @@ def get_field(extras: dict, field_name: str, strict: bool
= False):
return extras[field_name] or None
prefixed_name = f"{backcompat_prefix}{field_name}"
if prefixed_name in extras:
- warnings.warn(
- f"`{prefixed_name}` is deprecated in azure connection extra,"
- f" please use `{field_name}` instead",
- AirflowProviderDeprecationWarning,
- stacklevel=2,
- )
return extras[prefixed_name] or None
if strict:
raise KeyError(f"Field {field_name} not found in extras")
@@ -199,7 +179,7 @@ class AzureDataFactoryHook(BaseHook):
}
def __init__(self, azure_data_factory_conn_id: str = default_conn_name):
- self._conn: DataFactoryManagementClient = None
+ self._conn: DataFactoryManagementClient | None = None
self.conn_id = azure_data_factory_conn_id
super().__init__()
@@ -235,9 +215,7 @@ class AzureDataFactoryHook(BaseHook):
return self.get_conn()
@provide_targeted_factory
- def get_factory(
- self, resource_group_name: str | None = None, factory_name: str | None
= None, **config: Any
- ) -> Factory:
+ def get_factory(self, resource_group_name: str, factory_name: str,
**config: Any) -> Factory | None:
"""
Get the factory.
@@ -267,8 +245,9 @@ class AzureDataFactoryHook(BaseHook):
def update_factory(
self,
factory: Factory,
- resource_group_name: str | None = None,
- factory_name: str | None = None,
+ resource_group_name: str,
+ factory_name: str,
+ if_match: str | None = None,
**config: Any,
) -> Factory:
"""
@@ -277,6 +256,8 @@ class AzureDataFactoryHook(BaseHook):
:param factory: The factory resource definition.
:param resource_group_name: The resource group name.
:param factory_name: The factory name.
+ :param if_match: ETag of the factory entity. Should only be specified
for update, for which it
+ should match existing entity or can be * for unconditional update.
Default value is None.
:param config: Extra parameters for the ADF client.
:raise AirflowException: If the factory does not exist.
:return: The factory.
@@ -285,15 +266,15 @@ class AzureDataFactoryHook(BaseHook):
raise AirflowException(f"Factory {factory!r} does not exist.")
return self.get_conn().factories.create_or_update(
- resource_group_name, factory_name, factory, **config
+ resource_group_name, factory_name, factory, if_match, **config
)
@provide_targeted_factory
def create_factory(
self,
factory: Factory,
- resource_group_name: str | None = None,
- factory_name: str | None = None,
+ resource_group_name: str,
+ factory_name: str,
**config: Any,
) -> Factory:
"""
@@ -314,9 +295,7 @@ class AzureDataFactoryHook(BaseHook):
)
@provide_targeted_factory
- def delete_factory(
- self, resource_group_name: str | None = None, factory_name: str | None
= None, **config: Any
- ) -> None:
+ def delete_factory(self, resource_group_name: str, factory_name: str,
**config: Any) -> None:
"""
Delete the factory.
@@ -330,21 +309,25 @@ class AzureDataFactoryHook(BaseHook):
def get_linked_service(
self,
linked_service_name: str,
- resource_group_name: str | None = None,
- factory_name: str | None = None,
+ resource_group_name: str,
+ factory_name: str,
+ if_none_match: str | None = None,
**config: Any,
- ) -> LinkedServiceResource:
+ ) -> LinkedServiceResource | None:
"""
Get the linked service.
:param linked_service_name: The linked service name.
:param resource_group_name: The resource group name.
:param factory_name: The factory name.
+ :param if_none_match: ETag of the linked service entity. Should only
be specified for get. If
+ the ETag matches the existing entity tag, or if * was provided, then
no content will be
+ returned. Default value is None.
:param config: Extra parameters for the ADF client.
:return: The linked service.
"""
return self.get_conn().linked_services.get(
- resource_group_name, factory_name, linked_service_name, **config
+ resource_group_name, factory_name, linked_service_name,
if_none_match, **config
)
def _linked_service_exists(self, resource_group_name, factory_name,
linked_service_name) -> bool:
@@ -363,8 +346,8 @@ class AzureDataFactoryHook(BaseHook):
self,
linked_service_name: str,
linked_service: LinkedServiceResource,
- resource_group_name: str | None = None,
- factory_name: str | None = None,
+ resource_group_name: str,
+ factory_name: str,
**config: Any,
) -> LinkedServiceResource:
"""
@@ -390,8 +373,8 @@ class AzureDataFactoryHook(BaseHook):
self,
linked_service_name: str,
linked_service: LinkedServiceResource,
- resource_group_name: str | None = None,
- factory_name: str | None = None,
+ resource_group_name: str,
+ factory_name: str,
**config: Any,
) -> LinkedServiceResource:
"""
@@ -416,8 +399,8 @@ class AzureDataFactoryHook(BaseHook):
def delete_linked_service(
self,
linked_service_name: str,
- resource_group_name: str | None = None,
- factory_name: str | None = None,
+ resource_group_name: str,
+ factory_name: str,
**config: Any,
) -> None:
"""
@@ -436,10 +419,10 @@ class AzureDataFactoryHook(BaseHook):
def get_dataset(
self,
dataset_name: str,
- resource_group_name: str | None = None,
- factory_name: str | None = None,
+ resource_group_name: str,
+ factory_name: str,
**config: Any,
- ) -> DatasetResource:
+ ) -> DatasetResource | None:
"""
Get the dataset.
@@ -465,8 +448,8 @@ class AzureDataFactoryHook(BaseHook):
self,
dataset_name: str,
dataset: DatasetResource,
- resource_group_name: str | None = None,
- factory_name: str | None = None,
+ resource_group_name: str,
+ factory_name: str,
**config: Any,
) -> DatasetResource:
"""
@@ -492,8 +475,8 @@ class AzureDataFactoryHook(BaseHook):
self,
dataset_name: str,
dataset: DatasetResource,
- resource_group_name: str | None = None,
- factory_name: str | None = None,
+ resource_group_name: str,
+ factory_name: str,
**config: Any,
) -> DatasetResource:
"""
@@ -518,8 +501,8 @@ class AzureDataFactoryHook(BaseHook):
def delete_dataset(
self,
dataset_name: str,
- resource_group_name: str | None = None,
- factory_name: str | None = None,
+ resource_group_name: str,
+ factory_name: str,
**config: Any,
) -> None:
"""
@@ -536,26 +519,32 @@ class AzureDataFactoryHook(BaseHook):
def get_dataflow(
self,
dataflow_name: str,
- resource_group_name: str | None = None,
- factory_name: str | None = None,
+ resource_group_name: str,
+ factory_name: str,
+ if_none_match: str | None = None,
**config: Any,
- ) -> DataFlow:
+ ) -> DataFlowResource:
"""
Get the dataflow.
:param dataflow_name: The dataflow name.
:param resource_group_name: The resource group name.
:param factory_name: The factory name.
+ :param if_none_match: ETag of the data flow entity. Should only be
specified for get. If the
+ ETag matches the existing entity tag, or if * was provided, then no
content will be returned.
+ Default value is None.
:param config: Extra parameters for the ADF client.
- :return: The dataflow.
+ :return: The DataFlowResource.
"""
- return self.get_conn().data_flows.get(resource_group_name,
factory_name, dataflow_name, **config)
+ return self.get_conn().data_flows.get(
+ resource_group_name, factory_name, dataflow_name, if_none_match,
**config
+ )
def _dataflow_exists(
self,
dataflow_name: str,
- resource_group_name: str | None = None,
- factory_name: str | None = None,
+ resource_group_name: str,
+ factory_name: str,
) -> bool:
"""Return whether the dataflow already exists."""
dataflows = {
@@ -569,11 +558,12 @@ class AzureDataFactoryHook(BaseHook):
def update_dataflow(
self,
dataflow_name: str,
- dataflow: DataFlow,
- resource_group_name: str | None = None,
- factory_name: str | None = None,
+ dataflow: DataFlowResource | IO,
+ resource_group_name: str,
+ factory_name: str,
+ if_match: str | None = None,
**config: Any,
- ) -> DataFlow:
+ ) -> DataFlowResource:
"""
Update the dataflow.
@@ -581,9 +571,11 @@ class AzureDataFactoryHook(BaseHook):
:param dataflow: The dataflow resource definition.
:param resource_group_name: The resource group name.
:param factory_name: The factory name.
+ :param if_match: ETag of the data flow entity. Should only be
specified for update, for which
+ it should match existing entity or can be * for unconditional update.
Default value is None.
:param config: Extra parameters for the ADF client.
:raise AirflowException: If the dataset does not exist.
- :return: The dataflow.
+ :return: DataFlowResource.
"""
if not self._dataflow_exists(
dataflow_name,
@@ -593,18 +585,19 @@ class AzureDataFactoryHook(BaseHook):
raise AirflowException(f"Dataflow {dataflow_name!r} does not
exist.")
return self.get_conn().data_flows.create_or_update(
- resource_group_name, factory_name, dataflow_name, dataflow,
**config
+ resource_group_name, factory_name, dataflow_name, dataflow,
if_match, **config
)
@provide_targeted_factory
def create_dataflow(
self,
dataflow_name: str,
- dataflow: DataFlow,
- resource_group_name: str | None = None,
- factory_name: str | None = None,
+ dataflow: DataFlowResource,
+ resource_group_name: str,
+ factory_name: str,
+ if_match: str | None = None,
**config: Any,
- ) -> DataFlow:
+ ) -> DataFlowResource:
"""
Create the dataflow.
@@ -612,6 +605,8 @@ class AzureDataFactoryHook(BaseHook):
:param dataflow: The dataflow resource definition.
:param resource_group_name: The resource group name.
:param factory_name: The factory name.
+ :param if_match: ETag of the factory entity. Should only be specified
for update, for which it
+ should match existing entity or can be * for unconditional update.
Default value is None.
:param config: Extra parameters for the ADF client.
:raise AirflowException: If the dataset already exists.
:return: The dataset.
@@ -620,15 +615,15 @@ class AzureDataFactoryHook(BaseHook):
raise AirflowException(f"Dataflow {dataflow_name!r} already
exists.")
return self.get_conn().data_flows.create_or_update(
- resource_group_name, factory_name, dataflow_name, dataflow,
**config
+ resource_group_name, factory_name, dataflow_name, dataflow,
if_match, **config
)
@provide_targeted_factory
def delete_dataflow(
self,
dataflow_name: str,
- resource_group_name: str | None = None,
- factory_name: str | None = None,
+ resource_group_name: str,
+ factory_name: str,
**config: Any,
) -> None:
"""
@@ -645,10 +640,10 @@ class AzureDataFactoryHook(BaseHook):
def get_pipeline(
self,
pipeline_name: str,
- resource_group_name: str | None = None,
- factory_name: str | None = None,
+ resource_group_name: str,
+ factory_name: str,
**config: Any,
- ) -> PipelineResource:
+ ) -> PipelineResource | None:
"""
Get the pipeline.
@@ -674,8 +669,8 @@ class AzureDataFactoryHook(BaseHook):
self,
pipeline_name: str,
pipeline: PipelineResource,
- resource_group_name: str | None = None,
- factory_name: str | None = None,
+ resource_group_name: str,
+ factory_name: str,
**config: Any,
) -> PipelineResource:
"""
@@ -701,8 +696,8 @@ class AzureDataFactoryHook(BaseHook):
self,
pipeline_name: str,
pipeline: PipelineResource,
- resource_group_name: str | None = None,
- factory_name: str | None = None,
+ resource_group_name: str,
+ factory_name: str,
**config: Any,
) -> PipelineResource:
"""
@@ -727,8 +722,8 @@ class AzureDataFactoryHook(BaseHook):
def delete_pipeline(
self,
pipeline_name: str,
- resource_group_name: str | None = None,
- factory_name: str | None = None,
+ resource_group_name: str,
+ factory_name: str,
**config: Any,
) -> None:
"""
@@ -745,8 +740,8 @@ class AzureDataFactoryHook(BaseHook):
def run_pipeline(
self,
pipeline_name: str,
- resource_group_name: str | None = None,
- factory_name: str | None = None,
+ resource_group_name: str,
+ factory_name: str,
**config: Any,
) -> CreateRunResponse:
"""
@@ -766,8 +761,8 @@ class AzureDataFactoryHook(BaseHook):
def get_pipeline_run(
self,
run_id: str,
- resource_group_name: str | None = None,
- factory_name: str | None = None,
+ resource_group_name: str,
+ factory_name: str,
**config: Any,
) -> PipelineRun:
"""
@@ -784,8 +779,8 @@ class AzureDataFactoryHook(BaseHook):
def get_pipeline_run_status(
self,
run_id: str,
- resource_group_name: str | None = None,
- factory_name: str | None = None,
+ resource_group_name: str,
+ factory_name: str,
) -> str:
"""
Get a pipeline run's current status.
@@ -796,11 +791,7 @@ class AzureDataFactoryHook(BaseHook):
:return: The status of the pipeline run.
"""
self.log.info("Getting the status of run ID %s.", run_id)
- pipeline_run_status = self.get_pipeline_run(
- run_id=run_id,
- factory_name=factory_name,
- resource_group_name=resource_group_name,
- ).status
+ pipeline_run_status = self.get_pipeline_run(run_id,
resource_group_name, factory_name).status
self.log.info("Current status of pipeline run %s: %s", run_id,
pipeline_run_status)
return pipeline_run_status
@@ -809,8 +800,8 @@ class AzureDataFactoryHook(BaseHook):
self,
run_id: str,
expected_statuses: str | set[str],
- resource_group_name: str | None = None,
- factory_name: str | None = None,
+ resource_group_name: str,
+ factory_name: str,
check_interval: int = 60,
timeout: int = 60 * 60 * 24 * 7,
) -> bool:
@@ -826,13 +817,7 @@ class AzureDataFactoryHook(BaseHook):
status.
:return: Boolean indicating if the pipeline run has reached the
``expected_status``.
"""
- pipeline_run_info = PipelineRunInfo(
- run_id=run_id,
- factory_name=factory_name,
- resource_group_name=resource_group_name,
- )
- pipeline_run_status = self.get_pipeline_run_status(**pipeline_run_info)
- executed_after_token_refresh = True
+ pipeline_run_status = self.get_pipeline_run_status(run_id,
resource_group_name, factory_name)
start_time = time.monotonic()
@@ -849,14 +834,7 @@ class AzureDataFactoryHook(BaseHook):
# Wait to check the status of the pipeline run based on the
``check_interval`` configured.
time.sleep(check_interval)
- try:
- pipeline_run_status =
self.get_pipeline_run_status(**pipeline_run_info)
- executed_after_token_refresh = True
- except ServiceRequestError:
- if executed_after_token_refresh:
- self.refresh_conn()
- else:
- raise
+ pipeline_run_status = self.get_pipeline_run_status(run_id,
resource_group_name, factory_name)
return pipeline_run_status in expected_statuses
@@ -864,8 +842,8 @@ class AzureDataFactoryHook(BaseHook):
def cancel_pipeline_run(
self,
run_id: str,
- resource_group_name: str | None = None,
- factory_name: str | None = None,
+ resource_group_name: str,
+ factory_name: str,
**config: Any,
) -> None:
"""
@@ -882,10 +860,10 @@ class AzureDataFactoryHook(BaseHook):
def get_trigger(
self,
trigger_name: str,
- resource_group_name: str | None = None,
- factory_name: str | None = None,
+ resource_group_name: str,
+ factory_name: str,
**config: Any,
- ) -> TriggerResource:
+ ) -> TriggerResource | None:
"""
Get the trigger.
@@ -911,8 +889,9 @@ class AzureDataFactoryHook(BaseHook):
self,
trigger_name: str,
trigger: TriggerResource,
- resource_group_name: str | None = None,
- factory_name: str | None = None,
+ resource_group_name: str,
+ factory_name: str,
+ if_match: str | None = None,
**config: Any,
) -> TriggerResource:
"""
@@ -922,6 +901,8 @@ class AzureDataFactoryHook(BaseHook):
:param trigger: The trigger resource definition.
:param resource_group_name: The resource group name.
:param factory_name: The factory name.
+ :param if_match: ETag of the trigger entity. Should only be specified
for update, for which it
+ should match existing entity or can be * for unconditional update.
Default value is None.
:param config: Extra parameters for the ADF client.
:raise AirflowException: If the trigger does not exist.
:return: The trigger.
@@ -930,7 +911,7 @@ class AzureDataFactoryHook(BaseHook):
raise AirflowException(f"Trigger {trigger_name!r} does not exist.")
return self.get_conn().triggers.create_or_update(
- resource_group_name, factory_name, trigger_name, trigger, **config
+ resource_group_name, factory_name, trigger_name, trigger,
if_match, **config
)
@provide_targeted_factory
@@ -938,8 +919,8 @@ class AzureDataFactoryHook(BaseHook):
self,
trigger_name: str,
trigger: TriggerResource,
- resource_group_name: str | None = None,
- factory_name: str | None = None,
+ resource_group_name: str,
+ factory_name: str,
**config: Any,
) -> TriggerResource:
"""
@@ -964,8 +945,8 @@ class AzureDataFactoryHook(BaseHook):
def delete_trigger(
self,
trigger_name: str,
- resource_group_name: str | None = None,
- factory_name: str | None = None,
+ resource_group_name: str,
+ factory_name: str,
**config: Any,
) -> None:
"""
@@ -982,8 +963,8 @@ class AzureDataFactoryHook(BaseHook):
def start_trigger(
self,
trigger_name: str,
- resource_group_name: str | None = None,
- factory_name: str | None = None,
+ resource_group_name: str,
+ factory_name: str,
**config: Any,
) -> LROPoller:
"""
@@ -1001,8 +982,8 @@ class AzureDataFactoryHook(BaseHook):
def stop_trigger(
self,
trigger_name: str,
- resource_group_name: str | None = None,
- factory_name: str | None = None,
+ resource_group_name: str,
+ factory_name: str,
**config: Any,
) -> LROPoller:
"""
@@ -1021,8 +1002,8 @@ class AzureDataFactoryHook(BaseHook):
self,
trigger_name: str,
run_id: str,
- resource_group_name: str | None = None,
- factory_name: str | None = None,
+ resource_group_name: str,
+ factory_name: str,
**config: Any,
) -> None:
"""
@@ -1043,8 +1024,8 @@ class AzureDataFactoryHook(BaseHook):
self,
trigger_name: str,
run_id: str,
- resource_group_name: str | None = None,
- factory_name: str | None = None,
+ resource_group_name: str,
+ factory_name: str,
**config: Any,
) -> None:
"""
@@ -1068,7 +1049,7 @@ class AzureDataFactoryHook(BaseHook):
# DataFactoryManagementClient with incorrect values but then will
fail properly once items are
# retrieved using the client. We need to _actually_ try to
retrieve an object to properly test the
# connection.
- next(self.get_conn().factories.list())
+ self.get_conn().factories.list()
return success
except StopIteration:
# If the iterator returned is empty it should still be considered
a successful connection since
@@ -1132,7 +1113,7 @@ class AzureDataFactoryAsyncHook(AzureDataFactoryHook):
default_conn_name: str = "azure_data_factory_default"
def __init__(self, azure_data_factory_conn_id: str = default_conn_name):
- self._async_conn: AsyncDataFactoryManagementClient = None
+ self._async_conn: AsyncDataFactoryManagementClient | None = None
self.conn_id = azure_data_factory_conn_id
super().__init__(azure_data_factory_conn_id=azure_data_factory_conn_id)
@@ -1168,7 +1149,7 @@ class AzureDataFactoryAsyncHook(AzureDataFactoryHook):
return self._async_conn
- async def refresh_conn(self) -> AsyncDataFactoryManagementClient:
+ async def refresh_conn(self) -> AsyncDataFactoryManagementClient: # type:
ignore[override]
self._conn = None
return await self.get_async_conn()
@@ -1176,8 +1157,8 @@ class AzureDataFactoryAsyncHook(AzureDataFactoryHook):
async def get_pipeline_run(
self,
run_id: str,
- resource_group_name: str | None = None,
- factory_name: str | None = None,
+ resource_group_name: str,
+ factory_name: str,
**config: Any,
) -> PipelineRun:
"""
@@ -1193,7 +1174,7 @@ class AzureDataFactoryAsyncHook(AzureDataFactoryHook):
return pipeline_run
async def get_adf_pipeline_run_status(
- self, run_id: str, resource_group_name: str | None = None,
factory_name: str | None = None
+ self, run_id: str, resource_group_name: str, factory_name: str
) -> str:
"""
Connect to Azure Data Factory asynchronously and get the pipeline
status by run_id.
@@ -1202,20 +1183,16 @@ class AzureDataFactoryAsyncHook(AzureDataFactoryHook):
:param resource_group_name: The resource group name.
:param factory_name: The factory name.
"""
- pipeline_run = await self.get_pipeline_run(
- run_id=run_id,
- factory_name=factory_name,
- resource_group_name=resource_group_name,
- )
- status: str = pipeline_run.status
+ pipeline_run = await self.get_pipeline_run(run_id,
resource_group_name, factory_name)
+ status: str = cast(str, pipeline_run.status)
return status
@provide_targeted_factory_async
async def cancel_pipeline_run(
self,
run_id: str,
- resource_group_name: str | None = None,
- factory_name: str | None = None,
+ resource_group_name: str,
+ factory_name: str,
**config: Any,
) -> None:
"""
diff --git a/airflow/providers/microsoft/azure/operators/data_factory.py
b/airflow/providers/microsoft/azure/operators/data_factory.py
index 12962e5610..2aac723f86 100644
--- a/airflow/providers/microsoft/azure/operators/data_factory.py
+++ b/airflow/providers/microsoft/azure/operators/data_factory.py
@@ -29,7 +29,6 @@ from airflow.providers.microsoft.azure.hooks.data_factory
import (
AzureDataFactoryHook,
AzureDataFactoryPipelineRunException,
AzureDataFactoryPipelineRunStatus,
- PipelineRunInfo,
get_field,
)
from airflow.providers.microsoft.azure.triggers.data_factory import
AzureDataFactoryTrigger
@@ -132,9 +131,9 @@ class AzureDataFactoryRunPipelineOperator(BaseOperator):
*,
pipeline_name: str,
azure_data_factory_conn_id: str =
AzureDataFactoryHook.default_conn_name,
+ resource_group_name: str,
+ factory_name: str,
wait_for_termination: bool = True,
- resource_group_name: str | None = None,
- factory_name: str | None = None,
reference_pipeline_run_id: str | None = None,
is_recovery: bool | None = None,
start_activity_name: str | None = None,
@@ -168,9 +167,9 @@ class AzureDataFactoryRunPipelineOperator(BaseOperator):
def execute(self, context: Context) -> None:
self.log.info("Executing the %s pipeline.", self.pipeline_name)
response = self.hook.run_pipeline(
- pipeline_name=self.pipeline_name,
- resource_group_name=self.resource_group_name,
- factory_name=self.factory_name,
+ self.pipeline_name,
+ self.resource_group_name,
+ self.factory_name,
reference_pipeline_run_id=self.reference_pipeline_run_id,
is_recovery=self.is_recovery,
start_activity_name=self.start_activity_name,
@@ -188,12 +187,12 @@ class AzureDataFactoryRunPipelineOperator(BaseOperator):
self.log.info("Waiting for pipeline run %s to terminate.",
self.run_id)
if self.hook.wait_for_pipeline_run_status(
- run_id=self.run_id,
-
expected_statuses=AzureDataFactoryPipelineRunStatus.SUCCEEDED,
+ self.run_id,
+ AzureDataFactoryPipelineRunStatus.SUCCEEDED,
+ self.resource_group_name,
+ self.factory_name,
check_interval=self.check_interval,
timeout=self.timeout,
- resource_group_name=self.resource_group_name,
- factory_name=self.factory_name,
):
self.log.info("Pipeline run %s has completed
successfully.", self.run_id)
else:
@@ -202,12 +201,9 @@ class AzureDataFactoryRunPipelineOperator(BaseOperator):
)
else:
end_time = time.time() + self.timeout
- pipeline_run_info = PipelineRunInfo(
- run_id=self.run_id,
- factory_name=self.factory_name,
- resource_group_name=self.resource_group_name,
+ pipeline_run_status = self.hook.get_pipeline_run_status(
+ self.run_id, self.resource_group_name, self.factory_name
)
- pipeline_run_status =
self.hook.get_pipeline_run_status(**pipeline_run_info)
if pipeline_run_status not in
AzureDataFactoryPipelineRunStatus.TERMINAL_STATUSES:
self.defer(
timeout=self.execution_timeout,
diff --git a/airflow/providers/microsoft/azure/provider.yaml
b/airflow/providers/microsoft/azure/provider.yaml
index e2822f3822..d01ee268ef 100644
--- a/airflow/providers/microsoft/azure/provider.yaml
+++ b/airflow/providers/microsoft/azure/provider.yaml
@@ -21,6 +21,7 @@ description: |
`Microsoft Azure <https://azure.microsoft.com/>`__
suspended: false
versions:
+ - 8.0.0
- 7.0.0
- 6.3.0
- 6.2.4
@@ -81,11 +82,11 @@ dependencies:
- adal>=1.2.7
- azure-storage-file-datalake>=12.9.1
- azure-kusto-data>=4.1.0
+ - azure-mgmt-datafactory>=2.0.0
- azure-mgmt-containerregistry>=8.0.0
# TODO: upgrade to newer versions of all the below libraries.
# See issue https://github.com/apache/airflow/issues/30199
- azure-mgmt-containerinstance>=7.0.0,<9.0.0
- - azure-mgmt-datafactory>=1.0.0,<2.0
integrations:
- integration-name: Microsoft Azure Batch
diff --git a/airflow/providers/microsoft/azure/sensors/data_factory.py
b/airflow/providers/microsoft/azure/sensors/data_factory.py
index 5cede76ad9..f5ca765e85 100644
--- a/airflow/providers/microsoft/azure/sensors/data_factory.py
+++ b/airflow/providers/microsoft/azure/sensors/data_factory.py
@@ -59,8 +59,8 @@ class
AzureDataFactoryPipelineRunStatusSensor(BaseSensorOperator):
*,
run_id: str,
azure_data_factory_conn_id: str =
AzureDataFactoryHook.default_conn_name,
- resource_group_name: str | None = None,
- factory_name: str | None = None,
+ resource_group_name: str,
+ factory_name: str,
deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
**kwargs,
) -> None:
diff --git a/airflow/providers/microsoft/azure/triggers/data_factory.py
b/airflow/providers/microsoft/azure/triggers/data_factory.py
index e087e3556d..b550c2b1d1 100644
--- a/airflow/providers/microsoft/azure/triggers/data_factory.py
+++ b/airflow/providers/microsoft/azure/triggers/data_factory.py
@@ -44,8 +44,8 @@ class ADFPipelineRunStatusSensorTrigger(BaseTrigger):
run_id: str,
azure_data_factory_conn_id: str,
poke_interval: float,
- resource_group_name: str | None = None,
- factory_name: str | None = None,
+ resource_group_name: str,
+ factory_name: str,
):
super().__init__()
self.run_id = run_id
@@ -128,8 +128,8 @@ class AzureDataFactoryTrigger(BaseTrigger):
run_id: str,
azure_data_factory_conn_id: str,
end_time: float,
- resource_group_name: str | None = None,
- factory_name: str | None = None,
+ resource_group_name: str,
+ factory_name: str,
wait_for_termination: bool = True,
check_interval: int = 60,
):
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index 549492a113..14d7d52dd0 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -11,6 +11,7 @@ actionCard
Acyclic
acyclic
AddressesType
+adf
adhoc
adls
adobjects
diff --git a/generated/provider_dependencies.json
b/generated/provider_dependencies.json
index 028caf09ae..212f308989 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -558,7 +558,7 @@
"azure-mgmt-containerinstance>=7.0.0,<9.0.0",
"azure-mgmt-containerregistry>=8.0.0",
"azure-mgmt-cosmosdb",
- "azure-mgmt-datafactory>=1.0.0,<2.0",
+ "azure-mgmt-datafactory>=2.0.0",
"azure-mgmt-datalake-store>=0.5.0",
"azure-mgmt-resource>=2.2.0",
"azure-mgmt-storage>=16.0.0",
diff --git a/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py
b/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py
index 508bc2ee78..34d0bcb481 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py
@@ -183,301 +183,186 @@ def
test_get_connection_by_credential_client_secret(connection_id: str, credenti
assert mock_create_client.call_args.args[1] == "subscriptionId"
-@parametrize(
- explicit_factory=((RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY)),
- implicit_factory=((), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY)),
-)
-def test_get_factory(hook: AzureDataFactoryHook, user_args, sdk_args):
- hook.get_factory(*user_args)
+def test_get_factory(hook: AzureDataFactoryHook):
+ hook.get_factory(RESOURCE_GROUP, FACTORY)
- hook._conn.factories.get.assert_called_with(*sdk_args)
+ hook._conn.factories.get.assert_called_with(RESOURCE_GROUP, FACTORY)
-@parametrize(
- explicit_factory=((MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP,
FACTORY, MODEL)),
- implicit_factory=((MODEL,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY,
MODEL)),
-)
-def test_create_factory(hook: AzureDataFactoryHook, user_args, sdk_args):
- hook.create_factory(*user_args)
+def test_create_factory(hook: AzureDataFactoryHook):
+ hook.create_factory(MODEL, RESOURCE_GROUP, FACTORY)
- hook._conn.factories.create_or_update.assert_called_with(*sdk_args)
+ hook._conn.factories.create_or_update.assert_called_with(RESOURCE_GROUP,
FACTORY, MODEL)
-@parametrize(
- explicit_factory=((MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP,
FACTORY, MODEL)),
- implicit_factory=((MODEL,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY,
MODEL)),
-)
-def test_update_factory(hook: AzureDataFactoryHook, user_args, sdk_args):
+def test_update_factory(hook: AzureDataFactoryHook):
with patch.object(hook, "_factory_exists") as mock_factory_exists:
mock_factory_exists.return_value = True
- hook.update_factory(*user_args)
+ hook.update_factory(MODEL, RESOURCE_GROUP, FACTORY)
- hook._conn.factories.create_or_update.assert_called_with(*sdk_args)
+ hook._conn.factories.create_or_update.assert_called_with(RESOURCE_GROUP,
FACTORY, MODEL, None)
-@parametrize(
- explicit_factory=((MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP,
FACTORY, MODEL)),
- implicit_factory=((MODEL,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY,
MODEL)),
-)
-def test_update_factory_non_existent(hook: AzureDataFactoryHook, user_args,
sdk_args):
+def test_update_factory_non_existent(hook: AzureDataFactoryHook):
with patch.object(hook, "_factory_exists") as mock_factory_exists:
mock_factory_exists.return_value = False
with pytest.raises(AirflowException, match=r"Factory .+ does not exist"):
- hook.update_factory(*user_args)
+ hook.update_factory(MODEL, RESOURCE_GROUP, FACTORY)
-@parametrize(
- explicit_factory=((RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY)),
- implicit_factory=((), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY)),
-)
-def test_delete_factory(hook: AzureDataFactoryHook, user_args, sdk_args):
- hook.delete_factory(*user_args)
+def test_delete_factory(hook: AzureDataFactoryHook):
+ hook.delete_factory(RESOURCE_GROUP, FACTORY)
- hook._conn.factories.delete.assert_called_with(*sdk_args)
+ hook._conn.factories.delete.assert_called_with(RESOURCE_GROUP, FACTORY)
-@parametrize(
- explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP,
FACTORY, NAME)),
- implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY,
NAME)),
-)
-def test_get_linked_service(hook: AzureDataFactoryHook, user_args, sdk_args):
- hook.get_linked_service(*user_args)
+def test_get_linked_service(hook: AzureDataFactoryHook):
+ hook.get_linked_service(NAME, RESOURCE_GROUP, FACTORY)
- hook._conn.linked_services.get.assert_called_with(*sdk_args)
+ hook._conn.linked_services.get.assert_called_with(RESOURCE_GROUP, FACTORY,
NAME, None)
-@parametrize(
- explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP,
FACTORY, NAME, MODEL)),
- implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY,
NAME, MODEL)),
-)
-def test_create_linked_service(hook: AzureDataFactoryHook, user_args,
sdk_args):
- hook.create_linked_service(*user_args)
+def test_create_linked_service(hook: AzureDataFactoryHook):
+ hook.create_linked_service(NAME, MODEL, RESOURCE_GROUP, FACTORY)
- hook._conn.linked_services.create_or_update(*sdk_args)
+ hook._conn.linked_services.create_or_update(RESOURCE_GROUP, FACTORY, NAME,
MODEL)
-@parametrize(
- explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP,
FACTORY, NAME, MODEL)),
- implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY,
NAME, MODEL)),
-)
-def test_update_linked_service(hook: AzureDataFactoryHook, user_args,
sdk_args):
+def test_update_linked_service(hook: AzureDataFactoryHook):
with patch.object(hook, "_linked_service_exists") as
mock_linked_service_exists:
mock_linked_service_exists.return_value = True
- hook.update_linked_service(*user_args)
+ hook.update_linked_service(NAME, MODEL, RESOURCE_GROUP, FACTORY)
- hook._conn.linked_services.create_or_update(*sdk_args)
+ hook._conn.linked_services.create_or_update(RESOURCE_GROUP, FACTORY, NAME,
MODEL)
-@parametrize(
- explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP,
FACTORY, NAME, MODEL)),
- implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY,
NAME, MODEL)),
-)
-def test_update_linked_service_non_existent(hook: AzureDataFactoryHook,
user_args, sdk_args):
+def test_update_linked_service_non_existent(hook: AzureDataFactoryHook):
with patch.object(hook, "_linked_service_exists") as
mock_linked_service_exists:
mock_linked_service_exists.return_value = False
with pytest.raises(AirflowException, match=r"Linked service .+ does not
exist"):
- hook.update_linked_service(*user_args)
+ hook.update_linked_service(NAME, MODEL, RESOURCE_GROUP, FACTORY)
-@parametrize(
- explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP,
FACTORY, NAME)),
- implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY,
NAME)),
-)
-def test_delete_linked_service(hook: AzureDataFactoryHook, user_args,
sdk_args):
- hook.delete_linked_service(*user_args)
+def test_delete_linked_service(hook: AzureDataFactoryHook):
+ hook.delete_linked_service(NAME, RESOURCE_GROUP, FACTORY)
- hook._conn.linked_services.delete.assert_called_with(*sdk_args)
+ hook._conn.linked_services.delete.assert_called_with(RESOURCE_GROUP,
FACTORY, NAME)
-@parametrize(
- explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP,
FACTORY, NAME)),
- implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY,
NAME)),
-)
-def test_get_dataset(hook: AzureDataFactoryHook, user_args, sdk_args):
- hook.get_dataset(*user_args)
+def test_get_dataset(hook: AzureDataFactoryHook):
+ hook.get_dataset(NAME, RESOURCE_GROUP, FACTORY)
- hook._conn.datasets.get.assert_called_with(*sdk_args)
+ hook._conn.datasets.get.assert_called_with(RESOURCE_GROUP, FACTORY, NAME)
-@parametrize(
- explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP,
FACTORY, NAME, MODEL)),
- implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY,
NAME, MODEL)),
-)
-def test_create_dataset(hook: AzureDataFactoryHook, user_args, sdk_args):
- hook.create_dataset(*user_args)
+def test_create_dataset(hook: AzureDataFactoryHook):
+ hook.create_dataset(NAME, MODEL, RESOURCE_GROUP, FACTORY)
- hook._conn.datasets.create_or_update.assert_called_with(*sdk_args)
+ hook._conn.datasets.create_or_update.assert_called_with(RESOURCE_GROUP,
FACTORY, NAME, MODEL)
-@parametrize(
- explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP,
FACTORY, NAME, MODEL)),
- implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY,
NAME, MODEL)),
-)
-def test_update_dataset(hook: AzureDataFactoryHook, user_args, sdk_args):
+def test_update_dataset(hook: AzureDataFactoryHook):
with patch.object(hook, "_dataset_exists") as mock_dataset_exists:
mock_dataset_exists.return_value = True
- hook.update_dataset(*user_args)
+ hook.update_dataset(NAME, MODEL, RESOURCE_GROUP, FACTORY)
- hook._conn.datasets.create_or_update.assert_called_with(*sdk_args)
+ hook._conn.datasets.create_or_update.assert_called_with(RESOURCE_GROUP,
FACTORY, NAME, MODEL)
-@parametrize(
- explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP,
FACTORY, NAME, MODEL)),
- implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY,
NAME, MODEL)),
-)
-def test_update_dataset_non_existent(hook: AzureDataFactoryHook, user_args,
sdk_args):
+def test_update_dataset_non_existent(hook: AzureDataFactoryHook):
with patch.object(hook, "_dataset_exists") as mock_dataset_exists:
mock_dataset_exists.return_value = False
with pytest.raises(AirflowException, match=r"Dataset .+ does not exist"):
- hook.update_dataset(*user_args)
+ hook.update_dataset(NAME, MODEL, RESOURCE_GROUP, FACTORY)
-@parametrize(
- explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP,
FACTORY, NAME)),
- implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY,
NAME)),
-)
-def test_delete_dataset(hook: AzureDataFactoryHook, user_args, sdk_args):
- hook.delete_dataset(*user_args)
+def test_delete_dataset(hook: AzureDataFactoryHook):
+ hook.delete_dataset(NAME, RESOURCE_GROUP, FACTORY)
- hook._conn.datasets.delete.assert_called_with(*sdk_args)
+ hook._conn.datasets.delete.assert_called_with(RESOURCE_GROUP, FACTORY,
NAME)
-@parametrize(
- explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP,
FACTORY, NAME)),
- implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY,
NAME)),
-)
-def test_get_dataflow(hook: AzureDataFactoryHook, user_args, sdk_args):
- hook.get_dataflow(*user_args)
+def test_get_dataflow(hook: AzureDataFactoryHook):
+ hook.get_dataflow(NAME, RESOURCE_GROUP, FACTORY)
- hook._conn.data_flows.get.assert_called_with(*sdk_args)
+ hook._conn.data_flows.get.assert_called_with(RESOURCE_GROUP, FACTORY,
NAME, None)
-@parametrize(
- explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP,
FACTORY, NAME, MODEL)),
- implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY,
NAME, MODEL)),
-)
-def test_create_dataflow(hook: AzureDataFactoryHook, user_args, sdk_args):
- hook.create_dataflow(*user_args)
+def test_create_dataflow(hook: AzureDataFactoryHook):
+ hook.create_dataflow(NAME, MODEL, RESOURCE_GROUP, FACTORY)
- hook._conn.data_flows.create_or_update.assert_called_with(*sdk_args)
+ hook._conn.data_flows.create_or_update.assert_called_with(RESOURCE_GROUP,
FACTORY, NAME, MODEL, None)
-@parametrize(
- explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP,
FACTORY, NAME, MODEL)),
- implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY,
NAME, MODEL)),
-)
-def test_update_dataflow(hook: AzureDataFactoryHook, user_args, sdk_args):
+def test_update_dataflow(hook: AzureDataFactoryHook):
with patch.object(hook, "_dataflow_exists") as mock_dataflow_exists:
mock_dataflow_exists.return_value = True
- hook.update_dataflow(*user_args)
+ hook.update_dataflow(NAME, MODEL, RESOURCE_GROUP, FACTORY)
- hook._conn.data_flows.create_or_update.assert_called_with(*sdk_args)
+ hook._conn.data_flows.create_or_update.assert_called_with(RESOURCE_GROUP,
FACTORY, NAME, MODEL, None)
-@parametrize(
- explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP,
FACTORY, NAME, MODEL)),
- implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY,
NAME, MODEL)),
-)
-def test_update_dataflow_non_existent(hook: AzureDataFactoryHook, user_args,
sdk_args):
+def test_update_dataflow_non_existent(hook: AzureDataFactoryHook):
with patch.object(hook, "_dataflow_exists") as mock_dataflow_exists:
mock_dataflow_exists.return_value = False
with pytest.raises(AirflowException, match=r"Dataflow .+ does not exist"):
- hook.update_dataflow(*user_args)
+ hook.update_dataflow(NAME, MODEL, RESOURCE_GROUP, FACTORY)
-@parametrize(
- explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP,
FACTORY, NAME)),
- implicit_factory=(
- (NAME,),
- (
- DEFAULT_RESOURCE_GROUP,
- DEFAULT_FACTORY,
- NAME,
- ),
- ),
-)
-def test_delete_dataflow(hook: AzureDataFactoryHook, user_args, sdk_args):
- hook.delete_dataflow(*user_args)
+def test_delete_dataflow(hook: AzureDataFactoryHook):
+ hook.delete_dataflow(NAME, RESOURCE_GROUP, FACTORY)
- hook._conn.data_flows.delete.assert_called_with(*sdk_args)
+ hook._conn.data_flows.delete.assert_called_with(RESOURCE_GROUP, FACTORY,
NAME)
-@parametrize(
- explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP,
FACTORY, NAME)),
- implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY,
NAME)),
-)
-def test_get_pipeline(hook: AzureDataFactoryHook, user_args, sdk_args):
- hook.get_pipeline(*user_args)
+def test_get_pipeline(hook: AzureDataFactoryHook):
+ hook.get_pipeline(NAME, RESOURCE_GROUP, FACTORY)
- hook._conn.pipelines.get.assert_called_with(*sdk_args)
+ hook._conn.pipelines.get.assert_called_with(RESOURCE_GROUP, FACTORY, NAME)
-@parametrize(
- explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP,
FACTORY, NAME, MODEL)),
- implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY,
NAME, MODEL)),
-)
-def test_create_pipeline(hook: AzureDataFactoryHook, user_args, sdk_args):
- hook.create_pipeline(*user_args)
+def test_create_pipeline(hook: AzureDataFactoryHook):
+ hook.create_pipeline(NAME, MODEL, RESOURCE_GROUP, FACTORY)
- hook._conn.pipelines.create_or_update.assert_called_with(*sdk_args)
+ hook._conn.pipelines.create_or_update.assert_called_with(RESOURCE_GROUP,
FACTORY, NAME, MODEL)
-@parametrize(
- explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP,
FACTORY, NAME, MODEL)),
- implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY,
NAME, MODEL)),
-)
-def test_update_pipeline(hook: AzureDataFactoryHook, user_args, sdk_args):
+def test_update_pipeline(hook: AzureDataFactoryHook):
with patch.object(hook, "_pipeline_exists") as mock_pipeline_exists:
mock_pipeline_exists.return_value = True
- hook.update_pipeline(*user_args)
+ hook.update_pipeline(NAME, MODEL, RESOURCE_GROUP, FACTORY)
- hook._conn.pipelines.create_or_update.assert_called_with(*sdk_args)
+ hook._conn.pipelines.create_or_update.assert_called_with(RESOURCE_GROUP,
FACTORY, NAME, MODEL)
-@parametrize(
- explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP,
FACTORY, NAME, MODEL)),
- implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY,
NAME, MODEL)),
-)
-def test_update_pipeline_non_existent(hook: AzureDataFactoryHook, user_args,
sdk_args):
+def test_update_pipeline_non_existent(hook: AzureDataFactoryHook):
with patch.object(hook, "_pipeline_exists") as mock_pipeline_exists:
mock_pipeline_exists.return_value = False
with pytest.raises(AirflowException, match=r"Pipeline .+ does not exist"):
- hook.update_pipeline(*user_args)
+ hook.update_pipeline(NAME, MODEL, RESOURCE_GROUP, FACTORY)
-@parametrize(
- explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP,
FACTORY, NAME)),
- implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY,
NAME)),
-)
-def test_delete_pipeline(hook: AzureDataFactoryHook, user_args, sdk_args):
- hook.delete_pipeline(*user_args)
+def test_delete_pipeline(hook: AzureDataFactoryHook):
+ hook.delete_pipeline(NAME, RESOURCE_GROUP, FACTORY)
- hook._conn.pipelines.delete.assert_called_with(*sdk_args)
+ hook._conn.pipelines.delete.assert_called_with(RESOURCE_GROUP, FACTORY,
NAME)
-@parametrize(
- explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP,
FACTORY, NAME)),
- implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY,
NAME)),
-)
-def test_run_pipeline(hook: AzureDataFactoryHook, user_args, sdk_args):
- hook.run_pipeline(*user_args)
+def test_run_pipeline(hook: AzureDataFactoryHook):
+ hook.run_pipeline(NAME, RESOURCE_GROUP, FACTORY)
- hook._conn.pipelines.create_run.assert_called_with(*sdk_args)
+ hook._conn.pipelines.create_run.assert_called_with(RESOURCE_GROUP,
FACTORY, NAME)
-@parametrize(
- explicit_factory=((ID, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY,
ID)),
- implicit_factory=((ID,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, ID)),
-)
-def test_get_pipeline_run(hook: AzureDataFactoryHook, user_args, sdk_args):
- hook.get_pipeline_run(*user_args)
+def test_get_pipeline_run(hook: AzureDataFactoryHook):
+ hook.get_pipeline_run(ID, RESOURCE_GROUP, FACTORY)
- hook._conn.pipeline_runs.get.assert_called_with(*sdk_args)
+ hook._conn.pipeline_runs.get.assert_called_with(RESOURCE_GROUP, FACTORY,
ID)
_wait_for_pipeline_run_status_test_args = [
@@ -504,7 +389,14 @@ _wait_for_pipeline_run_status_test_args = [
],
)
def test_wait_for_pipeline_run_status(hook, pipeline_run_status,
expected_status, expected_output):
- config = {"run_id": ID, "timeout": 3, "check_interval": 1,
"expected_statuses": expected_status}
+ config = {
+ "resource_group_name": RESOURCE_GROUP,
+ "factory_name": FACTORY,
+ "run_id": ID,
+ "timeout": 3,
+ "check_interval": 1,
+ "expected_statuses": expected_status,
+ }
with patch.object(AzureDataFactoryHook, "get_pipeline_run") as
mock_pipeline_run:
mock_pipeline_run.return_value.status = pipeline_run_status
@@ -516,108 +408,68 @@ def test_wait_for_pipeline_run_status(hook,
pipeline_run_status, expected_status
hook.wait_for_pipeline_run_status(**config)
-@parametrize(
- explicit_factory=((ID, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY,
ID)),
- implicit_factory=((ID,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, ID)),
-)
-def test_cancel_pipeline_run(hook: AzureDataFactoryHook, user_args, sdk_args):
- hook.cancel_pipeline_run(*user_args)
+def test_cancel_pipeline_run(hook: AzureDataFactoryHook):
+ hook.cancel_pipeline_run(ID, RESOURCE_GROUP, FACTORY)
- hook._conn.pipeline_runs.cancel.assert_called_with(*sdk_args)
+ hook._conn.pipeline_runs.cancel.assert_called_with(RESOURCE_GROUP,
FACTORY, ID)
-@parametrize(
- explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP,
FACTORY, NAME)),
- implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY,
NAME)),
-)
-def test_get_trigger(hook: AzureDataFactoryHook, user_args, sdk_args):
- hook.get_trigger(*user_args)
+def test_get_trigger(hook: AzureDataFactoryHook):
+ hook.get_trigger(NAME, RESOURCE_GROUP, FACTORY)
- hook._conn.triggers.get.assert_called_with(*sdk_args)
+ hook._conn.triggers.get.assert_called_with(RESOURCE_GROUP, FACTORY, NAME)
-@parametrize(
- explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP,
FACTORY, NAME, MODEL)),
- implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY,
NAME, MODEL)),
-)
-def test_create_trigger(hook: AzureDataFactoryHook, user_args, sdk_args):
- hook.create_trigger(*user_args)
+def test_create_trigger(hook: AzureDataFactoryHook):
+ hook.create_trigger(NAME, MODEL, RESOURCE_GROUP, FACTORY)
- hook._conn.triggers.create_or_update.assert_called_with(*sdk_args)
+ hook._conn.triggers.create_or_update.assert_called_with(RESOURCE_GROUP,
FACTORY, NAME, MODEL)
-@parametrize(
- explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP,
FACTORY, NAME, MODEL)),
- implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY,
NAME, MODEL)),
-)
-def test_update_trigger(hook: AzureDataFactoryHook, user_args, sdk_args):
+def test_update_trigger(hook: AzureDataFactoryHook):
with patch.object(hook, "_trigger_exists") as mock_trigger_exists:
mock_trigger_exists.return_value = True
- hook.update_trigger(*user_args)
+ hook.update_trigger(NAME, MODEL, RESOURCE_GROUP, FACTORY)
- hook._conn.triggers.create_or_update.assert_called_with(*sdk_args)
+ hook._conn.triggers.create_or_update.assert_called_with(RESOURCE_GROUP,
FACTORY, NAME, MODEL, None)
-@parametrize(
- explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP,
FACTORY, NAME, MODEL)),
- implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY,
NAME, MODEL)),
-)
-def test_update_trigger_non_existent(hook: AzureDataFactoryHook, user_args,
sdk_args):
+def test_update_trigger_non_existent(hook: AzureDataFactoryHook):
with patch.object(hook, "_trigger_exists") as mock_trigger_exists:
mock_trigger_exists.return_value = False
with pytest.raises(AirflowException, match=r"Trigger .+ does not exist"):
- hook.update_trigger(*user_args)
+ hook.update_trigger(NAME, MODEL, RESOURCE_GROUP, FACTORY)
-@parametrize(
- explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP,
FACTORY, NAME)),
- implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY,
NAME)),
-)
-def test_delete_trigger(hook: AzureDataFactoryHook, user_args, sdk_args):
- hook.delete_trigger(*user_args)
+def test_delete_trigger(hook: AzureDataFactoryHook):
+ hook.delete_trigger(NAME, RESOURCE_GROUP, FACTORY)
- hook._conn.triggers.delete.assert_called_with(*sdk_args)
+ hook._conn.triggers.delete.assert_called_with(RESOURCE_GROUP, FACTORY,
NAME)
-@parametrize(
- explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP,
FACTORY, NAME)),
- implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY,
NAME)),
-)
-def test_start_trigger(hook: AzureDataFactoryHook, user_args, sdk_args):
- hook.start_trigger(*user_args)
+def test_start_trigger(hook: AzureDataFactoryHook):
+ hook.start_trigger(NAME, RESOURCE_GROUP, FACTORY)
- hook._conn.triggers.begin_start.assert_called_with(*sdk_args)
+ hook._conn.triggers.begin_start.assert_called_with(RESOURCE_GROUP,
FACTORY, NAME)
-@parametrize(
- explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP,
FACTORY, NAME)),
- implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY,
NAME)),
-)
-def test_stop_trigger(hook: AzureDataFactoryHook, user_args, sdk_args):
- hook.stop_trigger(*user_args)
+def test_stop_trigger(hook: AzureDataFactoryHook):
+ hook.stop_trigger(NAME, RESOURCE_GROUP, FACTORY)
- hook._conn.triggers.begin_stop.assert_called_with(*sdk_args)
+ hook._conn.triggers.begin_stop.assert_called_with(RESOURCE_GROUP, FACTORY,
NAME)
-@parametrize(
- explicit_factory=((NAME, ID, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP,
FACTORY, NAME, ID)),
- implicit_factory=((NAME, ID), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY,
NAME, ID)),
-)
-def test_rerun_trigger(hook: AzureDataFactoryHook, user_args, sdk_args):
- hook.rerun_trigger(*user_args)
+def test_rerun_trigger(hook: AzureDataFactoryHook):
+ hook.rerun_trigger(NAME, ID, RESOURCE_GROUP, FACTORY)
- hook._conn.trigger_runs.rerun.assert_called_with(*sdk_args)
+ hook._conn.trigger_runs.rerun.assert_called_with(RESOURCE_GROUP, FACTORY,
NAME, ID)
-@parametrize(
- explicit_factory=((NAME, ID, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP,
FACTORY, NAME, ID)),
- implicit_factory=((NAME, ID), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY,
NAME, ID)),
-)
-def test_cancel_trigger(hook: AzureDataFactoryHook, user_args, sdk_args):
- hook.cancel_trigger(*user_args)
+def test_cancel_trigger(hook: AzureDataFactoryHook):
+ hook.cancel_trigger(NAME, ID, RESOURCE_GROUP, FACTORY)
- hook._conn.trigger_runs.cancel.assert_called_with(*sdk_args)
+ hook._conn.trigger_runs.cancel.assert_called_with(RESOURCE_GROUP, FACTORY,
NAME, ID)
@pytest.mark.parametrize(
@@ -672,8 +524,8 @@ def test_connection_failure_missing_tenant_id():
def test_provide_targeted_factory_backcompat_prefix_works(mock_connect, uri):
with patch.dict(os.environ, {"AIRFLOW_CONN_MY_CONN": uri}):
hook = AzureDataFactoryHook("my_conn")
- hook.delete_factory()
- mock_connect.return_value.factories.delete.assert_called_with("abc",
"abc")
+ hook.delete_factory(RESOURCE_GROUP, FACTORY)
+
mock_connect.return_value.factories.delete.assert_called_with(RESOURCE_GROUP,
FACTORY)
@pytest.mark.parametrize(
@@ -707,8 +559,8 @@ def test_backcompat_prefix_both_prefers_short(mock_connect):
},
):
hook = AzureDataFactoryHook("my_conn")
- hook.delete_factory(factory_name="n/a")
-
mock_connect.return_value.factories.delete.assert_called_with("non-prefixed",
"n/a")
+ hook.delete_factory(RESOURCE_GROUP, FACTORY)
+
mock_connect.return_value.factories.delete.assert_called_with(RESOURCE_GROUP,
FACTORY)
def test_refresh_conn(hook):
diff --git
a/tests/providers/microsoft/azure/operators/test_azure_data_factory.py
b/tests/providers/microsoft/azure/operators/test_azure_data_factory.py
index 98ded34e1c..599a7b54e9 100644
--- a/tests/providers/microsoft/azure/operators/test_azure_data_factory.py
+++ b/tests/providers/microsoft/azure/operators/test_azure_data_factory.py
@@ -153,9 +153,9 @@ class TestAzureDataFactoryRunPipelineOperator:
)
mock_run_pipeline.assert_called_once_with(
- pipeline_name=self.config["pipeline_name"],
- resource_group_name=self.config["resource_group_name"],
- factory_name=self.config["factory_name"],
+ self.config["pipeline_name"],
+ self.config["resource_group_name"],
+ self.config["factory_name"],
reference_pipeline_run_id=None,
is_recovery=None,
start_activity_name=None,
@@ -165,9 +165,9 @@ class TestAzureDataFactoryRunPipelineOperator:
if pipeline_run_status in
AzureDataFactoryPipelineRunStatus.TERMINAL_STATUSES:
mock_get_pipeline_run.assert_called_once_with(
- run_id=mock_run_pipeline.return_value.run_id,
- factory_name=self.config["factory_name"],
- resource_group_name=self.config["resource_group_name"],
+ mock_run_pipeline.return_value.run_id,
+ self.config["resource_group_name"],
+ self.config["factory_name"],
)
else:
# When the pipeline run status is not in a terminal status or
"Succeeded", the operator will
@@ -177,9 +177,9 @@ class TestAzureDataFactoryRunPipelineOperator:
assert mock_get_pipeline_run.call_count == 4
mock_get_pipeline_run.assert_called_with(
- run_id=mock_run_pipeline.return_value.run_id,
- factory_name=self.config["factory_name"],
- resource_group_name=self.config["resource_group_name"],
+ mock_run_pipeline.return_value.run_id,
+ self.config["resource_group_name"],
+ self.config["factory_name"],
)
@patch.object(AzureDataFactoryHook, "run_pipeline",
return_value=MagicMock(**PIPELINE_RUN_RESPONSE))
@@ -205,9 +205,9 @@ class TestAzureDataFactoryRunPipelineOperator:
)
mock_run_pipeline.assert_called_once_with(
- pipeline_name=self.config["pipeline_name"],
- resource_group_name=self.config["resource_group_name"],
- factory_name=self.config["factory_name"],
+ self.config["pipeline_name"],
+ self.config["resource_group_name"],
+ self.config["factory_name"],
reference_pipeline_run_id=None,
is_recovery=None,
start_activity_name=None,
@@ -268,7 +268,12 @@ class TestAzureDataFactoryRunPipelineOperator:
class TestAzureDataFactoryRunPipelineOperatorWithDeferrable:
OPERATOR = AzureDataFactoryRunPipelineOperator(
- task_id="run_pipeline", pipeline_name="pipeline",
parameters={"myParam": "value"}, deferrable=True
+ task_id="run_pipeline",
+ pipeline_name="pipeline",
+ resource_group_name="resource-group-name",
+ factory_name="factory-name",
+ parameters={"myParam": "value"},
+ deferrable=True,
)
def get_dag_run(self, dag_id: str = "test_dag_id", run_id: str =
"test_dag_id") -> DagRun:
diff --git a/tests/providers/microsoft/azure/sensors/test_azure_data_factory.py
b/tests/providers/microsoft/azure/sensors/test_azure_data_factory.py
index fb489c7ad7..78631f0368 100644
--- a/tests/providers/microsoft/azure/sensors/test_azure_data_factory.py
+++ b/tests/providers/microsoft/azure/sensors/test_azure_data_factory.py
@@ -125,7 +125,11 @@ class TestAzureDataFactoryPipelineRunStatusSensor:
class TestAzureDataFactoryPipelineRunStatusSensorWithAsync:
RUN_ID = "7f8c6c72-c093-11ec-a83d-0242ac120007"
SENSOR = AzureDataFactoryPipelineRunStatusSensor(
- task_id="pipeline_run_sensor_async", run_id=RUN_ID, deferrable=True
+ task_id="pipeline_run_sensor_async",
+ run_id=RUN_ID,
+ resource_group_name="resource-group-name",
+ factory_name="factory-name",
+ deferrable=True,
)
@mock.patch("airflow.providers.microsoft.azure.sensors.data_factory.AzureDataFactoryHook")