This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 3d575fed54 Allow passing fully_qualified_namespace and credential to 
initialize Azure Service Bus Client (#33493)
3d575fed54 is described below

commit 3d575fed540e7521976303cd763a20e090e65d9e
Author: Wei Lee <[email protected]>
AuthorDate: Sat Aug 26 15:20:30 2023 +0800

    Allow passing fully_qualified_namespace and credential to initialize Azure 
Service Bus Client (#33493)
    
    * feat(providers/azure): allow passing fully_qualified_namespace and 
credential to initalize Azure Service Bus Client
    
    * style(providers/microsoft): ignore credential mypy warning
    
    as the docstring 
https://github.com/Azure/azure-sdk-for-python/blob/main/sdk/servicebus/azure-servicebus/azure/servicebus/_servicebus_client.py#L56-L59
 mentioned, it can accept the credential from azure-identity
    
    * test(providers/microsoft): add test cases for checking whether abs hooks 
fallback to DefaultAzureCredential when | connection_string is not provided
---
 airflow/providers/microsoft/azure/hooks/asb.py     | 60 ++++++++++++++++++++--
 airflow/providers/microsoft/azure/operators/asb.py |  7 ++-
 tests/providers/microsoft/azure/hooks/test_asb.py  | 32 ++++++++++++
 3 files changed, 95 insertions(+), 4 deletions(-)

diff --git a/airflow/providers/microsoft/azure/hooks/asb.py 
b/airflow/providers/microsoft/azure/hooks/asb.py
index 9eb921a656..c20944e160 100644
--- a/airflow/providers/microsoft/azure/hooks/asb.py
+++ b/airflow/providers/microsoft/azure/hooks/asb.py
@@ -18,10 +18,12 @@ from __future__ import annotations
 
 from typing import Any
 
+from azure.identity import DefaultAzureCredential
 from azure.servicebus import ServiceBusClient, ServiceBusMessage, 
ServiceBusSender
 from azure.servicebus.management import QueueProperties, 
ServiceBusAdministrationClient
 
 from airflow.hooks.base import BaseHook
+from airflow.providers.microsoft.azure.utils import get_field
 
 
 class BaseAzureServiceBusHook(BaseHook):
@@ -37,6 +39,20 @@ class BaseAzureServiceBusHook(BaseHook):
     conn_type = "azure_service_bus"
     hook_name = "Azure Service Bus"
 
+    @staticmethod
+    def get_connection_form_widgets() -> dict[str, Any]:
+        """Returns connection widgets to add to connection form."""
+        from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
+        from flask_babel import lazy_gettext
+        from wtforms import PasswordField, StringField
+
+        return {
+            "fully_qualified_namespace": StringField(
+                lazy_gettext("Fully Qualified Namespace"), 
widget=BS3TextFieldWidget()
+            ),
+            "credential": PasswordField(lazy_gettext("Credential"), 
widget=BS3TextFieldWidget()),
+        }
+
     @staticmethod
     def get_ui_field_behaviour() -> dict[str, Any]:
         """Returns custom field behaviour."""
@@ -44,6 +60,10 @@ class BaseAzureServiceBusHook(BaseHook):
             "hidden_fields": ["port", "host", "extra", "login", "password"],
             "relabeling": {"schema": "Connection String"},
             "placeholders": {
+                "fully_qualified_namespace": (
+                    "<Resource group>.servicebus.windows.net (for Azure AD 
authenticaltion)"
+                ),
+                "credential": "credential",
                 "schema": "Endpoint=sb://<Resource 
group>.servicebus.windows.net/;SharedAccessKeyName=<AccessKeyName>;SharedAccessKey=<SharedAccessKey>",
  # noqa
             },
         }
@@ -55,6 +75,14 @@ class BaseAzureServiceBusHook(BaseHook):
     def get_conn(self):
         raise NotImplementedError
 
+    def _get_field(self, extras: dict, field_name: str) -> str:
+        return get_field(
+            conn_id=self.conn_id,
+            conn_type=self.conn_type,
+            extras=extras,
+            field_name=field_name,
+        )
+
 
 class AdminClientHook(BaseAzureServiceBusHook):
     """Interact with the ServiceBusAdministrationClient.
@@ -70,9 +98,21 @@ class AdminClientHook(BaseAzureServiceBusHook):
         This uses the connection string in connection details.
         """
         conn = self.get_connection(self.conn_id)
-
         connection_string: str = str(conn.schema)
