josh-fell commented on code in PR #29801:
URL: https://github.com/apache/airflow/pull/29801#discussion_r1124899383


##########
airflow/providers/microsoft/azure/hooks/data_factory.py:
##########
@@ -1039,3 +1048,120 @@ def test_connection(self) -> tuple[bool, str]:
             return success
         except Exception as e:
             return False, str(e)
+
+
+def provide_targeted_factory_async(func: T) -> T:
+    """
+    Provide the targeted factory to the async decorated function in case it 
isn't specified.
+
+    If ``resource_group_name`` or ``factory_name`` is not provided it defaults 
to the value specified in
+    the connection extras.
+    """
+    signature = inspect.signature(func)
+
+    @wraps(func)
+    async def wrapper(*args: Any, **kwargs: Any) -> Any:
+        bound_args = signature.bind(*args, **kwargs)
+
+        async def bind_argument(arg: Any, default_key: str) -> None:
+            # Check if arg was not included in the function signature or, if 
it is, the value is not provided.
+            if arg not in bound_args.arguments or bound_args.arguments[arg] is 
None:
+                self = args[0]
+                conn = await sync_to_async(self.get_connection)(self.conn_id)
+                default_value = conn.extra_dejson.get(default_key)

Review Comment:
   Can you update this to match the sync version? Prefixing connection extras 
with `extra__...` is no longer needed, but we do keep the check for backwards 
compat.



##########
airflow/providers/microsoft/azure/hooks/data_factory.py:
##########
@@ -1039,3 +1048,120 @@ def test_connection(self) -> tuple[bool, str]:
             return success
         except Exception as e:
             return False, str(e)
+
+
+def provide_targeted_factory_async(func: T) -> T:
+    """
+    Provide the targeted factory to the async decorated function in case it 
isn't specified.
+
+    If ``resource_group_name`` or ``factory_name`` is not provided it defaults 
to the value specified in
+    the connection extras.
+    """
+    signature = inspect.signature(func)
+
+    @wraps(func)
+    async def wrapper(*args: Any, **kwargs: Any) -> Any:
+        bound_args = signature.bind(*args, **kwargs)
+
+        async def bind_argument(arg: Any, default_key: str) -> None:
+            # Check if arg was not included in the function signature or, if 
it is, the value is not provided.
+            if arg not in bound_args.arguments or bound_args.arguments[arg] is 
None:
+                self = args[0]
+                conn = await sync_to_async(self.get_connection)(self.conn_id)
+                default_value = conn.extra_dejson.get(default_key)
+                if not default_value:
+                    raise AirflowException("Could not determine the targeted 
data factory.")
+
+                bound_args.arguments[arg] = conn.extra_dejson[default_key]
+
+        await bind_argument("resource_group_name", 
"extra__azure_data_factory__resource_group_name")
+        await bind_argument("factory_name", 
"extra__azure_data_factory__factory_name")
+
+        return await func(*bound_args.args, **bound_args.kwargs)
+
+    return cast(T, wrapper)
+
+
+class AzureDataFactoryAsyncHook(AzureDataFactoryHook):
+    """
+    An Async Hook that connects to Azure DataFactory to perform pipeline 
operations
+
+    :param azure_data_factory_conn_id: The :ref:`Azure Data Factory connection 
id<howto/connection:adf>`.
+    """
+
+    def __init__(self, azure_data_factory_conn_id: str):
+        self._async_conn: AsyncDataFactoryManagementClient = None
+        self.conn_id = azure_data_factory_conn_id
+        super().__init__(azure_data_factory_conn_id=azure_data_factory_conn_id)
+
+    async def get_async_conn(self) -> AsyncDataFactoryManagementClient:
+        """Get async connection and connect to azure data factory"""
+        if self._conn is not None:
+            return self._conn
+
+        conn = await sync_to_async(self.get_connection)(self.conn_id)
+        tenant = conn.extra_dejson.get("extra__azure_data_factory__tenantId")

Review Comment:
   Same idea here. Should use the `get_field()` function to retrieve the extras.



##########
tests/providers/microsoft/azure/hooks/test_azure_data_factory.py:
##########
@@ -708,3 +723,214 @@ def 
test_backcompat_prefix_both_prefers_short(mock_connect):
         hook = AzureDataFactoryHook("my_conn")
         hook.delete_factory(factory_name="n/a")
         
mock_connect.return_value.factories.delete.assert_called_with("non-prefixed", 
"n/a")
+
+
+class TestAzureDataFactoryAsyncHook:
+    @pytest.mark.asyncio
+    @pytest.mark.parametrize(

Review Comment:
   Why parametrize if each status is its own explicit test? You can inject test 
case `ids` like `pytest.mark.parametrize(..., ids=["test_status1", 
"test_status2"])`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to