phanikumv commented on code in PR #29801:
URL: https://github.com/apache/airflow/pull/29801#discussion_r1132621625
##########
tests/providers/microsoft/azure/hooks/test_azure_data_factory.py:
##########
@@ -708,3 +724,245 @@ def
test_backcompat_prefix_both_prefers_short(mock_connect):
hook = AzureDataFactoryHook("my_conn")
hook.delete_factory(factory_name="n/a")
mock_connect.return_value.factories.delete.assert_called_with("non-prefixed",
"n/a")
+
+
+class TestAzureDataFactoryAsyncHook:
+ @pytest.mark.asyncio
+
@async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_async_conn")
+
@async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_pipeline_run")
+ async def test_get_adf_pipeline_run_status_queued(self,
mock_get_pipeline_run, mock_conn):
+ """Test get_adf_pipeline_run_status function with mocked status"""
+ mock_status = "Queued"
+ mock_get_pipeline_run.return_value.status = mock_status
+ hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID)
+ response = await hook.get_adf_pipeline_run_status(RUN_ID,
RESOURCE_GROUP_NAME, DATAFACTORY_NAME)
+ assert response == mock_status
+
+ @pytest.mark.asyncio
+
@async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_async_conn")
+
@async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_pipeline_run")
+ async def test_get_adf_pipeline_run_status_inprogress(
+ self,
+ mock_get_pipeline_run,
+ mock_conn,
+ ):
+ """Test get_adf_pipeline_run_status function with mocked status"""
+ mock_status = "InProgress"
+ mock_get_pipeline_run.return_value.status = mock_status
+ hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID)
+ response = await hook.get_adf_pipeline_run_status(RUN_ID,
RESOURCE_GROUP_NAME, DATAFACTORY_NAME)
+ assert response == mock_status
+
+ @pytest.mark.asyncio
+
@async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_async_conn")
+
@async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_pipeline_run")
+ async def test_get_adf_pipeline_run_status_success(self,
mock_get_pipeline_run, mock_conn):
+ """Test get_adf_pipeline_run_status function with mocked status"""
+ mock_status = "Succeeded"
+ mock_get_pipeline_run.return_value.status = mock_status
+ hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID)
+ response = await hook.get_adf_pipeline_run_status(RUN_ID,
RESOURCE_GROUP_NAME, DATAFACTORY_NAME)
+ assert response == mock_status
+
+ @pytest.mark.asyncio
+
@async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_async_conn")
+
@async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_pipeline_run")
+ async def test_get_adf_pipeline_run_status_failed(self,
mock_get_pipeline_run, mock_conn):
+ """Test get_adf_pipeline_run_status function with mocked status"""
+ mock_status = "Failed"
+ mock_get_pipeline_run.return_value.status = mock_status
+ hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID)
+ response = await hook.get_adf_pipeline_run_status(RUN_ID,
RESOURCE_GROUP_NAME, DATAFACTORY_NAME)
+ assert response == mock_status
+
+ @pytest.mark.asyncio
+
@async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_async_conn")
+
@async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_pipeline_run")
+ async def test_get_adf_pipeline_run_status_cancelled(self,
mock_get_pipeline_run, mock_conn):
+ """Test get_adf_pipeline_run_status function with mocked status"""
+ mock_status = "Cancelled"
+ mock_get_pipeline_run.return_value.status = mock_status
+ hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID)
+ response = await hook.get_adf_pipeline_run_status(RUN_ID,
RESOURCE_GROUP_NAME, DATAFACTORY_NAME)
+ assert response == mock_status
+
+ @pytest.mark.asyncio
+
@async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_async_conn")
+
@async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_pipeline_run")
+ async def test_get_adf_pipeline_run_status_exception(self,
mock_get_pipeline_run, mock_conn):
+ """Test get_adf_pipeline_run_status function with exception"""
+ mock_get_pipeline_run.side_effect = Exception("Test exception")
+ hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID)
+ with pytest.raises(AirflowException):
+ await hook.get_adf_pipeline_run_status(RUN_ID,
RESOURCE_GROUP_NAME, DATAFACTORY_NAME)
+
+ @pytest.mark.asyncio
+ @async_mock.patch("azure.mgmt.datafactory.models._models_py3.PipelineRun")
+
@async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_connection")
+
@async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_async_conn")
+ async def test_get_pipeline_run_exception_without_resource(
+ self, mock_conn, mock_get_connection, mock_pipeline_run
+ ):
+ """
+ Test get_pipeline_run function without passing the resource name to
check the decorator function and
+ raise exception
+ """
+ mock_connection = Connection(
+ extra=json.dumps({"extra__azure_data_factory__factory_name":
DATAFACTORY_NAME})
+ )
+ mock_get_connection.return_value = mock_connection
+
mock_conn.return_value.__aenter__.return_value.pipeline_runs.get.return_value =
mock_pipeline_run
+ hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID)
+ with pytest.raises(AirflowException):
+ await hook.get_pipeline_run(RUN_ID, None, DATAFACTORY_NAME)
+
+ @pytest.mark.asyncio
+
@async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_async_conn")
+ async def test_get_pipeline_run_exception(self, mock_conn):
+ """Test get_pipeline_run function with exception"""
+
mock_conn.return_value.__aenter__.return_value.pipeline_runs.get.side_effect =
Exception(
+ "Test exception"
+ )
+ hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID)
+ with pytest.raises(AirflowException):
+ await hook.get_pipeline_run(RUN_ID, RESOURCE_GROUP_NAME,
DATAFACTORY_NAME)
+
+ @pytest.mark.asyncio
+
@async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_connection")
+ async def test_get_async_conn(self, mock_connection):
+ """"""
+ mock_conn = Connection(
+ conn_id=DEFAULT_CONNECTION_CLIENT_SECRET,
+ conn_type="azure_data_factory",
+ login="clientId",
+ password="clientSecret",
+ extra=json.dumps(
+ {
+ "extra__azure_data_factory__tenantId": "tenantId",
+ "extra__azure_data_factory__subscriptionId":
"subscriptionId",
+ "extra__azure_data_factory__resource_group_name":
RESOURCE_GROUP_NAME,
+ "extra__azure_data_factory__factory_name":
DATAFACTORY_NAME,
+ }
+ ),
+ )
+ mock_connection.return_value = mock_conn
+ hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID)
+ response = await hook.get_async_conn()
+ assert isinstance(response, DataFactoryManagementClient)
+
+ @pytest.mark.asyncio
+
@async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_connection")
+ async def test_get_async_conn_without_login_id(self, mock_connection):
+ """Test get_async_conn function without login id"""
+ mock_conn = Connection(
+ conn_id=DEFAULT_CONNECTION_CLIENT_SECRET,
+ conn_type="azure_data_factory",
+ extra=json.dumps(
+ {
+ "extra__azure_data_factory__tenantId": "tenantId",
+ "extra__azure_data_factory__subscriptionId":
"subscriptionId",
+ "extra__azure_data_factory__resource_group_name":
RESOURCE_GROUP_NAME,
+ "extra__azure_data_factory__factory_name":
DATAFACTORY_NAME,
+ }
+ ),
+ )
+ mock_connection.return_value = mock_conn
+ hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID)
+ response = await hook.get_async_conn()
+ assert isinstance(response, DataFactoryManagementClient)
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize(
+ "mock_connection_params",
+ [
+ {
+ "extra__azure_data_factory__tenantId": "tenantId",
+ "extra__azure_data_factory__resource_group_name":
RESOURCE_GROUP_NAME,
+ "extra__azure_data_factory__factory_name": DATAFACTORY_NAME,
+ }
+ ],
+ )
+
@async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_connection")
+ async def test_get_async_conn_key_error_tenantId(self, mock_connection,
mock_connection_params):
+ """Test get_async_conn function with raising key error"""
Review Comment:
done
##########
airflow/providers/microsoft/azure/sensors/data_factory.py:
##########
@@ -78,3 +83,52 @@ def poke(self, context: Context) -> bool:
raise AzureDataFactoryPipelineRunException(f"Pipeline run
{self.run_id} has been cancelled.")
return pipeline_run_status ==
AzureDataFactoryPipelineRunStatus.SUCCEEDED
+
+
+class
AzureDataFactoryPipelineRunStatusAsyncSensor(AzureDataFactoryPipelineRunStatusSensor):
+ """
+ Checks the status of a pipeline run asynchronously.
+
+ :param azure_data_factory_conn_id: The connection identifier for
connecting to Azure Data Factory.
+ :param run_id: The pipeline run identifier.
+ :param resource_group_name: The resource group name.
+ :param factory_name: The data factory name.
+ :param poll_interval: polling period in seconds to check for the status
Review Comment:
fixed
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]