-        return 
ServiceBusAdministrationClient.from_connection_string(connection_string)
+        if connection_string:
+            client = 
ServiceBusAdministrationClient.from_connection_string(connection_string)
+        else:
+            extras = conn.extra_dejson
+            credential: str | DefaultAzureCredential = 
self._get_field(extras=extras, field_name="credential")
+            fully_qualified_namespace = self._get_field(extras=extras, 
field_name="fully_qualified_namespace")
+            if not credential:
+                credential = DefaultAzureCredential()
+            client = ServiceBusAdministrationClient(
+                fully_qualified_namespace=fully_qualified_namespace,
+                credential=credential,  # type: ignore[arg-type]
+            )
+        self.log.info("Create and returns ServiceBusAdministrationClient")
+        return client
 
     def create_queue(
         self,
@@ -143,9 +183,21 @@ class MessageHook(BaseAzureServiceBusHook):
         """Create and returns ServiceBusClient by using the connection string 
in connection details."""
         conn = self.get_connection(self.conn_id)
         connection_string: str = str(conn.schema)
+        if connection_string:
+            client = 
ServiceBusClient.from_connection_string(connection_string, logging_enable=True)
+        else:
+            extras = conn.extra_dejson
+            credential: str | DefaultAzureCredential = 
self._get_field(extras=extras, field_name="credential")
+            fully_qualified_namespace = self._get_field(extras=extras, 
field_name="fully_qualified_namespace")
+            if not credential:
+                credential = DefaultAzureCredential()
+            client = ServiceBusClient(
+                fully_qualified_namespace=fully_qualified_namespace,
+                credential=credential,  # type: ignore[arg-type]
+            )
 
         self.log.info("Create and returns ServiceBusClient")
-        return 
ServiceBusClient.from_connection_string(conn_str=connection_string, 
logging_enable=True)
+        return client
 
     def send_message(self, queue_name: str, messages: str | list[str], 
batch_message_flag: bool = False):
         """Use ServiceBusClient Send to send message(s) to a Service Bus Queue.
@@ -249,3 +301,5 @@ class MessageHook(BaseAzureServiceBusHook):
                 for msg in received_msgs:
                     self.log.info(msg)
                     subscription_receiver.complete_message(msg)
+                    self.log.info(msg)
+                    subscription_receiver.complete_message(msg)
diff --git a/airflow/providers/microsoft/azure/operators/asb.py 
b/airflow/providers/microsoft/azure/operators/asb.py
index d9a460b77d..39a6602ec2 100644
--- a/airflow/providers/microsoft/azure/operators/asb.py
+++ b/airflow/providers/microsoft/azure/operators/asb.py
@@ -19,6 +19,8 @@ from __future__ import annotations
 import datetime
 from typing import TYPE_CHECKING, Any, Sequence
 
+from azure.core.exceptions import ResourceNotFoundError
+
 from airflow.models import BaseOperator
 from airflow.providers.microsoft.azure.hooks.asb import AdminClientHook, 
MessageHook
 
@@ -293,7 +295,10 @@ class AzureServiceBusTopicCreateOperator(BaseOperator):
         hook = 
AdminClientHook(azure_service_bus_conn_id=self.azure_service_bus_conn_id)
 
         with hook.get_conn() as service_mgmt_conn:
-            topic_properties = service_mgmt_conn.get_topic(self.topic_name)
+            try:
+                topic_properties = service_mgmt_conn.get_topic(self.topic_name)
+            except ResourceNotFoundError:
+                topic_properties = None
             if topic_properties and topic_properties.name == self.topic_name:
                 self.log.info("Topic name already exists")
                 return topic_properties.name
diff --git a/tests/providers/microsoft/azure/hooks/test_asb.py 
b/tests/providers/microsoft/azure/hooks/test_asb.py
index 5f626d6c29..8364bc9dcb 100644
--- a/tests/providers/microsoft/azure/hooks/test_asb.py
+++ b/tests/providers/microsoft/azure/hooks/test_asb.py
@@ -49,11 +49,27 @@ class TestAdminClientHook:
                 schema=self.connection_string,
             )
         )
+        self.mock_conn_without_schema = Connection(
+            conn_id="azure_service_bus_default",
+            conn_type="azure_service_bus",
+            schema="",
+            extra={"fully_qualified_namespace": "fully_qualified_namespace"},
+        )
 
     def test_get_conn(self):
         hook = AdminClientHook(azure_service_bus_conn_id=self.conn_id)
         assert isinstance(hook.get_conn(), ServiceBusAdministrationClient)
 
+    
@mock.patch("airflow.providers.microsoft.azure.hooks.asb.DefaultAzureCredential")
+    
@mock.patch("airflow.providers.microsoft.azure.hooks.asb.AdminClientHook.get_connection")
+    def 
test_get_conn_fallback_to_default_azure_credential_when_schema_is_not_provided(
+        self, mock_connection, mock_default_azure_credential
+    ):
+        mock_connection.return_value = self.mock_conn_without_schema
+        hook = AdminClientHook(azure_service_bus_conn_id=self.conn_id)
+        assert isinstance(hook.get_conn(), ServiceBusAdministrationClient)
+        mock_default_azure_credential.assert_called_once()
+
     @mock.patch("azure.servicebus.management.QueueProperties")
     
@mock.patch("airflow.providers.microsoft.azure.hooks.asb.AdminClientHook.get_conn")
     def test_create_queue(self, mock_sb_admin_client, mock_queue_properties):
@@ -140,6 +156,12 @@ class TestMessageHook:
                 schema=self.connection_string,
             )
         )
+        self.mock_conn_without_schema = Connection(
+            conn_id="azure_service_bus_default",
+            conn_type="azure_service_bus",
+            schema="",
+            extra={"fully_qualified_namespace": "fully_qualified_namespace"},
+        )
 
     def test_get_service_bus_message_conn(self):
         """
@@ -149,6 +171,16 @@ class TestMessageHook:
         hook = MessageHook(azure_service_bus_conn_id=self.conn_id)
         assert isinstance(hook.get_conn(), ServiceBusClient)
 
+    
@mock.patch("airflow.providers.microsoft.azure.hooks.asb.DefaultAzureCredential")
+    
@mock.patch("airflow.providers.microsoft.azure.hooks.asb.MessageHook.get_connection")
+    def 
test_get_conn_fallback_to_default_azure_credential_when_schema_is_not_provided(
+        self, mock_connection, mock_default_azure_credential
+    ):
+        mock_connection.return_value = self.mock_conn_without_schema
+        hook = MessageHook(azure_service_bus_conn_id=self.conn_id)
+        assert isinstance(hook.get_conn(), ServiceBusClient)
+        mock_default_azure_credential.assert_called_once()
+
     @pytest.mark.parametrize(
         "mock_message, mock_batch_flag",
         [

Reply via email to