This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 1d6f6131df refresh connection if an exception is caught in
"AzureDataFactory" (#32323)
1d6f6131df is described below
commit 1d6f6131df7e420b9e9dd1535ea7cd1a29e3c548
Author: Wei Lee <[email protected]>
AuthorDate: Wed Jul 5 03:54:28 2023 +0800
refresh connection if an exception is caught in "AzureDataFactory" (#32323)
---
.../microsoft/azure/hooks/data_factory.py | 43 ++++---
.../microsoft/azure/triggers/data_factory.py | 126 ++++++++++++++-------
.../azure/hooks/test_azure_data_factory.py | 36 +++---
.../azure/triggers/test_azure_data_factory.py | 6 +-
4 files changed, 131 insertions(+), 80 deletions(-)
diff --git a/airflow/providers/microsoft/azure/hooks/data_factory.py
b/airflow/providers/microsoft/azure/hooks/data_factory.py
index b9e7feea0a..590f2af92b 100644
--- a/airflow/providers/microsoft/azure/hooks/data_factory.py
+++ b/airflow/providers/microsoft/azure/hooks/data_factory.py
@@ -36,6 +36,7 @@ from functools import wraps
from typing import Any, Callable, TypeVar, Union, cast
from asgiref.sync import sync_to_async
+from azure.core.exceptions import ServiceRequestError
from azure.core.polling import LROPoller
from azure.identity import ClientSecretCredential, DefaultAzureCredential
from azure.identity.aio import (
@@ -214,6 +215,10 @@ class AzureDataFactoryHook(BaseHook):
return self._conn
+ def refresh_conn(self) -> DataFactoryManagementClient:
+ self._conn = None
+ return self.get_conn()
+
@provide_targeted_factory
def get_factory(
self, resource_group_name: str | None = None, factory_name: str | None
= None, **config: Any
@@ -812,6 +817,7 @@ class AzureDataFactoryHook(BaseHook):
resource_group_name=resource_group_name,
)
pipeline_run_status = self.get_pipeline_run_status(**pipeline_run_info)
+ executed_after_token_refresh = True
start_time = time.monotonic()
@@ -828,7 +834,14 @@ class AzureDataFactoryHook(BaseHook):
# Wait to check the status of the pipeline run based on the
``check_interval`` configured.
time.sleep(check_interval)
- pipeline_run_status =
self.get_pipeline_run_status(**pipeline_run_info)
+ try:
+ pipeline_run_status =
self.get_pipeline_run_status(**pipeline_run_info)
+ executed_after_token_refresh = True
+ except ServiceRequestError:
+ if executed_after_token_refresh:
+ self.refresh_conn()
+ continue
+ raise
return pipeline_run_status in expected_statuses
@@ -1132,6 +1145,10 @@ class AzureDataFactoryAsyncHook(AzureDataFactoryHook):
return self._async_conn
+ async def refresh_conn(self) -> AsyncDataFactoryManagementClient:
+ self._conn = None
+ return await self.get_async_conn()
+
@provide_targeted_factory_async
async def get_pipeline_run(
self,
@@ -1149,11 +1166,8 @@ class AzureDataFactoryAsyncHook(AzureDataFactoryHook):
:param config: Extra parameters for the ADF client.
"""
client = await self.get_async_conn()
- try:
- pipeline_run = await client.pipeline_runs.get(resource_group_name,
factory_name, run_id)
- return pipeline_run
- except Exception as e:
- raise AirflowException(e)
+ pipeline_run = await client.pipeline_runs.get(resource_group_name,
factory_name, run_id)
+ return pipeline_run
async def get_adf_pipeline_run_status(
self, run_id: str, resource_group_name: str | None = None,
factory_name: str | None = None
@@ -1165,16 +1179,13 @@ class AzureDataFactoryAsyncHook(AzureDataFactoryHook):
:param resource_group_name: The resource group name.
:param factory_name: The factory name.
"""
- try:
- pipeline_run = await self.get_pipeline_run(
- run_id=run_id,
- factory_name=factory_name,
- resource_group_name=resource_group_name,
- )
- status: str = pipeline_run.status
- return status
- except Exception as e:
- raise AirflowException(e)
+ pipeline_run = await self.get_pipeline_run(
+ run_id=run_id,
+ factory_name=factory_name,
+ resource_group_name=resource_group_name,
+ )
+ status: str = pipeline_run.status
+ return status
@provide_targeted_factory_async
async def cancel_pipeline_run(
diff --git a/airflow/providers/microsoft/azure/triggers/data_factory.py
b/airflow/providers/microsoft/azure/triggers/data_factory.py
index 40b9555940..e3dd38ad66 100644
--- a/airflow/providers/microsoft/azure/triggers/data_factory.py
+++ b/airflow/providers/microsoft/azure/triggers/data_factory.py
@@ -20,6 +20,8 @@ import asyncio
import time
from typing import Any, AsyncIterator
+from azure.core.exceptions import ServiceRequestError
+
from airflow.providers.microsoft.azure.hooks.data_factory import (
AzureDataFactoryAsyncHook,
AzureDataFactoryPipelineRunStatus,
@@ -68,24 +70,41 @@ class ADFPipelineRunStatusSensorTrigger(BaseTrigger):
async def run(self) -> AsyncIterator[TriggerEvent]:
"""Make async connection to Azure Data Factory, polls for the pipeline
run status."""
hook =
AzureDataFactoryAsyncHook(azure_data_factory_conn_id=self.azure_data_factory_conn_id)
+ executed_after_token_refresh = False
try:
while True:
- pipeline_status = await hook.get_adf_pipeline_run_status(
- run_id=self.run_id,
- resource_group_name=self.resource_group_name,
- factory_name=self.factory_name,
- )
- if pipeline_status == AzureDataFactoryPipelineRunStatus.FAILED:
- yield TriggerEvent(
- {"status": "error", "message": f"Pipeline run
{self.run_id} has Failed."}
+ try:
+ pipeline_status = await hook.get_adf_pipeline_run_status(
+ run_id=self.run_id,
+ resource_group_name=self.resource_group_name,
+ factory_name=self.factory_name,
)
- elif pipeline_status ==
AzureDataFactoryPipelineRunStatus.CANCELLED:
- msg = f"Pipeline run {self.run_id} has been Cancelled."
- yield TriggerEvent({"status": "error", "message": msg})
- elif pipeline_status ==
AzureDataFactoryPipelineRunStatus.SUCCEEDED:
- msg = f"Pipeline run {self.run_id} has been Succeeded."
- yield TriggerEvent({"status": "success", "message": msg})
- await asyncio.sleep(self.poke_interval)
+ executed_after_token_refresh = False
+ if pipeline_status ==
AzureDataFactoryPipelineRunStatus.FAILED:
+ yield TriggerEvent(
+ {"status": "error", "message": f"Pipeline run
{self.run_id} has Failed."}
+ )
+ return
+ elif pipeline_status ==
AzureDataFactoryPipelineRunStatus.CANCELLED:
+ msg = f"Pipeline run {self.run_id} has been Cancelled."
+ yield TriggerEvent({"status": "error", "message": msg})
+ return
+ elif pipeline_status ==
AzureDataFactoryPipelineRunStatus.SUCCEEDED:
+ msg = f"Pipeline run {self.run_id} has been Succeeded."
+ yield TriggerEvent({"status": "success", "message":
msg})
+ return
+ await asyncio.sleep(self.poke_interval)
+ except ServiceRequestError:
+ # conn might expire during long running pipeline.
+ # If expcetion is caught, it tries to refresh connection
once.
+ # If it still doesn't fix the issue,
+ # than the execute_after_token_refresh would still be False
+ # and an exception will be raised
+ if executed_after_token_refresh:
+ await hook.refresh_conn()
+ executed_after_token_refresh = False
+ continue
+ raise
except Exception as e:
yield TriggerEvent({"status": "error", "message": str(e)})
@@ -147,33 +166,49 @@ class AzureDataFactoryTrigger(BaseTrigger):
resource_group_name=self.resource_group_name,
factory_name=self.factory_name,
)
+ executed_after_token_refresh = True
if self.wait_for_termination:
while self.end_time > time.time():
- pipeline_status = await hook.get_adf_pipeline_run_status(
- run_id=self.run_id,
- resource_group_name=self.resource_group_name,
- factory_name=self.factory_name,
- )
- if pipeline_status in
AzureDataFactoryPipelineRunStatus.FAILURE_STATES:
- yield TriggerEvent(
- {
- "status": "error",
- "message": f"The pipeline run {self.run_id}
has {pipeline_status}.",
- "run_id": self.run_id,
- }
+ try:
+ pipeline_status = await
hook.get_adf_pipeline_run_status(
+ run_id=self.run_id,
+ resource_group_name=self.resource_group_name,
+ factory_name=self.factory_name,
)
- elif pipeline_status ==
AzureDataFactoryPipelineRunStatus.SUCCEEDED:
- yield TriggerEvent(
- {
- "status": "success",
- "message": f"The pipeline run {self.run_id}
has {pipeline_status}.",
- "run_id": self.run_id,
- }
+ executed_after_token_refresh = True
+ if pipeline_status in
AzureDataFactoryPipelineRunStatus.FAILURE_STATES:
+ yield TriggerEvent(
+ {
+ "status": "error",
+ "message": f"The pipeline run
{self.run_id} has {pipeline_status}.",
+ "run_id": self.run_id,
+ }
+ )
+ return
+ elif pipeline_status ==
AzureDataFactoryPipelineRunStatus.SUCCEEDED:
+ yield TriggerEvent(
+ {
+ "status": "success",
+ "message": f"The pipeline run
{self.run_id} has {pipeline_status}.",
+ "run_id": self.run_id,
+ }
+ )
+ return
+ self.log.info(
+ "Sleeping for %s. The pipeline state is %s.",
self.check_interval, pipeline_status
)
- self.log.info(
- "Sleeping for %s. The pipeline state is %s.",
self.check_interval, pipeline_status
- )
- await asyncio.sleep(self.check_interval)
+ await asyncio.sleep(self.check_interval)
+ except ServiceRequestError:
+ # conn might expire during long running pipeline.
+ # If expcetion is caught, it tries to refresh
connection once.
+ # If it still doesn't fix the issue,
+ # than the execute_after_token_refresh would still be
False
+ # and an exception will be raised
+ if executed_after_token_refresh:
+ await hook.refresh_conn()
+ executed_after_token_refresh = False
+ continue
+ raise
yield TriggerEvent(
{
@@ -192,10 +227,13 @@ class AzureDataFactoryTrigger(BaseTrigger):
)
except Exception as e:
if self.run_id:
- await hook.cancel_pipeline_run(
- run_id=self.run_id,
- resource_group_name=self.resource_group_name,
- factory_name=self.factory_name,
- )
- self.log.info("Unexpected error %s caught. Cancel pipeline run
%s", str(e), self.run_id)
+ try:
+ await hook.cancel_pipeline_run(
+ run_id=self.run_id,
+ resource_group_name=self.resource_group_name,
+ factory_name=self.factory_name,
+ )
+ self.log.info("Unexpected error %s caught. Cancel pipeline
run %s", str(e), self.run_id)
+ except Exception as err:
+ yield TriggerEvent({"status": "error", "message":
str(err), "run_id": self.run_id})
yield TriggerEvent({"status": "error", "message": str(e),
"run_id": self.run_id})
diff --git a/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py
b/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py
index c7b256e9cb..57f7bc6178 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py
@@ -720,6 +720,14 @@ def
test_backcompat_prefix_both_prefers_short(mock_connect):
mock_connect.return_value.factories.delete.assert_called_with("non-prefixed",
"n/a")
+def test_refresh_conn(hook):
+ """Test refresh_conn method _conn is reset and get_conn is called"""
+ with patch.object(hook, "get_conn") as mock_get_conn:
+ hook.refresh_conn()
+ assert not hook._conn
+ assert mock_get_conn.called
+
+
class TestAzureDataFactoryAsyncHook:
@pytest.mark.asyncio
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_async_conn")
@@ -780,16 +788,6 @@ class TestAzureDataFactoryAsyncHook:
response = await hook.get_adf_pipeline_run_status(RUN_ID,
RESOURCE_GROUP_NAME, DATAFACTORY_NAME)
assert response == mock_status
- @pytest.mark.asyncio
-
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_async_conn")
-
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_pipeline_run")
- async def test_get_adf_pipeline_run_status_exception(self,
mock_get_pipeline_run, mock_conn):
- """Test get_adf_pipeline_run_status function with exception"""
- mock_get_pipeline_run.side_effect = Exception("Test exception")
- hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID)
- with pytest.raises(AirflowException):
- await hook.get_adf_pipeline_run_status(RUN_ID,
RESOURCE_GROUP_NAME, DATAFACTORY_NAME)
-
@pytest.mark.asyncio
@mock.patch("azure.mgmt.datafactory.models._models_py3.PipelineRun")
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_connection")
@@ -810,15 +808,6 @@ class TestAzureDataFactoryAsyncHook:
with pytest.raises(AirflowException):
await hook.get_pipeline_run(RUN_ID, None, DATAFACTORY_NAME)
- @pytest.mark.asyncio
-
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_async_conn")
- async def test_get_pipeline_run_exception(self, mock_conn):
- """Test get_pipeline_run function with exception"""
- mock_conn.return_value.pipeline_runs.get.side_effect = Exception("Test
exception")
- hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID)
- with pytest.raises(AirflowException):
- await hook.get_pipeline_run(RUN_ID, RESOURCE_GROUP_NAME,
DATAFACTORY_NAME)
-
@pytest.mark.asyncio
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_connection")
async def test_get_async_conn(self, mock_connection):
@@ -958,3 +947,12 @@ class TestAzureDataFactoryAsyncHook:
assert get_field(extras, "factory_name", strict=True) ==
DATAFACTORY_NAME
with pytest.raises(KeyError):
get_field(extras, "non-existent-field", strict=True)
+
+ @pytest.mark.asyncio
+
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_async_conn")
+ async def test_refresh_conn(self, mock_get_async_conn):
+ """Test refresh_conn method _conn is reset and get_async_conn is
called"""
+ hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID)
+ await hook.refresh_conn()
+ assert not hook._conn
+ assert mock_get_async_conn.called
diff --git
a/tests/providers/microsoft/azure/triggers/test_azure_data_factory.py
b/tests/providers/microsoft/azure/triggers/test_azure_data_factory.py
index 9df51bae71..fd1a206554 100644
--- a/tests/providers/microsoft/azure/triggers/test_azure_data_factory.py
+++ b/tests/providers/microsoft/azure/triggers/test_azure_data_factory.py
@@ -163,10 +163,14 @@ class TestADFPipelineRunStatusSensorTrigger:
assert TriggerEvent({"status": "error", "message": mock_message}) ==
actual
@pytest.mark.asyncio
+
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.refresh_conn")
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_adf_pipeline_run_status")
- async def test_adf_pipeline_run_status_sensors_trigger_exception(self,
mock_data_factory):
+ async def test_adf_pipeline_run_status_sensors_trigger_exception(
+ self, mock_data_factory, mock_refresh_token
+ ):
"""Test EMR container sensors with raise exception"""
mock_data_factory.side_effect = Exception("Test exception")
+ mock_refresh_token.side_effect = Exception("Test exception")
task = [i async for i in self.TRIGGER.run()]
assert len(task) == 1