phanikumv commented on code in PR #29801:
URL: https://github.com/apache/airflow/pull/29801#discussion_r1132621981
##########
airflow/providers/microsoft/azure/hooks/data_factory.py:
##########
@@ -1039,3 +1048,124 @@ def test_connection(self) -> tuple[bool, str]:
return success
except Exception as e:
return False, str(e)
+
+
+def provide_targeted_factory_async(func: T) -> T:
+ """
+ Provide the targeted factory to the async decorated function in case it
isn't specified.
+
+ If ``resource_group_name`` or ``factory_name`` is not provided it defaults
to the value specified in
+ the connection extras.
+ """
+ signature = inspect.signature(func)
+
+ @wraps(func)
+ async def wrapper(*args: Any, **kwargs: Any) -> Any:
+ bound_args = signature.bind(*args, **kwargs)
+
+ async def bind_argument(arg: Any, default_key: str) -> None:
+ # Check if arg was not included in the function signature or, if
it is, the value is not provided.
+ if arg not in bound_args.arguments or bound_args.arguments[arg] is
None:
+ self = args[0]
+ conn = await sync_to_async(self.get_connection)(self.conn_id)
+ extras = conn.extra_dejson
+ default_value = extras.get(default_key) or extras.get(
+ f"extra__azure_data_factory__{default_key}"
+ )
+ if not default_value:
+ raise AirflowException("Could not determine the targeted
data factory.")
+
+ bound_args.arguments[arg] = default_value
+
+ await bind_argument("resource_group_name", "resource_group_name")
+ await bind_argument("factory_name", "factory_name")
+
+ return await func(*bound_args.args, **bound_args.kwargs)
+
+ return cast(T, wrapper)
+
+
+class AzureDataFactoryAsyncHook(AzureDataFactoryHook):
+ """
+ An Async Hook that connects to Azure DataFactory to perform pipeline
operations
+
+ :param azure_data_factory_conn_id: The :ref:`Azure Data Factory connection
id<howto/connection:adf>`.
+ """
+
+ def __init__(self, azure_data_factory_conn_id: str):
+ self._async_conn: AsyncDataFactoryManagementClient = None
+ self.conn_id = azure_data_factory_conn_id
+ super().__init__(azure_data_factory_conn_id=azure_data_factory_conn_id)
+
+ async def get_async_conn(self) -> AsyncDataFactoryManagementClient:
+ """Get async connection and connect to azure data factory"""
+ if self._conn is not None:
+ return self._conn
+
+ conn = await sync_to_async(self.get_connection)(self.conn_id)
+ extras = conn.extra_dejson
+ tenant = get_field(extras, "tenantId")
+
+ try:
+ subscription_id = get_field(extras, "subscriptionId", strict=True)
+ except KeyError:
+ raise ValueError("A Subscription ID is required to connect to
Azure Data Factory.")
+
+ credential: AsyncCredentials
+ if conn.login is not None and conn.password is not None:
+ if not tenant:
+ raise ValueError("A Tenant ID is required when authenticating
with Client ID and Secret.")
+
+ credential = AsyncClientSecretCredential(
+ client_id=conn.login, client_secret=conn.password,
tenant_id=tenant
+ )
+ else:
+ credential = AsyncDefaultAzureCredential()
+
+ return AsyncDataFactoryManagementClient(
+ credential=credential,
+ subscription_id=subscription_id,
+ )
+
+ @provide_targeted_factory_async
+ async def get_pipeline_run(
+ self,
+ run_id: str,
+ resource_group_name: str | None = None,
+ factory_name: str | None = None,
+ **config: Any,
+ ) -> PipelineRun:
+ """
+ Connect to Azure Data Factory asynchronously to get the pipeline run
details by run id
+
+ :param run_id: The pipeline run identifier.
+ :param resource_group_name: The resource group name.
+ :param factory_name: The factory name.
Review Comment:
done
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]