This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 479ec87166 Fix SSL verification in druid operator (#37673)
479ec87166 is described below
commit 479ec87166bfb7059ed1763580feb80c75ce7cd8
Author: Daniel Bell <[email protected]>
AuthorDate: Sun Mar 3 08:46:22 2024 +0100
Fix SSL verification in druid operator (#37673)
* Use cached_property for get_connection to avoid calling >1
* Get ca bundle path from connection config if not verifying ssl
* Add type checking check
* Add tests
* Fix lint and add documentation
* Use default
* Final tidy
* Add log for when using CA bundle
* Add test_conn_property
---------
Co-authored-by: Daniel Bell <[email protected]>
---
airflow/providers/apache/druid/hooks/druid.py | 43 ++++++++++++++--------
airflow/providers/apache/druid/operators/druid.py | 7 ++--
tests/providers/apache/druid/hooks/test_druid.py | 32 +++++++++++++++-
.../providers/apache/druid/operators/test_druid.py | 2 +-
4 files changed, 62 insertions(+), 22 deletions(-)
diff --git a/airflow/providers/apache/druid/hooks/druid.py
b/airflow/providers/apache/druid/hooks/druid.py
index 9ab60a8c08..da678d0153 100644
--- a/airflow/providers/apache/druid/hooks/druid.py
+++ b/airflow/providers/apache/druid/hooks/druid.py
@@ -19,7 +19,8 @@ from __future__ import annotations
import time
from enum import Enum
-from typing import Any, Iterable
+from functools import cached_property
+from typing import TYPE_CHECKING, Any, Iterable
import requests
from pydruid.db import connect
@@ -28,6 +29,9 @@ from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
from airflow.providers.common.sql.hooks.sql import DbApiHook
+if TYPE_CHECKING:
+ from airflow.models import Connection
+
class IngestionType(Enum):
"""
@@ -53,9 +57,8 @@ class DruidHook(BaseHook):
the Druid job for the status of the ingestion job.
Must be greater than or equal to 1
:param max_ingestion_time: The maximum ingestion time before assuming the
job failed
- :param verify_ssl: Either a boolean, in which case it controls whether we
verify the server's TLS
- certificate, or a string, in which case it must be a
path to a CA bundle to use.
- Defaults to True
+ :param verify_ssl: Whether to use SSL encryption to submit indexing job.
If set to False then checks
+ connection information for path to a CA bundle to use.
Defaults to True
"""
def __init__(
@@ -63,7 +66,7 @@ class DruidHook(BaseHook):
druid_ingest_conn_id: str = "druid_ingest_default",
timeout: int = 1,
max_ingestion_time: int | None = None,
- verify_ssl: bool | str = True,
+ verify_ssl: bool = True,
) -> None:
super().__init__()
self.druid_ingest_conn_id = druid_ingest_conn_id
@@ -75,16 +78,19 @@ class DruidHook(BaseHook):
if self.timeout < 1:
raise ValueError("Druid timeout should be equal or greater than 1")
+ @cached_property
+ def conn(self) -> Connection:
+ return self.get_connection(self.druid_ingest_conn_id)
+
def get_conn_url(self, ingestion_type: IngestionType =
IngestionType.BATCH) -> str:
"""Get Druid connection url."""
- conn = self.get_connection(self.druid_ingest_conn_id)
- host = conn.host
- port = conn.port
- conn_type = conn.conn_type or "http"
+ host = self.conn.host
+ port = self.conn.port
+ conn_type = self.conn.conn_type or "http"
if ingestion_type == IngestionType.BATCH:
- endpoint = conn.extra_dejson.get("endpoint", "")
+ endpoint = self.conn.extra_dejson.get("endpoint", "")
else:
- endpoint = conn.extra_dejson.get("msq_endpoint", "")
+ endpoint = self.conn.extra_dejson.get("msq_endpoint", "")
return f"{conn_type}://{host}:{port}/{endpoint}"
def get_auth(self) -> requests.auth.HTTPBasicAuth | None:
@@ -93,14 +99,21 @@ class DruidHook(BaseHook):
If these details have not been set then returns None.
"""
- conn = self.get_connection(self.druid_ingest_conn_id)
- user = conn.login
- password = conn.password
+ user = self.conn.login
+ password = self.conn.password
if user is not None and password is not None:
return requests.auth.HTTPBasicAuth(user, password)
else:
return None
+ def get_verify(self) -> bool | str:
+ ca_bundle_path: str | None =
self.conn.extra_dejson.get("ca_bundle_path", None)
+ if not self.verify_ssl and ca_bundle_path:
+ self.log.info("Using CA bundle to verify connection")
+ return ca_bundle_path
+
+ return self.verify_ssl
+
def submit_indexing_job(
self, json_index_spec: dict[str, Any] | str, ingestion_type:
IngestionType = IngestionType.BATCH
) -> None:
@@ -109,7 +122,7 @@ class DruidHook(BaseHook):
self.log.info("Druid ingestion spec: %s", json_index_spec)
req_index = requests.post(
- url, data=json_index_spec, headers=self.header,
auth=self.get_auth(), verify=self.verify_ssl
+ url, data=json_index_spec, headers=self.header,
auth=self.get_auth(), verify=self.get_verify()
)
code = req_index.status_code
diff --git a/airflow/providers/apache/druid/operators/druid.py
b/airflow/providers/apache/druid/operators/druid.py
index 9a5a411121..71ad409cb9 100644
--- a/airflow/providers/apache/druid/operators/druid.py
+++ b/airflow/providers/apache/druid/operators/druid.py
@@ -37,9 +37,8 @@ class DruidOperator(BaseOperator):
of the ingestion job. Must be greater than or equal to 1
:param max_ingestion_time: The maximum ingestion time before assuming the
job failed
:param ingestion_type: The ingestion type of the job. Could be
IngestionType.Batch or IngestionType.MSQ
- :param verify_ssl: Either a boolean, in which case it controls whether we
verify the server's TLS
- certificate, or a string, in which case it must be a
path to a CA bundle to use.
- Defaults to True.
+ :param verify_ssl: Whether to use SSL encryption to submit indexing job.
If set to False then checks
+ connection information for path to a CA bundle to use.
Defaults to True
"""
template_fields: Sequence[str] = ("json_index_file",)
@@ -54,7 +53,7 @@ class DruidOperator(BaseOperator):
timeout: int = 1,
max_ingestion_time: int | None = None,
ingestion_type: IngestionType = IngestionType.BATCH,
- verify_ssl: bool | str = True,
+ verify_ssl: bool = True,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
diff --git a/tests/providers/apache/druid/hooks/test_druid.py
b/tests/providers/apache/druid/hooks/test_druid.py
index 76f332bb97..9b3ccfd474 100644
--- a/tests/providers/apache/druid/hooks/test_druid.py
+++ b/tests/providers/apache/druid/hooks/test_druid.py
@@ -97,7 +97,7 @@ class TestDruidSubmitHook:
assert status_check.called_once
def test_submit_with_correct_ssl_arg(self, requests_mock):
- self.db_hook.verify_ssl = "/path/to/ca.crt"
+ self.db_hook.verify_ssl = False
task_post = requests_mock.post(
"http://druid-overlord:8081/druid/indexer/v1/task",
text='{"task":"9f8a7359-77d4-4612-b0cd-cc2f6a3c28de"}',
@@ -113,7 +113,7 @@ class TestDruidSubmitHook:
assert status_check.called_once
if task_post.called_once:
verify_ssl = task_post.request_history[0].verify
- assert "/path/to/ca.crt" == verify_ssl
+ assert False is verify_ssl
def test_submit_correct_json_body(self, requests_mock):
task_post = requests_mock.post(
@@ -199,6 +199,17 @@ class TestDruidHook:
self.db_hook = TestDRuidhook()
+
@patch("airflow.providers.apache.druid.hooks.druid.DruidHook.get_connection")
+ def test_conn_property(self, mock_get_connection):
+ get_conn_value = MagicMock()
+ get_conn_value.host = "test_host"
+ get_conn_value.conn_type = "https"
+ get_conn_value.port = "1"
+ get_conn_value.extra_dejson = {"endpoint": "ingest"}
+ mock_get_connection.return_value = get_conn_value
+ hook = DruidHook()
+ assert hook.conn == get_conn_value
+
@patch("airflow.providers.apache.druid.hooks.druid.DruidHook.get_connection")
def test_get_conn_url(self, mock_get_connection):
get_conn_value = MagicMock()
@@ -254,6 +265,23 @@ class TestDruidHook:
mock_get_connection.return_value = get_conn_value
assert self.db_hook.get_auth() is None
+ @pytest.mark.parametrize(
+ "verify_ssl_arg, ca_bundle_path, expected_return_value",
+ [
+ (False, None, False),
+ (True, None, True),
+ (False, "path/to/ca_bundle", "path/to/ca_bundle"),
+ (True, "path/to/ca_bundle", True),
+ ],
+ )
+
@patch("airflow.providers.apache.druid.hooks.druid.DruidHook.get_connection")
+ def test_get_verify(self, mock_get_connection, verify_ssl_arg,
ca_bundle_path, expected_return_value):
+ get_conn_value = MagicMock()
+ get_conn_value.extra_dejson = {"ca_bundle_path": ca_bundle_path}
+ mock_get_connection.return_value = get_conn_value
+ hook = DruidHook(verify_ssl=verify_ssl_arg)
+ assert hook.get_verify() == expected_return_value
+
class TestDruidDbApiHook:
def setup_method(self):
diff --git a/tests/providers/apache/druid/operators/test_druid.py
b/tests/providers/apache/druid/operators/test_druid.py
index 28f9632cd1..286cdd3916 100644
--- a/tests/providers/apache/druid/operators/test_druid.py
+++ b/tests/providers/apache/druid/operators/test_druid.py
@@ -128,7 +128,7 @@ def
test_execute_calls_druid_hook_with_the_right_parameters(mock_druid_hook):
druid_ingest_conn_id = "druid_ingest_default"
max_ingestion_time = 5
timeout = 5
- verify_ssl = "/path/to/ca.crt"
+ verify_ssl = False
operator = DruidOperator(
task_id="spark_submit_job",
json_index_file=json_index_file,