This is an automated email from the ASF dual-hosted git repository.

taragolis pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 89df63b7ce Allow passing context to DruidDbApiHook (#34603)
89df63b7ce is described below

commit 89df63b7ce2cde4a7b3c0cd0583ed9d2bb9d0ece
Author: Saulius Grigaliunas <[email protected]>
AuthorDate: Tue Oct 3 17:24:48 2023 +0300

    Allow passing context to DruidDbApiHook (#34603)
    
    Druid's SQL API endpoint can accept context param to allow use of
    various query functionality [described in 
documentation](https://druid.apache.org/docs/latest/querying/sql-query-context/).
    This change enables passing context when using `DruidDbApiHook`.
---
 airflow/providers/apache/druid/hooks/druid.py    |  9 +++++++
 tests/providers/apache/druid/hooks/test_druid.py | 32 ++++++++++++++++++++++++
 2 files changed, 41 insertions(+)

diff --git a/airflow/providers/apache/druid/hooks/druid.py 
b/airflow/providers/apache/druid/hooks/druid.py
index 7708684e60..fb219fdebe 100644
--- a/airflow/providers/apache/druid/hooks/druid.py
+++ b/airflow/providers/apache/druid/hooks/druid.py
@@ -156,6 +156,10 @@ class DruidDbApiHook(DbApiHook):
 
     This hook is purely for users to query druid broker.
     For ingestion, please use druidHook.
+
+    :param context: Optional query context parameters to pass to the SQL 
endpoint.
+        Example: ``{"sqlFinalizeOuterSketches": True}``
+        See: https://druid.apache.org/docs/latest/querying/sql-query-context/
     """
 
     conn_name_attr = "druid_broker_conn_id"
@@ -164,6 +168,10 @@ class DruidDbApiHook(DbApiHook):
     hook_name = "Druid"
     supports_autocommit = False
 
+    def __init__(self, context: dict | None = None, *args, **kwargs) -> None:
+        super().__init__(*args, **kwargs)
+        self.context = context or {}
+
     def get_conn(self) -> connect:
         """Establish a connection to druid broker."""
         conn = self.get_connection(getattr(self, self.conn_name_attr))
@@ -174,6 +182,7 @@ class DruidDbApiHook(DbApiHook):
             scheme=conn.extra_dejson.get("schema", "http"),
             user=conn.login,
             password=conn.password,
+            context=self.context,
         )
         self.log.info("Get the connection to druid broker on %s using user 
%s", conn.host, conn.login)
         return druid_broker_conn
diff --git a/tests/providers/apache/druid/hooks/test_druid.py 
b/tests/providers/apache/druid/hooks/test_druid.py
index 38f26f06cc..04e68f765a 100644
--- a/tests/providers/apache/druid/hooks/test_druid.py
+++ b/tests/providers/apache/druid/hooks/test_druid.py
@@ -233,6 +233,38 @@ class TestDruidDbApiHook:
 
         self.db_hook = TestDruidDBApiHook
 
+    
@patch("airflow.providers.apache.druid.hooks.druid.DruidDbApiHook.get_connection")
+    @patch("airflow.providers.apache.druid.hooks.druid.connect")
+    @pytest.mark.parametrize(
+        ("specified_context", "passed_context"),
+        [
+            (None, {}),
+            ({"query_origin": "airflow"}, {"query_origin": "airflow"}),
+        ],
+    )
+    def test_get_conn_with_context(
+        self, mock_connect, mock_get_connection, specified_context, 
passed_context
+    ):
+        get_conn_value = MagicMock()
+        get_conn_value.host = "test_host"
+        get_conn_value.conn_type = "https"
+        get_conn_value.login = "test_login"
+        get_conn_value.password = "test_password"
+        get_conn_value.port = 10000
+        get_conn_value.extra_dejson = {"endpoint": "/test/endpoint", "schema": 
"https"}
+        mock_get_connection.return_value = get_conn_value
+        hook = DruidDbApiHook(context=specified_context)
+        hook.get_conn()
+        mock_connect.assert_called_with(
+            host="test_host",
+            port=10000,
+            path="/test/endpoint",
+            scheme="https",
+            user="test_login",
+            password="test_password",
+            context=passed_context,
+        )
+
     def test_get_uri(self):
         db_hook = self.db_hook()
         assert "druid://host:1000/druid/v2/sql" == db_hook.get_uri()

Reply via email to