This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 1d6f6131df refresh connection if an exception is caught in 
"AzureDataFactory" (#32323)
1d6f6131df is described below

commit 1d6f6131df7e420b9e9dd1535ea7cd1a29e3c548
Author: Wei Lee <[email protected]>
AuthorDate: Wed Jul 5 03:54:28 2023 +0800

    refresh connection if an exception is caught in "AzureDataFactory" (#32323)
---
 .../microsoft/azure/hooks/data_factory.py          |  43 ++++---
 .../microsoft/azure/triggers/data_factory.py       | 126 ++++++++++++++-------
 .../azure/hooks/test_azure_data_factory.py         |  36 +++---
 .../azure/triggers/test_azure_data_factory.py      |   6 +-
 4 files changed, 131 insertions(+), 80 deletions(-)

diff --git a/airflow/providers/microsoft/azure/hooks/data_factory.py 
b/airflow/providers/microsoft/azure/hooks/data_factory.py
index b9e7feea0a..590f2af92b 100644
--- a/airflow/providers/microsoft/azure/hooks/data_factory.py
+++ b/airflow/providers/microsoft/azure/hooks/data_factory.py
@@ -36,6 +36,7 @@ from functools import wraps
 from typing import Any, Callable, TypeVar, Union, cast
 
 from asgiref.sync import sync_to_async
+from azure.core.exceptions import ServiceRequestError
 from azure.core.polling import LROPoller
 from azure.identity import ClientSecretCredential, DefaultAzureCredential
 from azure.identity.aio import (
@@ -214,6 +215,10 @@ class AzureDataFactoryHook(BaseHook):
 
         return self._conn
 
+    def refresh_conn(self) -> DataFactoryManagementClient:
+        self._conn = None
+        return self.get_conn()
+
     @provide_targeted_factory
     def get_factory(
         self, resource_group_name: str | None = None, factory_name: str | None 
= None, **config: Any
@@ -812,6 +817,7 @@ class AzureDataFactoryHook(BaseHook):
             resource_group_name=resource_group_name,
         )
         pipeline_run_status = self.get_pipeline_run_status(**pipeline_run_info)
+        executed_after_token_refresh = True
 
         start_time = time.monotonic()
 
@@ -828,7 +834,14 @@ class AzureDataFactoryHook(BaseHook):
             # Wait to check the status of the pipeline run based on the 
``check_interval`` configured.
             time.sleep(check_interval)
 
-            pipeline_run_status = 
self.get_pipeline_run_status(**pipeline_run_info)
+            try:
+                pipeline_run_status = 
self.get_pipeline_run_status(**pipeline_run_info)
+                executed_after_token_refresh = True
+            except ServiceRequestError:
+                if executed_after_token_refresh:
+                    self.refresh_conn()
+                    continue
+                raise
 
         return pipeline_run_status in expected_statuses
 
@@ -1132,6 +1145,10 @@ class AzureDataFactoryAsyncHook(AzureDataFactoryHook):
 
         return self._async_conn
 
+    async def refresh_conn(self) -> AsyncDataFactoryManagementClient:
+        self._conn = None
+        return await self.get_async_conn()
+
     @provide_targeted_factory_async
     async def get_pipeline_run(
         self,
@@ -1149,11 +1166,8 @@ class AzureDataFactoryAsyncHook(AzureDataFactoryHook):
         :param config: Extra parameters for the ADF client.
         """
         client = await self.get_async_conn()
-        try:
-            pipeline_run = await client.pipeline_runs.get(resource_group_name, 
factory_name, run_id)
-            return pipeline_run
-        except Exception as e:
-            raise AirflowException(e)
+        pipeline_run = await client.pipeline_runs.get(resource_group_name, 
factory_name, run_id)
+        return pipeline_run
 
     async def get_adf_pipeline_run_status(
         self, run_id: str, resource_group_name: str | None = None, 
factory_name: str | None = None
@@ -1165,16 +1179,13 @@ class AzureDataFactoryAsyncHook(AzureDataFactoryHook):
         :param resource_group_name: The resource group name.
         :param factory_name: The factory name.
         """
-        try:
-            pipeline_run = await self.get_pipeline_run(
-                run_id=run_id,
-                factory_name=factory_name,
-                resource_group_name=resource_group_name,
-            )
-            status: str = pipeline_run.status
-            return status
-        except Exception as e:
-            raise AirflowException(e)
+        pipeline_run = await self.get_pipeline_run(
+            run_id=run_id,
+            factory_name=factory_name,
+            resource_group_name=resource_group_name,
+        )
+        status: str = pipeline_run.status
+        return status
 
     @provide_targeted_factory_async
     async def cancel_pipeline_run(
diff --git a/airflow/providers/microsoft/azure/triggers/data_factory.py 
b/airflow/providers/microsoft/azure/triggers/data_factory.py
index 40b9555940..e3dd38ad66 100644
--- a/airflow/providers/microsoft/azure/triggers/data_factory.py
+++ b/airflow/providers/microsoft/azure/triggers/data_factory.py
@@ -20,6 +20,8 @@ import asyncio
 import time
 from typing import Any, AsyncIterator
 
+from azure.core.exceptions import ServiceRequestError
+
 from airflow.providers.microsoft.azure.hooks.data_factory import (
     AzureDataFactoryAsyncHook,
     AzureDataFactoryPipelineRunStatus,
@@ -68,24 +70,41 @@ class ADFPipelineRunStatusSensorTrigger(BaseTrigger):
     async def run(self) -> AsyncIterator[TriggerEvent]:
         """Make async connection to Azure Data Factory, polls for the pipeline 
run status."""
         hook = 
AzureDataFactoryAsyncHook(azure_data_factory_conn_id=self.azure_data_factory_conn_id)
+        executed_after_token_refresh = False
         try:
             while True:
-                pipeline_status = await hook.get_adf_pipeline_run_status(
-                    run_id=self.run_id,
-                    resource_group_name=self.resource_group_name,
-                    factory_name=self.factory_name,
-                )
-                if pipeline_status == AzureDataFactoryPipelineRunStatus.FAILED:
-                    yield TriggerEvent(
-                        {"status": "error", "message": f"Pipeline run 
{self.run_id} has Failed."}
+                try:
+                    pipeline_status = await hook.get_adf_pipeline_run_status(
+                        run_id=self.run_id,
+                        resource_group_name=self.resource_group_name,
+                        factory_name=self.factory_name,
                     )
-                elif pipeline_status == 
AzureDataFactoryPipelineRunStatus.CANCELLED:
-                    msg = f"Pipeline run {self.run_id} has been Cancelled."
-                    yield TriggerEvent({"status": "error", "message": msg})
-                elif pipeline_status == 
AzureDataFactoryPipelineRunStatus.SUCCEEDED:
-                    msg = f"Pipeline run {self.run_id} has been Succeeded."
-                    yield TriggerEvent({"status": "success", "message": msg})
-                await asyncio.sleep(self.poke_interval)
+                    executed_after_token_refresh = False
+                    if pipeline_status == 
AzureDataFactoryPipelineRunStatus.FAILED:
+                        yield TriggerEvent(
+                            {"status": "error", "message": f"Pipeline run 
{self.run_id} has Failed."}
+                        )
+                        return
+                    elif pipeline_status == 
AzureDataFactoryPipelineRunStatus.CANCELLED:
+                        msg = f"Pipeline run {self.run_id} has been Cancelled."
+                        yield TriggerEvent({"status": "error", "message": msg})
+                        return
+                    elif pipeline_status == 
AzureDataFactoryPipelineRunStatus.SUCCEEDED:
+                        msg = f"Pipeline run {self.run_id} has been Succeeded."
+                        yield TriggerEvent({"status": "success", "message": 
msg})
+                        return
+                    await asyncio.sleep(self.poke_interval)
+                except ServiceRequestError:
+                    # conn might expire during long running pipeline.
+                    # If expcetion is caught, it tries to refresh connection 
once.
+                    # If it still doesn't fix the issue,
+                    # than the execute_after_token_refresh would still be False
+                    # and an exception will be raised
+                    if executed_after_token_refresh:
+                        await hook.refresh_conn()
+                        executed_after_token_refresh = False
+                        continue
+                    raise
         except Exception as e:
             yield TriggerEvent({"status": "error", "message": str(e)})
 
@@ -147,33 +166,49 @@ class AzureDataFactoryTrigger(BaseTrigger):
                 resource_group_name=self.resource_group_name,
                 factory_name=self.factory_name,
             )
+            executed_after_token_refresh = True
             if self.wait_for_termination:
                 while self.end_time > time.time():
-                    pipeline_status = await hook.get_adf_pipeline_run_status(
-                        run_id=self.run_id,
-                        resource_group_name=self.resource_group_name,
-                        factory_name=self.factory_name,
-                    )
-                    if pipeline_status in 
AzureDataFactoryPipelineRunStatus.FAILURE_STATES:
-                        yield TriggerEvent(
-                            {
-                                "status": "error",
-                                "message": f"The pipeline run {self.run_id} 
has {pipeline_status}.",
-                                "run_id": self.run_id,
-                            }
+                    try:
+                        pipeline_status = await 
hook.get_adf_pipeline_run_status(
+                            run_id=self.run_id,
+                            resource_group_name=self.resource_group_name,
+                            factory_name=self.factory_name,
                         )
-                    elif pipeline_status == 
AzureDataFactoryPipelineRunStatus.SUCCEEDED:
-                        yield TriggerEvent(
-                            {
-                                "status": "success",
-                                "message": f"The pipeline run {self.run_id} 
has {pipeline_status}.",
-                                "run_id": self.run_id,
-                            }
+                        executed_after_token_refresh = True
+                        if pipeline_status in 
AzureDataFactoryPipelineRunStatus.FAILURE_STATES:
+                            yield TriggerEvent(
+                                {
+                                    "status": "error",
+                                    "message": f"The pipeline run 
{self.run_id} has {pipeline_status}.",
+                                    "run_id": self.run_id,
+                                }
+                            )
+                            return
+                        elif pipeline_status == 
AzureDataFactoryPipelineRunStatus.SUCCEEDED:
+                            yield TriggerEvent(
+                                {
+                                    "status": "success",
+                                    "message": f"The pipeline run 
{self.run_id} has {pipeline_status}.",
+                                    "run_id": self.run_id,
+                                }
+                            )
+                            return
+                        self.log.info(
+                            "Sleeping for %s. The pipeline state is %s.", 
self.check_interval, pipeline_status
                         )
-                    self.log.info(
-                        "Sleeping for %s. The pipeline state is %s.", 
self.check_interval, pipeline_status
-                    )
-                    await asyncio.sleep(self.check_interval)
+                        await asyncio.sleep(self.check_interval)
+                    except ServiceRequestError:
+                        # conn might expire during long running pipeline.
+                        # If expcetion is caught, it tries to refresh 
connection once.
+                        # If it still doesn't fix the issue,
+                        # than the execute_after_token_refresh would still be 
False
+                        # and an exception will be raised
+                        if executed_after_token_refresh:
+                            await hook.refresh_conn()
+                            executed_after_token_refresh = False
+                            continue
+                        raise
 
                 yield TriggerEvent(
                     {
@@ -192,10 +227,13 @@ class AzureDataFactoryTrigger(BaseTrigger):
                 )
         except Exception as e:
             if self.run_id:
-                await hook.cancel_pipeline_run(
-                    run_id=self.run_id,
-                    resource_group_name=self.resource_group_name,
-                    factory_name=self.factory_name,
-                )
-                self.log.info("Unexpected error %s caught. Cancel pipeline run 
%s", str(e), self.run_id)
+                try:
+                    await hook.cancel_pipeline_run(
+                        run_id=self.run_id,
+                        resource_group_name=self.resource_group_name,
+                        factory_name=self.factory_name,
+                    )
+                    self.log.info("Unexpected error %s caught. Cancel pipeline 
run %s", str(e), self.run_id)
+                except Exception as err:
+                    yield TriggerEvent({"status": "error", "message": 
str(err), "run_id": self.run_id})
             yield TriggerEvent({"status": "error", "message": str(e), 
"run_id": self.run_id})
diff --git a/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py 
b/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py
index c7b256e9cb..57f7bc6178 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py
@@ -720,6 +720,14 @@ def 
test_backcompat_prefix_both_prefers_short(mock_connect):
         
mock_connect.return_value.factories.delete.assert_called_with("non-prefixed", 
"n/a")
 
 
+def test_refresh_conn(hook):
+    """Test refresh_conn method _conn is reset and get_conn is called"""
+    with patch.object(hook, "get_conn") as mock_get_conn:
+        hook.refresh_conn()
+        assert not hook._conn
+        assert mock_get_conn.called
+
+
 class TestAzureDataFactoryAsyncHook:
     @pytest.mark.asyncio
     
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_async_conn")
@@ -780,16 +788,6 @@ class TestAzureDataFactoryAsyncHook:
         response = await hook.get_adf_pipeline_run_status(RUN_ID, 
RESOURCE_GROUP_NAME, DATAFACTORY_NAME)
         assert response == mock_status
 
-    @pytest.mark.asyncio
-    
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_async_conn")
-    
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_pipeline_run")
-    async def test_get_adf_pipeline_run_status_exception(self, 
mock_get_pipeline_run, mock_conn):
-        """Test get_adf_pipeline_run_status function with exception"""
-        mock_get_pipeline_run.side_effect = Exception("Test exception")
-        hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID)
-        with pytest.raises(AirflowException):
-            await hook.get_adf_pipeline_run_status(RUN_ID, 
RESOURCE_GROUP_NAME, DATAFACTORY_NAME)
-
     @pytest.mark.asyncio
     @mock.patch("azure.mgmt.datafactory.models._models_py3.PipelineRun")
     
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_connection")
@@ -810,15 +808,6 @@ class TestAzureDataFactoryAsyncHook:
         with pytest.raises(AirflowException):
             await hook.get_pipeline_run(RUN_ID, None, DATAFACTORY_NAME)
 
-    @pytest.mark.asyncio
-    
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_async_conn")
-    async def test_get_pipeline_run_exception(self, mock_conn):
-        """Test get_pipeline_run function with exception"""
-        mock_conn.return_value.pipeline_runs.get.side_effect = Exception("Test 
exception")
-        hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID)
-        with pytest.raises(AirflowException):
-            await hook.get_pipeline_run(RUN_ID, RESOURCE_GROUP_NAME, 
DATAFACTORY_NAME)
-
     @pytest.mark.asyncio
     
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_connection")
     async def test_get_async_conn(self, mock_connection):
@@ -958,3 +947,12 @@ class TestAzureDataFactoryAsyncHook:
         assert get_field(extras, "factory_name", strict=True) == 
DATAFACTORY_NAME
         with pytest.raises(KeyError):
             get_field(extras, "non-existent-field", strict=True)
+
+    @pytest.mark.asyncio
+    
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_async_conn")
+    async def test_refresh_conn(self, mock_get_async_conn):
+        """Test refresh_conn method _conn is reset and get_async_conn is 
called"""
+        hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID)
+        await hook.refresh_conn()
+        assert not hook._conn
+        assert mock_get_async_conn.called
diff --git 
a/tests/providers/microsoft/azure/triggers/test_azure_data_factory.py 
b/tests/providers/microsoft/azure/triggers/test_azure_data_factory.py
index 9df51bae71..fd1a206554 100644
--- a/tests/providers/microsoft/azure/triggers/test_azure_data_factory.py
+++ b/tests/providers/microsoft/azure/triggers/test_azure_data_factory.py
@@ -163,10 +163,14 @@ class TestADFPipelineRunStatusSensorTrigger:
         assert TriggerEvent({"status": "error", "message": mock_message}) == 
actual
 
     @pytest.mark.asyncio
+    
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.refresh_conn")
     
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_adf_pipeline_run_status")
-    async def test_adf_pipeline_run_status_sensors_trigger_exception(self, 
mock_data_factory):
+    async def test_adf_pipeline_run_status_sensors_trigger_exception(
+        self, mock_data_factory, mock_refresh_token
+    ):
         """Test EMR container sensors with raise exception"""
         mock_data_factory.side_effect = Exception("Test exception")
+        mock_refresh_token.side_effect = Exception("Test exception")
 
         task = [i async for i in self.TRIGGER.run()]
         assert len(task) == 1

Reply via email to