This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 479ec87166 Fix SSL verification in druid operator (#37673)
479ec87166 is described below

commit 479ec87166bfb7059ed1763580feb80c75ce7cd8
Author: Daniel Bell <[email protected]>
AuthorDate: Sun Mar 3 08:46:22 2024 +0100

    Fix SSL verification in druid operator (#37673)
    
    * Use cached_property for get_connection to avoid calling >1
    
    * Get ca bundle path from connection config if not verifying ssl
    
    * Add type checking check
    
    * Add tests
    
    * Fix lint and add documentation
    
    * Use default
    
    * Final tidy
    
    * Add log for when using CA bundle
    
    * Add test_conn_property
    
    ---------
    
    Co-authored-by: Daniel Bell <[email protected]>
---
 airflow/providers/apache/druid/hooks/druid.py      | 43 ++++++++++++++--------
 airflow/providers/apache/druid/operators/druid.py  |  7 ++--
 tests/providers/apache/druid/hooks/test_druid.py   | 32 +++++++++++++++-
 .../providers/apache/druid/operators/test_druid.py |  2 +-
 4 files changed, 62 insertions(+), 22 deletions(-)

diff --git a/airflow/providers/apache/druid/hooks/druid.py 
b/airflow/providers/apache/druid/hooks/druid.py
index 9ab60a8c08..da678d0153 100644
--- a/airflow/providers/apache/druid/hooks/druid.py
+++ b/airflow/providers/apache/druid/hooks/druid.py
@@ -19,7 +19,8 @@ from __future__ import annotations
 
 import time
 from enum import Enum
-from typing import Any, Iterable
+from functools import cached_property
+from typing import TYPE_CHECKING, Any, Iterable
 
 import requests
 from pydruid.db import connect
@@ -28,6 +29,9 @@ from airflow.exceptions import AirflowException
 from airflow.hooks.base import BaseHook
 from airflow.providers.common.sql.hooks.sql import DbApiHook
 
+if TYPE_CHECKING:
+    from airflow.models import Connection
+
 
 class IngestionType(Enum):
     """
@@ -53,9 +57,8 @@ class DruidHook(BaseHook):
                     the Druid job for the status of the ingestion job.
                     Must be greater than or equal to 1
     :param max_ingestion_time: The maximum ingestion time before assuming the 
job failed
-    :param verify_ssl: Either a boolean, in which case it controls whether we 
verify the server's TLS
-                      certificate, or a string, in which case it must be a 
path to a CA bundle to use.
-                      Defaults to True
+    :param verify_ssl: Whether to use SSL encryption to submit indexing job. 
If set to False then checks
+                       connection information for path to a CA bundle to use. 
Defaults to True
     """
 
     def __init__(
@@ -63,7 +66,7 @@ class DruidHook(BaseHook):
         druid_ingest_conn_id: str = "druid_ingest_default",
         timeout: int = 1,
         max_ingestion_time: int | None = None,
-        verify_ssl: bool | str = True,
+        verify_ssl: bool = True,
     ) -> None:
         super().__init__()
         self.druid_ingest_conn_id = druid_ingest_conn_id
@@ -75,16 +78,19 @@ class DruidHook(BaseHook):
         if self.timeout < 1:
             raise ValueError("Druid timeout should be equal or greater than 1")
 
+    @cached_property
+    def conn(self) -> Connection:
+        return self.get_connection(self.druid_ingest_conn_id)
+
     def get_conn_url(self, ingestion_type: IngestionType = 
IngestionType.BATCH) -> str:
         """Get Druid connection url."""
-        conn = self.get_connection(self.druid_ingest_conn_id)
-        host = conn.host
-        port = conn.port
-        conn_type = conn.conn_type or "http"
+        host = self.conn.host
+        port = self.conn.port
+        conn_type = self.conn.conn_type or "http"
         if ingestion_type == IngestionType.BATCH:
-            endpoint = conn.extra_dejson.get("endpoint", "")
+            endpoint = self.conn.extra_dejson.get("endpoint", "")
         else:
-            endpoint = conn.extra_dejson.get("msq_endpoint", "")
+            endpoint = self.conn.extra_dejson.get("msq_endpoint", "")
         return f"{conn_type}://{host}:{port}/{endpoint}"
 
     def get_auth(self) -> requests.auth.HTTPBasicAuth | None:
@@ -93,14 +99,21 @@ class DruidHook(BaseHook):
 
         If these details have not been set then returns None.
         """
-        conn = self.get_connection(self.druid_ingest_conn_id)
-        user = conn.login
-        password = conn.password
+        user = self.conn.login
+        password = self.conn.password
         if user is not None and password is not None:
             return requests.auth.HTTPBasicAuth(user, password)
         else:
             return None
 
+    def get_verify(self) -> bool | str:
+        ca_bundle_path: str | None = 
self.conn.extra_dejson.get("ca_bundle_path", None)
+        if not self.verify_ssl and ca_bundle_path:
+            self.log.info("Using CA bundle to verify connection")
+            return ca_bundle_path
+
+        return self.verify_ssl
+
     def submit_indexing_job(
         self, json_index_spec: dict[str, Any] | str, ingestion_type: 
IngestionType = IngestionType.BATCH
     ) -> None:
@@ -109,7 +122,7 @@ class DruidHook(BaseHook):
 
         self.log.info("Druid ingestion spec: %s", json_index_spec)
         req_index = requests.post(
-            url, data=json_index_spec, headers=self.header, 
auth=self.get_auth(), verify=self.verify_ssl
+            url, data=json_index_spec, headers=self.header, 
auth=self.get_auth(), verify=self.get_verify()
         )
 
         code = req_index.status_code
diff --git a/airflow/providers/apache/druid/operators/druid.py 
b/airflow/providers/apache/druid/operators/druid.py
index 9a5a411121..71ad409cb9 100644
--- a/airflow/providers/apache/druid/operators/druid.py
+++ b/airflow/providers/apache/druid/operators/druid.py
@@ -37,9 +37,8 @@ class DruidOperator(BaseOperator):
         of the ingestion job. Must be greater than or equal to 1
     :param max_ingestion_time: The maximum ingestion time before assuming the 
job failed
     :param ingestion_type: The ingestion type of the job. Could be 
IngestionType.Batch or IngestionType.MSQ
-    :param verify_ssl: Either a boolean, in which case it controls whether we 
verify the server's TLS
-                      certificate, or a string, in which case it must be a 
path to a CA bundle to use.
-                      Defaults to True.
+    :param verify_ssl: Whether to use SSL encryption to submit indexing job. 
If set to False then checks
+                       connection information for path to a CA bundle to use. 
Defaults to True
     """
 
     template_fields: Sequence[str] = ("json_index_file",)
@@ -54,7 +53,7 @@ class DruidOperator(BaseOperator):
         timeout: int = 1,
         max_ingestion_time: int | None = None,
         ingestion_type: IngestionType = IngestionType.BATCH,
-        verify_ssl: bool | str = True,
+        verify_ssl: bool = True,
         **kwargs: Any,
     ) -> None:
         super().__init__(**kwargs)
diff --git a/tests/providers/apache/druid/hooks/test_druid.py 
b/tests/providers/apache/druid/hooks/test_druid.py
index 76f332bb97..9b3ccfd474 100644
--- a/tests/providers/apache/druid/hooks/test_druid.py
+++ b/tests/providers/apache/druid/hooks/test_druid.py
@@ -97,7 +97,7 @@ class TestDruidSubmitHook:
         assert status_check.called_once
 
     def test_submit_with_correct_ssl_arg(self, requests_mock):
-        self.db_hook.verify_ssl = "/path/to/ca.crt"
+        self.db_hook.verify_ssl = False
         task_post = requests_mock.post(
             "http://druid-overlord:8081/druid/indexer/v1/task";,
             text='{"task":"9f8a7359-77d4-4612-b0cd-cc2f6a3c28de"}',
@@ -113,7 +113,7 @@ class TestDruidSubmitHook:
         assert status_check.called_once
         if task_post.called_once:
             verify_ssl = task_post.request_history[0].verify
-            assert "/path/to/ca.crt" == verify_ssl
+            assert False is verify_ssl
 
     def test_submit_correct_json_body(self, requests_mock):
         task_post = requests_mock.post(
@@ -199,6 +199,17 @@ class TestDruidHook:
 
         self.db_hook = TestDRuidhook()
 
+    
@patch("airflow.providers.apache.druid.hooks.druid.DruidHook.get_connection")
+    def test_conn_property(self, mock_get_connection):
+        get_conn_value = MagicMock()
+        get_conn_value.host = "test_host"
+        get_conn_value.conn_type = "https"
+        get_conn_value.port = "1"
+        get_conn_value.extra_dejson = {"endpoint": "ingest"}
+        mock_get_connection.return_value = get_conn_value
+        hook = DruidHook()
+        assert hook.conn == get_conn_value
+
     
@patch("airflow.providers.apache.druid.hooks.druid.DruidHook.get_connection")
     def test_get_conn_url(self, mock_get_connection):
         get_conn_value = MagicMock()
@@ -254,6 +265,23 @@ class TestDruidHook:
         mock_get_connection.return_value = get_conn_value
         assert self.db_hook.get_auth() is None
 
+    @pytest.mark.parametrize(
+        "verify_ssl_arg, ca_bundle_path, expected_return_value",
+        [
+            (False, None, False),
+            (True, None, True),
+            (False, "path/to/ca_bundle", "path/to/ca_bundle"),
+            (True, "path/to/ca_bundle", True),
+        ],
+    )
+    
@patch("airflow.providers.apache.druid.hooks.druid.DruidHook.get_connection")
+    def test_get_verify(self, mock_get_connection, verify_ssl_arg, 
ca_bundle_path, expected_return_value):
+        get_conn_value = MagicMock()
+        get_conn_value.extra_dejson = {"ca_bundle_path": ca_bundle_path}
+        mock_get_connection.return_value = get_conn_value
+        hook = DruidHook(verify_ssl=verify_ssl_arg)
+        assert hook.get_verify() == expected_return_value
+
 
 class TestDruidDbApiHook:
     def setup_method(self):
diff --git a/tests/providers/apache/druid/operators/test_druid.py 
b/tests/providers/apache/druid/operators/test_druid.py
index 28f9632cd1..286cdd3916 100644
--- a/tests/providers/apache/druid/operators/test_druid.py
+++ b/tests/providers/apache/druid/operators/test_druid.py
@@ -128,7 +128,7 @@ def 
test_execute_calls_druid_hook_with_the_right_parameters(mock_druid_hook):
     druid_ingest_conn_id = "druid_ingest_default"
     max_ingestion_time = 5
     timeout = 5
-    verify_ssl = "/path/to/ca.crt"
+    verify_ssl = False
     operator = DruidOperator(
         task_id="spark_submit_job",
         json_index_file=json_index_file,

Reply via email to