This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new cbe6c2dd24 Add `delete_topic` to `KafkaAdminClientHook` and teardown 
logic to Kafka integration tests (#40142)
cbe6c2dd24 is described below

commit cbe6c2dd243ee48e0f8bade77a15f49ecef92849
Author: Shahar Epstein <[email protected]>
AuthorDate: Sat Jun 8 23:51:45 2024 +0300

    Add `delete_topic` to `KafkaAdminClientHook` and teardown logic to Kafka 
integration tests (#40142)
    
    * Add unit tests to Apache Kafka hooks
    
    * Add teardown logic to integration tests of kafka hooks
---
 airflow/providers/apache/kafka/hooks/base.py       |  3 +-
 airflow/providers/apache/kafka/hooks/client.py     | 16 +++++
 tests/always/test_project_structure.py             |  1 -
 .../apache/kafka/hooks/test_admin_client.py        |  1 +
 .../providers/apache/kafka/hooks/test_consumer.py  |  3 +
 .../providers/apache/kafka/hooks/test_producer.py  |  4 +-
 tests/providers/apache/kafka/hooks/test_base.py    | 81 ++++++++++++++++++++++
 tests/providers/apache/kafka/hooks/test_client.py  | 81 +++++++++++++++-------
 tests/providers/apache/kafka/hooks/test_consume.py | 12 +---
 tests/providers/apache/kafka/hooks/test_produce.py | 12 +---
 10 files changed, 168 insertions(+), 46 deletions(-)

diff --git a/airflow/providers/apache/kafka/hooks/base.py 
b/airflow/providers/apache/kafka/hooks/base.py
index 2f99cb21ea..f45b773c27 100644
--- a/airflow/providers/apache/kafka/hooks/base.py
+++ b/airflow/providers/apache/kafka/hooks/base.py
@@ -40,7 +40,6 @@ class KafkaBaseHook(BaseHook):
         """Initialize our Base."""
         super().__init__()
         self.kafka_config_id = kafka_config_id
-        self.get_conn
 
     @classmethod
     def get_ui_field_behaviour(cls) -> dict[str, Any]:
@@ -74,6 +73,6 @@ class KafkaBaseHook(BaseHook):
             if t:
                 return True, "Connection successful."
         except Exception as e:
-            False, str(e)
+            return False, str(e)
 
         return False, "Failed to establish connection."
diff --git a/airflow/providers/apache/kafka/hooks/client.py 
b/airflow/providers/apache/kafka/hooks/client.py
index fa4dff1d69..6772cc3d32 100644
--- a/airflow/providers/apache/kafka/hooks/client.py
+++ b/airflow/providers/apache/kafka/hooks/client.py
@@ -61,3 +61,19 @@ class KafkaAdminClientHook(KafkaBaseHook):
                     self.log.warning("The topic %s already exists.", t)
                 else:
                     raise
+
+    def delete_topic(
+        self,
+        topics: Sequence[str],
+    ) -> None:
+        """
+        Delete a topic.
+
+        :param topics: a list of topics to delete.
+        """
+        admin_client = self.get_conn
+        futures = admin_client.delete_topics(topics)
+
+        for t, f in futures.items():
+            f.result()
+            self.log.info("The topic %s has been deleted.", t)
diff --git a/tests/always/test_project_structure.py 
b/tests/always/test_project_structure.py
index 26f609a6f0..9917ccfb03 100644
--- a/tests/always/test_project_structure.py
+++ b/tests/always/test_project_structure.py
@@ -83,7 +83,6 @@ class TestProjectStructure:
             "tests/providers/apache/hdfs/log/test_hdfs_task_handler.py",
             "tests/providers/apache/hdfs/sensors/test_hdfs.py",
             "tests/providers/apache/hive/plugins/test_hive.py",
-            "tests/providers/apache/kafka/hooks/test_base.py",
             "tests/providers/celery/executors/test_celery_executor_utils.py",
             "tests/providers/celery/executors/test_default_celery.py",
             
"tests/providers/cncf/kubernetes/backcompat/test_backwards_compat_converters.py",
diff --git 
a/tests/integration/providers/apache/kafka/hooks/test_admin_client.py 
b/tests/integration/providers/apache/kafka/hooks/test_admin_client.py
index 9597456c6f..d542b721eb 100644
--- a/tests/integration/providers/apache/kafka/hooks/test_admin_client.py
+++ b/tests/integration/providers/apache/kafka/hooks/test_admin_client.py
@@ -49,3 +49,4 @@ class TestKafkaAdminClientHook:
         kadmin = hook.get_conn
         t = kadmin.list_topics(timeout=10).topics
         assert t.get("test_2")
+        hook.delete_topic(topics=["test_1", "test_2"])
diff --git a/tests/integration/providers/apache/kafka/hooks/test_consumer.py 
b/tests/integration/providers/apache/kafka/hooks/test_consumer.py
index 1134f55265..9cd9527631 100644
--- a/tests/integration/providers/apache/kafka/hooks/test_consumer.py
+++ b/tests/integration/providers/apache/kafka/hooks/test_consumer.py
@@ -24,6 +24,7 @@ from confluent_kafka import Producer
 from airflow.models import Connection
 
 # Import Hook
+from airflow.providers.apache.kafka.hooks.client import KafkaAdminClientHook
 from airflow.providers.apache.kafka.hooks.consume import KafkaConsumerHook
 from airflow.utils import db
 
@@ -68,3 +69,5 @@ class TestConsumerHook:
         msg = consumer.consume()
 
         assert msg[0].value() == b"test_message"
+        hook = KafkaAdminClientHook(kafka_config_id="kafka_d")
+        hook.delete_topic(topics=[TOPIC])
diff --git a/tests/integration/providers/apache/kafka/hooks/test_producer.py 
b/tests/integration/providers/apache/kafka/hooks/test_producer.py
index 663a965eb9..ad2351ef73 100644
--- a/tests/integration/providers/apache/kafka/hooks/test_producer.py
+++ b/tests/integration/providers/apache/kafka/hooks/test_producer.py
@@ -22,6 +22,7 @@ import logging
 import pytest
 
 from airflow.models import Connection
+from airflow.providers.apache.kafka.hooks.client import KafkaAdminClientHook
 from airflow.providers.apache.kafka.hooks.produce import KafkaProducerHook
 from airflow.utils import db
 
@@ -61,7 +62,8 @@ class TestProducerHook:
         p_hook = KafkaProducerHook(kafka_config_id="kafka_default")
 
         producer = p_hook.get_producer()
-
         producer.produce(topic, key="p1", value="p2", on_delivery=acked)
         producer.poll(0)
         producer.flush()
+        hook = KafkaAdminClientHook(kafka_config_id="kafka_default")
+        hook.delete_topic(topics=[topic])
diff --git a/tests/providers/apache/kafka/hooks/test_base.py 
b/tests/providers/apache/kafka/hooks/test_base.py
new file mode 100644
index 0000000000..c1ca9544b8
--- /dev/null
+++ b/tests/providers/apache/kafka/hooks/test_base.py
@@ -0,0 +1,81 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+from unittest.mock import MagicMock
+
+import pytest
+
+from airflow.providers.apache.kafka.hooks.base import KafkaBaseHook
+
+
+class SomeKafkaHook(KafkaBaseHook):
+    def _get_client(self, config):
+        return config
+
+
[email protected]
+def hook():
+    return SomeKafkaHook()
+
+
+TIMEOUT = 10
+
+
+class TestKafkaBaseHook:
+    @mock.patch("airflow.hooks.base.BaseHook.get_connection")
+    def test_get_conn(self, mock_get_connection, hook):
+        config = {"bootstrap.servers": MagicMock()}
+        mock_get_connection.return_value.extra_dejson = config
+        assert hook.get_conn == config
+
+    @mock.patch("airflow.hooks.base.BaseHook.get_connection")
+    def test_get_conn_value_error(self, mock_get_connection, hook):
+        mock_get_connection.return_value.extra_dejson = {}
+        with pytest.raises(ValueError, match="must be provided"):
+            hook.get_conn()
+
+    @mock.patch("airflow.providers.apache.kafka.hooks.base.AdminClient")
+    @mock.patch("airflow.hooks.base.BaseHook.get_connection")
+    def test_test_connection(self, mock_get_connection, admin_client, hook):
+        config = {"bootstrap.servers": MagicMock()}
+        mock_get_connection.return_value.extra_dejson = config
+        connection = hook.test_connection()
+        admin_client.assert_called_once_with(config, timeout=10)
+        assert connection == (True, "Connection successful.")
+
+    @mock.patch(
+        "airflow.providers.apache.kafka.hooks.base.AdminClient",
+        return_value=MagicMock(list_topics=MagicMock(return_value=[])),
+    )
+    @mock.patch("airflow.hooks.base.BaseHook.get_connection")
+    def test_test_connection_no_topics(self, mock_get_connection, 
admin_client, hook):
+        config = {"bootstrap.servers": MagicMock()}
+        mock_get_connection.return_value.extra_dejson = config
+        connection = hook.test_connection()
+        admin_client.assert_called_once_with(config, timeout=TIMEOUT)
+        assert connection == (False, "Failed to establish connection.")
+
+    @mock.patch("airflow.providers.apache.kafka.hooks.base.AdminClient")
+    @mock.patch("airflow.hooks.base.BaseHook.get_connection")
+    def test_test_connection_exception(self, mock_get_connection, 
admin_client, hook):
+        config = {"bootstrap.servers": MagicMock()}
+        mock_get_connection.return_value.extra_dejson = config
+        admin_client.return_value.list_topics.side_effect = [ValueError("some 
error")]
+        connection = hook.test_connection()
+        assert connection == (False, "some error")
diff --git a/tests/providers/apache/kafka/hooks/test_client.py 
b/tests/providers/apache/kafka/hooks/test_client.py
index 16ffa5ac4d..0d92287969 100644
--- a/tests/providers/apache/kafka/hooks/test_client.py
+++ b/tests/providers/apache/kafka/hooks/test_client.py
@@ -18,9 +18,11 @@ from __future__ import annotations
 
 import json
 import logging
+from unittest.mock import MagicMock, patch
 
 import pytest
-from confluent_kafka.admin import AdminClient
+from confluent_kafka import KafkaException
+from confluent_kafka.admin import AdminClient, NewTopic
 
 from airflow.models import Connection
 from airflow.providers.apache.kafka.hooks.client import KafkaAdminClientHook
@@ -31,11 +33,7 @@ pytestmark = pytest.mark.db_test
 log = logging.getLogger(__name__)
 
 
-class TestSampleHook:
-    """
-    Test Admin Client Hook.
-    """
-
+class TestKafkaAdminClientHook:
     def setup_method(self):
         db.merge_conn(
             Connection(
@@ -54,23 +52,58 @@ class TestSampleHook:
                 extra=json.dumps({"socket.timeout.ms": 10}),
             )
         )
-
-    def test_init(self):
-        """test initialization of AdminClientHook"""
-
-        # Standard Init
-        KafkaAdminClientHook(kafka_config_id="kafka_d")
-
-        # # Not Enough Args
-        with pytest.raises(ValueError):
-            KafkaAdminClientHook(kafka_config_id="kafka_bad")
+        self.hook = KafkaAdminClientHook(kafka_config_id="kafka_d")
 
     def test_get_conn(self):
-        """test get_conn"""
-
-        # Standard Init
-        k = KafkaAdminClientHook(kafka_config_id="kafka_d")
-
-        c = k.get_conn
-
-        assert isinstance(c, AdminClient)
+        assert isinstance(self.hook.get_conn, AdminClient)
+
+    @patch(
+        "airflow.providers.apache.kafka.hooks.client.AdminClient",
+    )
+    def test_create_topic(self, admin_client):
+        mock_f = MagicMock()
+        admin_client.return_value.create_topics.return_value = {"topic_name": 
mock_f}
+        self.hook.create_topic(topics=[("topic_name", 0, 1)])
+        
admin_client.return_value.create_topics.assert_called_with([NewTopic("topic_name",
 0, 1)])
+        mock_f.result.assert_called_once()
+
+    @patch(
+        "airflow.providers.apache.kafka.hooks.client.AdminClient",
+    )
+    def test_create_topic_error(self, admin_client):
+        mock_f = MagicMock()
+        kafka_exception = KafkaException()
+        mock_arg = MagicMock()
+        # mock_arg.name = "TOPIC_ALREADY_EXISTS"
+        kafka_exception.args = [mock_arg]
+        mock_f.result.side_effect = [kafka_exception]
+        admin_client.return_value.create_topics.return_value = {"topic_name": 
mock_f}
+        with pytest.raises(KafkaException):
+            self.hook.create_topic(topics=[("topic_name", 0, 1)])
+
+    @patch(
+        "airflow.providers.apache.kafka.hooks.client.AdminClient",
+    )
+    def test_create_topic_warning(self, admin_client, caplog):
+        mock_f = MagicMock()
+        kafka_exception = KafkaException()
+        mock_arg = MagicMock()
+        mock_arg.name = "TOPIC_ALREADY_EXISTS"
+        kafka_exception.args = [mock_arg]
+        mock_f.result.side_effect = [kafka_exception]
+        admin_client.return_value.create_topics.return_value = {"topic_name": 
mock_f}
+        with caplog.at_level(
+            logging.WARNING, 
logger="airflow.providers.apache.kafka.hooks.client.KafkaAdminClientHook"
+        ):
+            self.hook.create_topic(topics=[("topic_name", 0, 1)])
+            assert "The topic topic_name already exists" in caplog.text
+
+    @patch(
+        "airflow.providers.apache.kafka.hooks.client.AdminClient",
+    )
+    def test_delete_topic(self, admin_client):
+        mock_f = MagicMock()
+        admin_client.return_value.delete_topics.return_value = {"topic_name": 
mock_f}
+        self.hook.delete_topic(topics=["topic_name"])
+        
admin_client.return_value.delete_topics.assert_called_with(["topic_name"])
+        mock_f.result.assert_called_once()
diff --git a/tests/providers/apache/kafka/hooks/test_consume.py 
b/tests/providers/apache/kafka/hooks/test_consume.py
index 852d737448..16c9ee5dba 100644
--- a/tests/providers/apache/kafka/hooks/test_consume.py
+++ b/tests/providers/apache/kafka/hooks/test_consume.py
@@ -52,13 +52,7 @@ class TestConsumerHook:
                 extra=json.dumps({}),
             )
         )
+        self.hook = KafkaConsumerHook(["test_1"], kafka_config_id="kafka_d")
 
-    def test_init(self):
-        """test initialization of AdminClientHook"""
-
-        # Standard Init
-        KafkaConsumerHook(["test_1"], kafka_config_id="kafka_d")
-
-        # Not Enough Args
-        with pytest.raises(ValueError):
-            KafkaConsumerHook(["test_1"], kafka_config_id="kafka_bad")
+    def test_get_consumer(self):
+        assert self.hook.get_consumer() == self.hook.get_conn
diff --git a/tests/providers/apache/kafka/hooks/test_produce.py 
b/tests/providers/apache/kafka/hooks/test_produce.py
index 0f5ed0e186..3bcdd010ca 100644
--- a/tests/providers/apache/kafka/hooks/test_produce.py
+++ b/tests/providers/apache/kafka/hooks/test_produce.py
@@ -54,13 +54,7 @@ class TestProducerHook:
                 extra=json.dumps({}),
             )
         )
+        self.hook = KafkaProducerHook(kafka_config_id="kafka_d")
 
-    def test_init(self):
-        """test initialization of AdminClientHook"""
-
-        # Standard Init
-        KafkaProducerHook(kafka_config_id="kafka_d")
-
-        # Not Enough Args
-        with pytest.raises(ValueError):
-            KafkaProducerHook(kafka_config_id="kafka_bad")
+    def test_get_producer(self):
+        assert self.hook.get_producer() == self.hook.get_conn

Reply via email to