This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new c39097c07a merge AzureDataFactoryPipelineRunStatusAsyncSensor to
AzureDataFactoryPipelineRunStatusSensor (#30250)
c39097c07a is described below
commit c39097c07a50fdc0baad08573d319627291f5f91
Author: Wei Lee <[email protected]>
AuthorDate: Wed Mar 29 20:04:06 2023 +0800
merge AzureDataFactoryPipelineRunStatusAsyncSensor to
AzureDataFactoryPipelineRunStatusSensor (#30250)
* feat(providers/microsoft): move the async execution logic from
AzureDataFactoryPipelineRunStatusAsyncSensor to
AzureDataFactoryPipelineRunStatusSensor
* test(providers/microsoft): add test cases for
AzureDataFactoryPipelineRunStatusSensor when its deferrable attribute is set to
True
* docs(providers/microsoft): update the doc for
AzureDataFactoryPipelineRunStatusSensor deferrable mode and deprecate
AzureDataFactoryPipelineRunStatusAsyncSensor
---
.../microsoft/azure/sensors/data_factory.py | 79 ++++++++++++----------
.../operators/adf_run_pipeline.rst | 9 +++
.../azure/sensors/test_azure_data_factory.py | 26 +++++++
.../microsoft/azure/example_adf_run_pipeline.py | 8 +++
4 files changed, 87 insertions(+), 35 deletions(-)
diff --git a/airflow/providers/microsoft/azure/sensors/data_factory.py
b/airflow/providers/microsoft/azure/sensors/data_factory.py
index e09f7a0a7d..e98bb9caee 100644
--- a/airflow/providers/microsoft/azure/sensors/data_factory.py
+++ b/airflow/providers/microsoft/azure/sensors/data_factory.py
@@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations
+import warnings
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Sequence
@@ -25,9 +26,7 @@ from airflow.providers.microsoft.azure.hooks.data_factory
import (
AzureDataFactoryPipelineRunException,
AzureDataFactoryPipelineRunStatus,
)
-from airflow.providers.microsoft.azure.triggers.data_factory import (
- ADFPipelineRunStatusSensorTrigger,
-)
+from airflow.providers.microsoft.azure.triggers.data_factory import
ADFPipelineRunStatusSensorTrigger
from airflow.sensors.base import BaseSensorOperator
if TYPE_CHECKING:
@@ -42,6 +41,7 @@ class
AzureDataFactoryPipelineRunStatusSensor(BaseSensorOperator):
:param run_id: The pipeline run identifier.
:param resource_group_name: The resource group name.
:param factory_name: The data factory name.
+ :param deferrable: Run sensor in the deferrable mode.
"""
template_fields: Sequence[str] = (
@@ -60,6 +60,7 @@ class
AzureDataFactoryPipelineRunStatusSensor(BaseSensorOperator):
azure_data_factory_conn_id: str =
AzureDataFactoryHook.default_conn_name,
resource_group_name: str | None = None,
factory_name: str | None = None,
+ deferrable: bool = False,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -68,6 +69,8 @@ class
AzureDataFactoryPipelineRunStatusSensor(BaseSensorOperator):
self.resource_group_name = resource_group_name
self.factory_name = factory_name
+ self.deferrable = deferrable
+
def poke(self, context: Context) -> bool:
self.hook =
AzureDataFactoryHook(azure_data_factory_conn_id=self.azure_data_factory_conn_id)
pipeline_run_status = self.hook.get_pipeline_run_status(
@@ -84,42 +87,24 @@ class
AzureDataFactoryPipelineRunStatusSensor(BaseSensorOperator):
return pipeline_run_status ==
AzureDataFactoryPipelineRunStatus.SUCCEEDED
-
-class
AzureDataFactoryPipelineRunStatusAsyncSensor(AzureDataFactoryPipelineRunStatusSensor):
- """
- Checks the status of a pipeline run asynchronously.
-
- :param azure_data_factory_conn_id: The connection identifier for
connecting to Azure Data Factory.
- :param run_id: The pipeline run identifier.
- :param resource_group_name: The resource group name.
- :param factory_name: The data factory name.
- :param poke_interval: polling period in seconds to check for the status
- """
-
- def __init__(
- self,
- *,
- poke_interval: float = 60,
- **kwargs: Any,
- ):
- self.poke_interval = poke_interval
- super().__init__(**kwargs)
-
def execute(self, context: Context) -> None:
"""Defers trigger class to poll for state of the job run until
it reaches a failure state or success state
"""
- self.defer(
- timeout=timedelta(seconds=self.timeout),
- trigger=ADFPipelineRunStatusSensorTrigger(
- run_id=self.run_id,
- azure_data_factory_conn_id=self.azure_data_factory_conn_id,
- resource_group_name=self.resource_group_name,
- factory_name=self.factory_name,
- poke_interval=self.poke_interval,
- ),
- method_name="execute_complete",
- )
+ if not self.deferrable:
+ super().execute(context=context)
+ else:
+ self.defer(
+ timeout=timedelta(seconds=self.timeout),
+ trigger=ADFPipelineRunStatusSensorTrigger(
+ run_id=self.run_id,
+ azure_data_factory_conn_id=self.azure_data_factory_conn_id,
+ resource_group_name=self.resource_group_name,
+ factory_name=self.factory_name,
+ poke_interval=self.poke_interval,
+ ),
+ method_name="execute_complete",
+ )
def execute_complete(self, context: Context, event: dict[str, str]) ->
None:
"""
@@ -132,3 +117,27 @@ class
AzureDataFactoryPipelineRunStatusAsyncSensor(AzureDataFactoryPipelineRunSt
raise AirflowException(event["message"])
self.log.info(event["message"])
return None
+
+
+class
AzureDataFactoryPipelineRunStatusAsyncSensor(AzureDataFactoryPipelineRunStatusSensor):
+ """
+ Checks the status of a pipeline run asynchronously.
+
+ :param azure_data_factory_conn_id: The connection identifier for
connecting to Azure Data Factory.
+ :param run_id: The pipeline run identifier.
+ :param resource_group_name: The resource group name.
+ :param factory_name: The data factory name.
+ :param poke_interval: polling period in seconds to check for the status
+ :param deferrable: Run sensor in the deferrable mode.
+ """
+
+ def __init__(self, **kwargs: Any) -> None:
+ warnings.warn(
+ "Class `AzureDataFactoryPipelineRunStatusAsyncSensor` is
deprecated and "
+ "will be removed in a future release. "
+ "Please use `AzureDataFactoryPipelineRunStatusSensor` and "
+ "set `deferrable` attribute to `True` instead",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ super().__init__(**kwargs, deferrable=True)
diff --git
a/docs/apache-airflow-providers-microsoft-azure/operators/adf_run_pipeline.rst
b/docs/apache-airflow-providers-microsoft-azure/operators/adf_run_pipeline.rst
index a666a1871d..024873d597 100644
---
a/docs/apache-airflow-providers-microsoft-azure/operators/adf_run_pipeline.rst
+++
b/docs/apache-airflow-providers-microsoft-azure/operators/adf_run_pipeline.rst
@@ -54,6 +54,15 @@ Here is a different example of using this operator to
execute a pipeline but cou
:start-after: [START howto_operator_adf_run_pipeline_async]
:end-before: [END howto_operator_adf_run_pipeline_async]
+Also you can use deferrable mode in
:class:`~airflow.providers.microsoft.azure.sensors.data_factory.AzureDataFactoryPipelineRunStatusSensor`
if you would like to free up the worker slots while the sensor is running.
+
+ .. exampleinclude::
/../../tests/system/providers/microsoft/azure/example_adf_run_pipeline.py
+ :language: python
+ :dedent: 0
+ :start-after: [START howto_operator_adf_run_pipeline_async]
+ :end-before: [END howto_operator_adf_run_pipeline_async]
+
+
Poll for status of a data factory pipeline run asynchronously
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
diff --git a/tests/providers/microsoft/azure/sensors/test_azure_data_factory.py
b/tests/providers/microsoft/azure/sensors/test_azure_data_factory.py
index 0424bf41c4..21451775bd 100644
--- a/tests/providers/microsoft/azure/sensors/test_azure_data_factory.py
+++ b/tests/providers/microsoft/azure/sensors/test_azure_data_factory.py
@@ -45,6 +45,9 @@ class TestPipelineRunStatusSensor:
"poke_interval": 15,
}
self.sensor =
AzureDataFactoryPipelineRunStatusSensor(task_id="pipeline_run_sensor",
**self.config)
+ self.defered_sensor = AzureDataFactoryPipelineRunStatusSensor(
+ task_id="pipeline_run_sensor_defered", deferrable=True,
**self.config
+ )
def test_init(self):
assert self.sensor.azure_data_factory_conn_id ==
self.config["azure_data_factory_conn_id"]
@@ -81,6 +84,29 @@ class TestPipelineRunStatusSensor:
with pytest.raises(AzureDataFactoryPipelineRunException,
match=error_message):
self.sensor.poke({})
+ def test_adf_pipeline_status_sensor_async(self):
+ """Assert execute method defer for Azure Data factory pipeline run
status sensor"""
+
+ with pytest.raises(TaskDeferred) as exc:
+ self.defered_sensor.execute({})
+ assert isinstance(
+ exc.value.trigger, ADFPipelineRunStatusSensorTrigger
+ ), "Trigger is not a ADFPipelineRunStatusSensorTrigger"
+
+ def test_adf_pipeline_status_sensor_execute_complete_success(self):
+ """Assert execute_complete log success message when trigger fire with
target status"""
+
+ msg = f"Pipeline run {self.config['run_id']} has been succeeded."
+ with mock.patch.object(self.defered_sensor.log, "info") as
mock_log_info:
+ self.defered_sensor.execute_complete(context={}, event={"status":
"success", "message": msg})
+ mock_log_info.assert_called_with(msg)
+
+ def test_adf_pipeline_status_sensor_execute_complete_failure(self):
+ """Assert execute_complete method fail"""
+
+ with pytest.raises(AirflowException):
+ self.defered_sensor.execute_complete(context={}, event={"status":
"error", "message": ""})
+
class TestAzureDataFactoryPipelineRunStatusAsyncSensor:
RUN_ID = "7f8c6c72-c093-11ec-a83d-0242ac120007"
diff --git a/tests/system/providers/microsoft/azure/example_adf_run_pipeline.py
b/tests/system/providers/microsoft/azure/example_adf_run_pipeline.py
index 2db131c93d..dbac6e8a0a 100644
--- a/tests/system/providers/microsoft/azure/example_adf_run_pipeline.py
+++ b/tests/system/providers/microsoft/azure/example_adf_run_pipeline.py
@@ -76,6 +76,14 @@ with DAG(
)
# Performs polling on the Airflow Triggerer thus freeing up resources on
Airflow Worker
+ pipeline_run_sensor = AzureDataFactoryPipelineRunStatusSensor(
+ task_id="pipeline_run_sensor_defered",
+ run_id=cast(str, XComArg(run_pipeline2, key="run_id")),
+ deferrable=True,
+ )
+
+ # The following sensor is deprecated.
+ # Please use the AzureDataFactoryPipelineRunStatusSensor and set
deferrable to True
pipeline_run_async_sensor = AzureDataFactoryPipelineRunStatusAsyncSensor(
task_id="pipeline_run_async_sensor",
run_id=cast(str, XComArg(run_pipeline2, key="run_id")),