This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 3d575fed54 Allow passing fully_qualified_namespace and credential to
initialize Azure Service Bus Client (#33493)
3d575fed54 is described below
commit 3d575fed540e7521976303cd763a20e090e65d9e
Author: Wei Lee <[email protected]>
AuthorDate: Sat Aug 26 15:20:30 2023 +0800
Allow passing fully_qualified_namespace and credential to initialize Azure
Service Bus Client (#33493)
* feat(providers/azure): allow passing fully_qualified_namespace and
credential to initalize Azure Service Bus Client
* style(providers/microsoft): ignore credential mypy warning
as the docstring
https://github.com/Azure/azure-sdk-for-python/blob/main/sdk/servicebus/azure-servicebus/azure/servicebus/_servicebus_client.py#L56-L59
mentioned, it can accept the credential from azure-identity
* test(providers/microsoft): add test cases for checking whether abs hooks
fallback to DefaultAzureCredential when | connection_string is not provided
---
airflow/providers/microsoft/azure/hooks/asb.py | 60 ++++++++++++++++++++--
airflow/providers/microsoft/azure/operators/asb.py | 7 ++-
tests/providers/microsoft/azure/hooks/test_asb.py | 32 ++++++++++++
3 files changed, 95 insertions(+), 4 deletions(-)
diff --git a/airflow/providers/microsoft/azure/hooks/asb.py
b/airflow/providers/microsoft/azure/hooks/asb.py
index 9eb921a656..c20944e160 100644
--- a/airflow/providers/microsoft/azure/hooks/asb.py
+++ b/airflow/providers/microsoft/azure/hooks/asb.py
@@ -18,10 +18,12 @@ from __future__ import annotations
from typing import Any
+from azure.identity import DefaultAzureCredential
from azure.servicebus import ServiceBusClient, ServiceBusMessage,
ServiceBusSender
from azure.servicebus.management import QueueProperties,
ServiceBusAdministrationClient
from airflow.hooks.base import BaseHook
+from airflow.providers.microsoft.azure.utils import get_field
class BaseAzureServiceBusHook(BaseHook):
@@ -37,6 +39,20 @@ class BaseAzureServiceBusHook(BaseHook):
conn_type = "azure_service_bus"
hook_name = "Azure Service Bus"
+ @staticmethod
+ def get_connection_form_widgets() -> dict[str, Any]:
+ """Returns connection widgets to add to connection form."""
+ from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
+ from flask_babel import lazy_gettext
+ from wtforms import PasswordField, StringField
+
+ return {
+ "fully_qualified_namespace": StringField(
+ lazy_gettext("Fully Qualified Namespace"),
widget=BS3TextFieldWidget()
+ ),
+ "credential": PasswordField(lazy_gettext("Credential"),
widget=BS3TextFieldWidget()),
+ }
+
@staticmethod
def get_ui_field_behaviour() -> dict[str, Any]:
"""Returns custom field behaviour."""
@@ -44,6 +60,10 @@ class BaseAzureServiceBusHook(BaseHook):
"hidden_fields": ["port", "host", "extra", "login", "password"],
"relabeling": {"schema": "Connection String"},
"placeholders": {
+ "fully_qualified_namespace": (
+ "<Resource group>.servicebus.windows.net (for Azure AD
authenticaltion)"
+ ),
+ "credential": "credential",
"schema": "Endpoint=sb://<Resource
group>.servicebus.windows.net/;SharedAccessKeyName=<AccessKeyName>;SharedAccessKey=<SharedAccessKey>",
# noqa
},
}
@@ -55,6 +75,14 @@ class BaseAzureServiceBusHook(BaseHook):
def get_conn(self):
raise NotImplementedError
+ def _get_field(self, extras: dict, field_name: str) -> str:
+ return get_field(
+ conn_id=self.conn_id,
+ conn_type=self.conn_type,
+ extras=extras,
+ field_name=field_name,
+ )
+
class AdminClientHook(BaseAzureServiceBusHook):
"""Interact with the ServiceBusAdministrationClient.
@@ -70,9 +98,21 @@ class AdminClientHook(BaseAzureServiceBusHook):
This uses the connection string in connection details.
"""
conn = self.get_connection(self.conn_id)
-
connection_string: str = str(conn.schema)
- return
ServiceBusAdministrationClient.from_connection_string(connection_string)
+ if connection_string:
+ client =
ServiceBusAdministrationClient.from_connection_string(connection_string)
+ else:
+ extras = conn.extra_dejson
+ credential: str | DefaultAzureCredential =
self._get_field(extras=extras, field_name="credential")
+ fully_qualified_namespace = self._get_field(extras=extras,
field_name="fully_qualified_namespace")
+ if not credential:
+ credential = DefaultAzureCredential()
+ client = ServiceBusAdministrationClient(
+ fully_qualified_namespace=fully_qualified_namespace,
+ credential=credential, # type: ignore[arg-type]
+ )
+ self.log.info("Create and returns ServiceBusAdministrationClient")
+ return client
def create_queue(
self,
@@ -143,9 +183,21 @@ class MessageHook(BaseAzureServiceBusHook):
"""Create and returns ServiceBusClient by using the connection string
in connection details."""
conn = self.get_connection(self.conn_id)
connection_string: str = str(conn.schema)
+ if connection_string:
+ client =
ServiceBusClient.from_connection_string(connection_string, logging_enable=True)
+ else:
+ extras = conn.extra_dejson
+ credential: str | DefaultAzureCredential =
self._get_field(extras=extras, field_name="credential")
+ fully_qualified_namespace = self._get_field(extras=extras,
field_name="fully_qualified_namespace")
+ if not credential:
+ credential = DefaultAzureCredential()
+ client = ServiceBusClient(
+ fully_qualified_namespace=fully_qualified_namespace,
+ credential=credential, # type: ignore[arg-type]
+ )
self.log.info("Create and returns ServiceBusClient")
- return
ServiceBusClient.from_connection_string(conn_str=connection_string,
logging_enable=True)
+ return client
def send_message(self, queue_name: str, messages: str | list[str],
batch_message_flag: bool = False):
"""Use ServiceBusClient Send to send message(s) to a Service Bus Queue.
@@ -249,3 +301,5 @@ class MessageHook(BaseAzureServiceBusHook):
for msg in received_msgs:
self.log.info(msg)
subscription_receiver.complete_message(msg)
+ self.log.info(msg)
+ subscription_receiver.complete_message(msg)
diff --git a/airflow/providers/microsoft/azure/operators/asb.py
b/airflow/providers/microsoft/azure/operators/asb.py
index d9a460b77d..39a6602ec2 100644
--- a/airflow/providers/microsoft/azure/operators/asb.py
+++ b/airflow/providers/microsoft/azure/operators/asb.py
@@ -19,6 +19,8 @@ from __future__ import annotations
import datetime
from typing import TYPE_CHECKING, Any, Sequence
+from azure.core.exceptions import ResourceNotFoundError
+
from airflow.models import BaseOperator
from airflow.providers.microsoft.azure.hooks.asb import AdminClientHook,
MessageHook
@@ -293,7 +295,10 @@ class AzureServiceBusTopicCreateOperator(BaseOperator):
hook =
AdminClientHook(azure_service_bus_conn_id=self.azure_service_bus_conn_id)
with hook.get_conn() as service_mgmt_conn:
- topic_properties = service_mgmt_conn.get_topic(self.topic_name)
+ try:
+ topic_properties = service_mgmt_conn.get_topic(self.topic_name)
+ except ResourceNotFoundError:
+ topic_properties = None
if topic_properties and topic_properties.name == self.topic_name:
self.log.info("Topic name already exists")
return topic_properties.name
diff --git a/tests/providers/microsoft/azure/hooks/test_asb.py
b/tests/providers/microsoft/azure/hooks/test_asb.py
index 5f626d6c29..8364bc9dcb 100644
--- a/tests/providers/microsoft/azure/hooks/test_asb.py
+++ b/tests/providers/microsoft/azure/hooks/test_asb.py
@@ -49,11 +49,27 @@ class TestAdminClientHook:
schema=self.connection_string,
)
)
+ self.mock_conn_without_schema = Connection(
+ conn_id="azure_service_bus_default",
+ conn_type="azure_service_bus",
+ schema="",
+ extra={"fully_qualified_namespace": "fully_qualified_namespace"},
+ )
def test_get_conn(self):
hook = AdminClientHook(azure_service_bus_conn_id=self.conn_id)
assert isinstance(hook.get_conn(), ServiceBusAdministrationClient)
+
@mock.patch("airflow.providers.microsoft.azure.hooks.asb.DefaultAzureCredential")
+
@mock.patch("airflow.providers.microsoft.azure.hooks.asb.AdminClientHook.get_connection")
+ def
test_get_conn_fallback_to_default_azure_credential_when_schema_is_not_provided(
+ self, mock_connection, mock_default_azure_credential
+ ):
+ mock_connection.return_value = self.mock_conn_without_schema
+ hook = AdminClientHook(azure_service_bus_conn_id=self.conn_id)
+ assert isinstance(hook.get_conn(), ServiceBusAdministrationClient)
+ mock_default_azure_credential.assert_called_once()
+
@mock.patch("azure.servicebus.management.QueueProperties")
@mock.patch("airflow.providers.microsoft.azure.hooks.asb.AdminClientHook.get_conn")
def test_create_queue(self, mock_sb_admin_client, mock_queue_properties):
@@ -140,6 +156,12 @@ class TestMessageHook:
schema=self.connection_string,
)
)
+ self.mock_conn_without_schema = Connection(
+ conn_id="azure_service_bus_default",
+ conn_type="azure_service_bus",
+ schema="",
+ extra={"fully_qualified_namespace": "fully_qualified_namespace"},
+ )
def test_get_service_bus_message_conn(self):
"""
@@ -149,6 +171,16 @@ class TestMessageHook:
hook = MessageHook(azure_service_bus_conn_id=self.conn_id)
assert isinstance(hook.get_conn(), ServiceBusClient)
+
@mock.patch("airflow.providers.microsoft.azure.hooks.asb.DefaultAzureCredential")
+
@mock.patch("airflow.providers.microsoft.azure.hooks.asb.MessageHook.get_connection")
+ def
test_get_conn_fallback_to_default_azure_credential_when_schema_is_not_provided(
+ self, mock_connection, mock_default_azure_credential
+ ):
+ mock_connection.return_value = self.mock_conn_without_schema
+ hook = MessageHook(azure_service_bus_conn_id=self.conn_id)
+ assert isinstance(hook.get_conn(), ServiceBusClient)
+ mock_default_azure_credential.assert_called_once()
+
@pytest.mark.parametrize(
"mock_message, mock_batch_flag",
[