phanikumv commented on code in PR #29801:
URL: https://github.com/apache/airflow/pull/29801#discussion_r1133483001
##########
airflow/providers/microsoft/azure/hooks/data_factory.py:
##########
@@ -1039,3 +1048,124 @@ def test_connection(self) -> tuple[bool, str]:
return success
except Exception as e:
return False, str(e)
+
+
+def provide_targeted_factory_async(func: T) -> T:
+ """
+ Provide the targeted factory to the async decorated function in case it
isn't specified.
+
+ If ``resource_group_name`` or ``factory_name`` is not provided it defaults
to the value specified in
+ the connection extras.
+ """
+ signature = inspect.signature(func)
+
+ @wraps(func)
+ async def wrapper(*args: Any, **kwargs: Any) -> Any:
+ bound_args = signature.bind(*args, **kwargs)
+
+ async def bind_argument(arg: Any, default_key: str) -> None:
+ # Check if arg was not included in the function signature or, if
it is, the value is not provided.
+ if arg not in bound_args.arguments or bound_args.arguments[arg] is
None:
+ self = args[0]
+ conn = await sync_to_async(self.get_connection)(self.conn_id)
+ extras = conn.extra_dejson
+ default_value = extras.get(default_key) or extras.get(
+ f"extra__azure_data_factory__{default_key}"
+ )
+ if not default_value:
+ raise AirflowException("Could not determine the targeted
data factory.")
+
+ bound_args.arguments[arg] = default_value
+
+ await bind_argument("resource_group_name", "resource_group_name")
+ await bind_argument("factory_name", "factory_name")
+
+ return await func(*bound_args.args, **bound_args.kwargs)
+
+ return cast(T, wrapper)
+
+
+class AzureDataFactoryAsyncHook(AzureDataFactoryHook):
+ """
+ An Async Hook that connects to Azure DataFactory to perform pipeline
operations
+
+ :param azure_data_factory_conn_id: The :ref:`Azure Data Factory connection
id<howto/connection:adf>`.
+ """
+
+ def __init__(self, azure_data_factory_conn_id: str):
+ self._async_conn: AsyncDataFactoryManagementClient = None
+ self.conn_id = azure_data_factory_conn_id
+ super().__init__(azure_data_factory_conn_id=azure_data_factory_conn_id)
+
+ async def get_async_conn(self) -> AsyncDataFactoryManagementClient:
+ """Get async connection and connect to azure data factory"""
+ if self._conn is not None:
Review Comment:
this is done in
[264e90d](https://github.com/apache/airflow/pull/29801/commits/264e90dd62f4f6be0498b7df0fb03594b5936153)
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]