josh-fell commented on code in PR #29801:
URL: https://github.com/apache/airflow/pull/29801#discussion_r1132491038


##########
airflow/providers/microsoft/azure/hooks/data_factory.py:
##########
@@ -1039,3 +1048,124 @@ def test_connection(self) -> tuple[bool, str]:
             return success
         except Exception as e:
             return False, str(e)
+
+
+def provide_targeted_factory_async(func: T) -> T:
+    """
+    Provide the targeted factory to the async decorated function in case it 
isn't specified.
+
+    If ``resource_group_name`` or ``factory_name`` is not provided it defaults 
to the value specified in
+    the connection extras.
+    """
+    signature = inspect.signature(func)
+
+    @wraps(func)
+    async def wrapper(*args: Any, **kwargs: Any) -> Any:
+        bound_args = signature.bind(*args, **kwargs)
+
+        async def bind_argument(arg: Any, default_key: str) -> None:
+            # Check if arg was not included in the function signature or, if 
it is, the value is not provided.
+            if arg not in bound_args.arguments or bound_args.arguments[arg] is 
None:
+                self = args[0]
+                conn = await sync_to_async(self.get_connection)(self.conn_id)
+                extras = conn.extra_dejson
+                default_value = extras.get(default_key) or extras.get(
+                    f"extra__azure_data_factory__{default_key}"
+                )
+                if not default_value:
+                    raise AirflowException("Could not determine the targeted 
data factory.")
+
+                bound_args.arguments[arg] = default_value
+
+        await bind_argument("resource_group_name", "resource_group_name")
+        await bind_argument("factory_name", "factory_name")
+
+        return await func(*bound_args.args, **bound_args.kwargs)
+
+    return cast(T, wrapper)
+
+
+class AzureDataFactoryAsyncHook(AzureDataFactoryHook):
+    """
+    An Async Hook that connects to Azure DataFactory to perform pipeline 
operations
+
+    :param azure_data_factory_conn_id: The :ref:`Azure Data Factory connection 
id<howto/connection:adf>`.
+    """
+
+    def __init__(self, azure_data_factory_conn_id: str):

Review Comment:
   WDYT about making `azure_data_factory_conn_id` optional like it is in 
AzureDataFactoryHook?



##########
airflow/providers/microsoft/azure/hooks/data_factory.py:
##########
@@ -1039,3 +1048,124 @@ def test_connection(self) -> tuple[bool, str]:
             return success
         except Exception as e:
             return False, str(e)
+
+
+def provide_targeted_factory_async(func: T) -> T:
+    """
+    Provide the targeted factory to the async decorated function in case it 
isn't specified.
+
+    If ``resource_group_name`` or ``factory_name`` is not provided it defaults 
to the value specified in
+    the connection extras.
+    """
+    signature = inspect.signature(func)
+
+    @wraps(func)
+    async def wrapper(*args: Any, **kwargs: Any) -> Any:
+        bound_args = signature.bind(*args, **kwargs)
+
+        async def bind_argument(arg: Any, default_key: str) -> None:
+            # Check if arg was not included in the function signature or, if 
it is, the value is not provided.
+            if arg not in bound_args.arguments or bound_args.arguments[arg] is 
None:
+                self = args[0]
+                conn = await sync_to_async(self.get_connection)(self.conn_id)
+                extras = conn.extra_dejson
+                default_value = extras.get(default_key) or extras.get(
+                    f"extra__azure_data_factory__{default_key}"
+                )
+                if not default_value:
+                    raise AirflowException("Could not determine the targeted 
data factory.")
+
+                bound_args.arguments[arg] = default_value
+
+        await bind_argument("resource_group_name", "resource_group_name")
+        await bind_argument("factory_name", "factory_name")
+
+        return await func(*bound_args.args, **bound_args.kwargs)
+
+    return cast(T, wrapper)
+
+
+class AzureDataFactoryAsyncHook(AzureDataFactoryHook):
+    """
+    An Async Hook that connects to Azure DataFactory to perform pipeline 
operations
+
+    :param azure_data_factory_conn_id: The :ref:`Azure Data Factory connection 
id<howto/connection:adf>`.
+    """
+
+    def __init__(self, azure_data_factory_conn_id: str):
+        self._async_conn: AsyncDataFactoryManagementClient = None
+        self.conn_id = azure_data_factory_conn_id
+        super().__init__(azure_data_factory_conn_id=azure_data_factory_conn_id)
+
+    async def get_async_conn(self) -> AsyncDataFactoryManagementClient:
+        """Get async connection and connect to azure data factory"""
+        if self._conn is not None:

Review Comment:
   I think this function should use `self._async_conn` instead?



##########
airflow/providers/microsoft/azure/hooks/data_factory.py:
##########
@@ -1039,3 +1048,124 @@ def test_connection(self) -> tuple[bool, str]:
             return success
         except Exception as e:
             return False, str(e)
+
+
+def provide_targeted_factory_async(func: T) -> T:
+    """
+    Provide the targeted factory to the async decorated function in case it 
isn't specified.
+
+    If ``resource_group_name`` or ``factory_name`` is not provided it defaults 
to the value specified in
+    the connection extras.
+    """
+    signature = inspect.signature(func)
+
+    @wraps(func)
+    async def wrapper(*args: Any, **kwargs: Any) -> Any:
+        bound_args = signature.bind(*args, **kwargs)
+
+        async def bind_argument(arg: Any, default_key: str) -> None:
+            # Check if arg was not included in the function signature or, if 
it is, the value is not provided.
+            if arg not in bound_args.arguments or bound_args.arguments[arg] is 
None:
+                self = args[0]
+                conn = await sync_to_async(self.get_connection)(self.conn_id)
+                extras = conn.extra_dejson
+                default_value = extras.get(default_key) or extras.get(
+                    f"extra__azure_data_factory__{default_key}"
+                )
+                if not default_value:
+                    raise AirflowException("Could not determine the targeted 
data factory.")
+
+                bound_args.arguments[arg] = default_value
+
+        await bind_argument("resource_group_name", "resource_group_name")
+        await bind_argument("factory_name", "factory_name")
+
+        return await func(*bound_args.args, **bound_args.kwargs)
+
+    return cast(T, wrapper)
+
+
+class AzureDataFactoryAsyncHook(AzureDataFactoryHook):
+    """
+    An Async Hook that connects to Azure DataFactory to perform pipeline 
operations
+
+    :param azure_data_factory_conn_id: The :ref:`Azure Data Factory connection 
id<howto/connection:adf>`.
+    """
+
+    def __init__(self, azure_data_factory_conn_id: str):
+        self._async_conn: AsyncDataFactoryManagementClient = None
+        self.conn_id = azure_data_factory_conn_id
+        super().__init__(azure_data_factory_conn_id=azure_data_factory_conn_id)
+
+    async def get_async_conn(self) -> AsyncDataFactoryManagementClient:
+        """Get async connection and connect to azure data factory"""
+        if self._conn is not None:
+            return self._conn
+
+        conn = await sync_to_async(self.get_connection)(self.conn_id)
+        extras = conn.extra_dejson
+        tenant = get_field(extras, "tenantId")
+
+        try:
+            subscription_id = get_field(extras, "subscriptionId", strict=True)
+        except KeyError:
+            raise ValueError("A Subscription ID is required to connect to 
Azure Data Factory.")
+
+        credential: AsyncCredentials
+        if conn.login is not None and conn.password is not None:
+            if not tenant:
+                raise ValueError("A Tenant ID is required when authenticating 
with Client ID and Secret.")
+
+            credential = AsyncClientSecretCredential(
+                client_id=conn.login, client_secret=conn.password, 
tenant_id=tenant
+            )
+        else:
+            credential = AsyncDefaultAzureCredential()
+
+        return AsyncDataFactoryManagementClient(
+            credential=credential,
+            subscription_id=subscription_id,
+        )
+
+    @provide_targeted_factory_async
+    async def get_pipeline_run(
+        self,
+        run_id: str,
+        resource_group_name: str | None = None,
+        factory_name: str | None = None,
+        **config: Any,
+    ) -> PipelineRun:
+        """
+        Connect to Azure Data Factory asynchronously to get the pipeline run 
details by run id
+
+        :param run_id: The pipeline run identifier.
+        :param resource_group_name: The resource group name.
+        :param factory_name: The factory name.

Review Comment:
   ```suggestion
           :param factory_name: The factory name.
           :param config: Extra parameters for the ADF client.
   ```
   To be consistent with the other sync functions and to let users know they 
can pass in additional configuration should they choose to use this function 
directly.



##########
airflow/providers/microsoft/azure/hooks/data_factory.py:
##########
@@ -1039,3 +1048,124 @@ def test_connection(self) -> tuple[bool, str]:
             return success
         except Exception as e:
             return False, str(e)
+
+
+def provide_targeted_factory_async(func: T) -> T:
+    """
+    Provide the targeted factory to the async decorated function in case it 
isn't specified.
+
+    If ``resource_group_name`` or ``factory_name`` is not provided it defaults 
to the value specified in
+    the connection extras.
+    """
+    signature = inspect.signature(func)
+
+    @wraps(func)
+    async def wrapper(*args: Any, **kwargs: Any) -> Any:
+        bound_args = signature.bind(*args, **kwargs)
+
+        async def bind_argument(arg: Any, default_key: str) -> None:
+            # Check if arg was not included in the function signature or, if 
it is, the value is not provided.
+            if arg not in bound_args.arguments or bound_args.arguments[arg] is 
None:
+                self = args[0]
+                conn = await sync_to_async(self.get_connection)(self.conn_id)
+                extras = conn.extra_dejson
+                default_value = extras.get(default_key) or extras.get(
+                    f"extra__azure_data_factory__{default_key}"
+                )
+                if not default_value:
+                    raise AirflowException("Could not determine the targeted 
data factory.")
+
+                bound_args.arguments[arg] = default_value
+
+        await bind_argument("resource_group_name", "resource_group_name")
+        await bind_argument("factory_name", "factory_name")
+
+        return await func(*bound_args.args, **bound_args.kwargs)
+
+    return cast(T, wrapper)
+
+
+class AzureDataFactoryAsyncHook(AzureDataFactoryHook):
+    """
+    An Async Hook that connects to Azure DataFactory to perform pipeline 
operations
+
+    :param azure_data_factory_conn_id: The :ref:`Azure Data Factory connection 
id<howto/connection:adf>`.
+    """
+
+    def __init__(self, azure_data_factory_conn_id: str):
+        self._async_conn: AsyncDataFactoryManagementClient = None
+        self.conn_id = azure_data_factory_conn_id
+        super().__init__(azure_data_factory_conn_id=azure_data_factory_conn_id)
+
+    async def get_async_conn(self) -> AsyncDataFactoryManagementClient:
+        """Get async connection and connect to azure data factory"""
+        if self._conn is not None:

Review Comment:
   And `self._async_conn` should be returned too so the initial `is not None` 
check acts as a form of "caching".



##########
airflow/providers/microsoft/azure/sensors/data_factory.py:
##########
@@ -78,3 +83,52 @@ def poke(self, context: Context) -> bool:
             raise AzureDataFactoryPipelineRunException(f"Pipeline run 
{self.run_id} has been cancelled.")
 
         return pipeline_run_status == 
AzureDataFactoryPipelineRunStatus.SUCCEEDED
+
+
+class 
AzureDataFactoryPipelineRunStatusAsyncSensor(AzureDataFactoryPipelineRunStatusSensor):
+    """
+    Checks the status of a pipeline run asynchronously.
+
+    :param azure_data_factory_conn_id: The connection identifier for 
connecting to Azure Data Factory.
+    :param run_id: The pipeline run identifier.
+    :param resource_group_name: The resource group name.
+    :param factory_name: The data factory name.
+    :param poll_interval: polling period in seconds to check for the status

Review Comment:
   ```suggestion
       :param poke_interval: polling period in seconds to check for the status
   ```



##########
tests/providers/microsoft/azure/hooks/test_azure_data_factory.py:
##########
@@ -708,3 +724,245 @@ def 
test_backcompat_prefix_both_prefers_short(mock_connect):
         hook = AzureDataFactoryHook("my_conn")
         hook.delete_factory(factory_name="n/a")
         
mock_connect.return_value.factories.delete.assert_called_with("non-prefixed", 
"n/a")
+
+
+class TestAzureDataFactoryAsyncHook:
+    @pytest.mark.asyncio
+    
@async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_async_conn")
+    
@async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_pipeline_run")
+    async def test_get_adf_pipeline_run_status_queued(self, 
mock_get_pipeline_run, mock_conn):
+        """Test get_adf_pipeline_run_status function with mocked status"""
+        mock_status = "Queued"
+        mock_get_pipeline_run.return_value.status = mock_status
+        hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID)
+        response = await hook.get_adf_pipeline_run_status(RUN_ID, 
RESOURCE_GROUP_NAME, DATAFACTORY_NAME)
+        assert response == mock_status
+
+    @pytest.mark.asyncio
+    
@async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_async_conn")
+    
@async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_pipeline_run")
+    async def test_get_adf_pipeline_run_status_inprogress(
+        self,
+        mock_get_pipeline_run,
+        mock_conn,
+    ):
+        """Test get_adf_pipeline_run_status function with mocked status"""
+        mock_status = "InProgress"
+        mock_get_pipeline_run.return_value.status = mock_status
+        hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID)
+        response = await hook.get_adf_pipeline_run_status(RUN_ID, 
RESOURCE_GROUP_NAME, DATAFACTORY_NAME)
+        assert response == mock_status
+
+    @pytest.mark.asyncio
+    
@async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_async_conn")
+    
@async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_pipeline_run")
+    async def test_get_adf_pipeline_run_status_success(self, 
mock_get_pipeline_run, mock_conn):
+        """Test get_adf_pipeline_run_status function with mocked status"""
+        mock_status = "Succeeded"
+        mock_get_pipeline_run.return_value.status = mock_status
+        hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID)
+        response = await hook.get_adf_pipeline_run_status(RUN_ID, 
RESOURCE_GROUP_NAME, DATAFACTORY_NAME)
+        assert response == mock_status
+
+    @pytest.mark.asyncio
+    
@async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_async_conn")
+    
@async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_pipeline_run")
+    async def test_get_adf_pipeline_run_status_failed(self, 
mock_get_pipeline_run, mock_conn):
+        """Test get_adf_pipeline_run_status function with mocked status"""
+        mock_status = "Failed"
+        mock_get_pipeline_run.return_value.status = mock_status
+        hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID)
+        response = await hook.get_adf_pipeline_run_status(RUN_ID, 
RESOURCE_GROUP_NAME, DATAFACTORY_NAME)
+        assert response == mock_status
+
+    @pytest.mark.asyncio
+    
@async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_async_conn")
+    
@async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_pipeline_run")
+    async def test_get_adf_pipeline_run_status_cancelled(self, 
mock_get_pipeline_run, mock_conn):
+        """Test get_adf_pipeline_run_status function with mocked status"""
+        mock_status = "Cancelled"
+        mock_get_pipeline_run.return_value.status = mock_status
+        hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID)
+        response = await hook.get_adf_pipeline_run_status(RUN_ID, 
RESOURCE_GROUP_NAME, DATAFACTORY_NAME)
+        assert response == mock_status
+
+    @pytest.mark.asyncio
+    
@async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_async_conn")
+    
@async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_pipeline_run")
+    async def test_get_adf_pipeline_run_status_exception(self, 
mock_get_pipeline_run, mock_conn):
+        """Test get_adf_pipeline_run_status function with exception"""
+        mock_get_pipeline_run.side_effect = Exception("Test exception")
+        hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID)
+        with pytest.raises(AirflowException):
+            await hook.get_adf_pipeline_run_status(RUN_ID, 
RESOURCE_GROUP_NAME, DATAFACTORY_NAME)
+
+    @pytest.mark.asyncio
+    @async_mock.patch("azure.mgmt.datafactory.models._models_py3.PipelineRun")
+    
@async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_connection")
+    
@async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_async_conn")
+    async def test_get_pipeline_run_exception_without_resource(
+        self, mock_conn, mock_get_connection, mock_pipeline_run
+    ):
+        """
+        Test get_pipeline_run function without passing the resource name to 
check the decorator function and
+        raise exception
+        """
+        mock_connection = Connection(
+            extra=json.dumps({"extra__azure_data_factory__factory_name": 
DATAFACTORY_NAME})
+        )
+        mock_get_connection.return_value = mock_connection
+        
mock_conn.return_value.__aenter__.return_value.pipeline_runs.get.return_value = 
mock_pipeline_run
+        hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID)
+        with pytest.raises(AirflowException):
+            await hook.get_pipeline_run(RUN_ID, None, DATAFACTORY_NAME)
+
+    @pytest.mark.asyncio
+    
@async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_async_conn")
+    async def test_get_pipeline_run_exception(self, mock_conn):
+        """Test get_pipeline_run function with exception"""
+        
mock_conn.return_value.__aenter__.return_value.pipeline_runs.get.side_effect = 
Exception(
+            "Test exception"
+        )
+        hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID)
+        with pytest.raises(AirflowException):
+            await hook.get_pipeline_run(RUN_ID, RESOURCE_GROUP_NAME, 
DATAFACTORY_NAME)
+
+    @pytest.mark.asyncio
+    
@async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_connection")
+    async def test_get_async_conn(self, mock_connection):
+        """"""
+        mock_conn = Connection(
+            conn_id=DEFAULT_CONNECTION_CLIENT_SECRET,
+            conn_type="azure_data_factory",
+            login="clientId",
+            password="clientSecret",
+            extra=json.dumps(
+                {
+                    "extra__azure_data_factory__tenantId": "tenantId",
+                    "extra__azure_data_factory__subscriptionId": 
"subscriptionId",
+                    "extra__azure_data_factory__resource_group_name": 
RESOURCE_GROUP_NAME,
+                    "extra__azure_data_factory__factory_name": 
DATAFACTORY_NAME,
+                }
+            ),
+        )
+        mock_connection.return_value = mock_conn
+        hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID)
+        response = await hook.get_async_conn()
+        assert isinstance(response, DataFactoryManagementClient)
+
+    @pytest.mark.asyncio
+    
@async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_connection")
+    async def test_get_async_conn_without_login_id(self, mock_connection):
+        """Test get_async_conn function without login id"""
+        mock_conn = Connection(
+            conn_id=DEFAULT_CONNECTION_CLIENT_SECRET,
+            conn_type="azure_data_factory",
+            extra=json.dumps(
+                {
+                    "extra__azure_data_factory__tenantId": "tenantId",
+                    "extra__azure_data_factory__subscriptionId": 
"subscriptionId",
+                    "extra__azure_data_factory__resource_group_name": 
RESOURCE_GROUP_NAME,
+                    "extra__azure_data_factory__factory_name": 
DATAFACTORY_NAME,
+                }
+            ),
+        )
+        mock_connection.return_value = mock_conn
+        hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID)
+        response = await hook.get_async_conn()
+        assert isinstance(response, DataFactoryManagementClient)
+
+    @pytest.mark.asyncio
+    @pytest.mark.parametrize(
+        "mock_connection_params",
+        [
+            {
+                "extra__azure_data_factory__tenantId": "tenantId",
+                "extra__azure_data_factory__resource_group_name": 
RESOURCE_GROUP_NAME,
+                "extra__azure_data_factory__factory_name": DATAFACTORY_NAME,
+            }
+        ],
+    )
+    
@async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_connection")
+    async def test_get_async_conn_key_error_tenantId(self, mock_connection, 
mock_connection_params):
+        """Test get_async_conn function with raising key error"""

Review Comment:
   Would you mind updating the docstrings for this function and 
`test_get_async_conn_key_error_subscriptionId`? They are both the same.
   
   Although I think the function names should be switched? 
`test_get_async_conn_key_error_tenantId` is testing that  a subscription ID is 
not part of the connection while `test_get_async_conn_key_error_subscriptionId` 
is testing that tenant ID is required when using login/password.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to