This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new d2ba126404 Adding optional SSL verification for druid operator (#37629)
d2ba126404 is described below

commit d2ba126404ceb7fbbfca64317086c8058c897f68
Author: Daniel Bell <[email protected]>
AuthorDate: Fri Feb 23 19:17:49 2024 +0100

    Adding optional SSL verification for druid operator (#37629)
    
    * Initial changes
    
    * Add tests
    
    ---------
    
    Co-authored-by: Daniel Bell <[email protected]>
---
 airflow/providers/apache/druid/hooks/druid.py        |  9 ++++++++-
 airflow/providers/apache/druid/operators/druid.py    |  6 ++++++
 tests/providers/apache/druid/hooks/test_druid.py     | 19 +++++++++++++++++++
 tests/providers/apache/druid/operators/test_druid.py | 13 ++++++++++++-
 4 files changed, 45 insertions(+), 2 deletions(-)

diff --git a/airflow/providers/apache/druid/hooks/druid.py 
b/airflow/providers/apache/druid/hooks/druid.py
index 1c0f809247..9ab60a8c08 100644
--- a/airflow/providers/apache/druid/hooks/druid.py
+++ b/airflow/providers/apache/druid/hooks/druid.py
@@ -53,6 +53,9 @@ class DruidHook(BaseHook):
                     the Druid job for the status of the ingestion job.
                     Must be greater than or equal to 1
     :param max_ingestion_time: The maximum ingestion time before assuming the 
job failed
+    :param verify_ssl: Either a boolean, in which case it controls whether we 
verify the server's TLS
+                      certificate, or a string, in which case it must be a 
path to a CA bundle to use.
+                      Defaults to True
     """
 
     def __init__(
@@ -60,12 +63,14 @@ class DruidHook(BaseHook):
         druid_ingest_conn_id: str = "druid_ingest_default",
         timeout: int = 1,
         max_ingestion_time: int | None = None,
+        verify_ssl: bool | str = True,
     ) -> None:
         super().__init__()
         self.druid_ingest_conn_id = druid_ingest_conn_id
         self.timeout = timeout
         self.max_ingestion_time = max_ingestion_time
         self.header = {"content-type": "application/json"}
+        self.verify_ssl = verify_ssl
 
         if self.timeout < 1:
             raise ValueError("Druid timeout should be equal or greater than 1")
@@ -103,7 +108,9 @@ class DruidHook(BaseHook):
         url = self.get_conn_url(ingestion_type)
 
         self.log.info("Druid ingestion spec: %s", json_index_spec)
-        req_index = requests.post(url, data=json_index_spec, 
headers=self.header, auth=self.get_auth())
+        req_index = requests.post(
+            url, data=json_index_spec, headers=self.header, 
auth=self.get_auth(), verify=self.verify_ssl
+        )
 
         code = req_index.status_code
         not_accepted = not (200 <= code < 300)
diff --git a/airflow/providers/apache/druid/operators/druid.py 
b/airflow/providers/apache/druid/operators/druid.py
index 080287e5ec..9a5a411121 100644
--- a/airflow/providers/apache/druid/operators/druid.py
+++ b/airflow/providers/apache/druid/operators/druid.py
@@ -37,6 +37,9 @@ class DruidOperator(BaseOperator):
         of the ingestion job. Must be greater than or equal to 1
     :param max_ingestion_time: The maximum ingestion time before assuming the 
job failed
     :param ingestion_type: The ingestion type of the job. Could be 
IngestionType.Batch or IngestionType.MSQ
+    :param verify_ssl: Either a boolean, in which case it controls whether we 
verify the server's TLS
+                      certificate, or a string, in which case it must be a 
path to a CA bundle to use.
+                      Defaults to True.
     """
 
     template_fields: Sequence[str] = ("json_index_file",)
@@ -51,6 +54,7 @@ class DruidOperator(BaseOperator):
         timeout: int = 1,
         max_ingestion_time: int | None = None,
         ingestion_type: IngestionType = IngestionType.BATCH,
+        verify_ssl: bool | str = True,
         **kwargs: Any,
     ) -> None:
         super().__init__(**kwargs)
@@ -59,12 +63,14 @@ class DruidOperator(BaseOperator):
         self.timeout = timeout
         self.max_ingestion_time = max_ingestion_time
         self.ingestion_type = ingestion_type
+        self.verify_ssl = verify_ssl
 
     def execute(self, context: Context) -> None:
         hook = DruidHook(
             druid_ingest_conn_id=self.conn_id,
             timeout=self.timeout,
             max_ingestion_time=self.max_ingestion_time,
+            verify_ssl=self.verify_ssl,
         )
         self.log.info("Submitting %s", self.json_index_file)
         hook.submit_indexing_job(self.json_index_file, self.ingestion_type)
diff --git a/tests/providers/apache/druid/hooks/test_druid.py 
b/tests/providers/apache/druid/hooks/test_druid.py
index 5a389cb710..76f332bb97 100644
--- a/tests/providers/apache/druid/hooks/test_druid.py
+++ b/tests/providers/apache/druid/hooks/test_druid.py
@@ -96,6 +96,25 @@ class TestDruidSubmitHook:
         assert task_post.called_once
         assert status_check.called_once
 
+    def test_submit_with_correct_ssl_arg(self, requests_mock):
+        self.db_hook.verify_ssl = "/path/to/ca.crt"
+        task_post = requests_mock.post(
+            "http://druid-overlord:8081/druid/indexer/v1/task";,
+            text='{"task":"9f8a7359-77d4-4612-b0cd-cc2f6a3c28de"}',
+        )
+        status_check = requests_mock.get(
+            
"http://druid-overlord:8081/druid/indexer/v1/task/9f8a7359-77d4-4612-b0cd-cc2f6a3c28de/status";,
+            text='{"status":{"status": "SUCCESS"}}',
+        )
+
+        self.db_hook.submit_indexing_job("Long json file")
+
+        assert task_post.called_once
+        assert status_check.called_once
+        if task_post.called_once:
+            verify_ssl = task_post.request_history[0].verify
+            assert "/path/to/ca.crt" == verify_ssl
+
     def test_submit_correct_json_body(self, requests_mock):
         task_post = requests_mock.post(
             "http://druid-overlord:8081/druid/indexer/v1/task";,
diff --git a/tests/providers/apache/druid/operators/test_druid.py 
b/tests/providers/apache/druid/operators/test_druid.py
index f6fba6bffb..28f9632cd1 100644
--- a/tests/providers/apache/druid/operators/test_druid.py
+++ b/tests/providers/apache/druid/operators/test_druid.py
@@ -102,14 +102,22 @@ def test_init_with_timeout_and_max_ingestion_time():
     assert expected_values["max_ingestion_time"] == operator.max_ingestion_time
 
 
-def test_init_default_timeout():
+def test_init_defaults():
     operator = DruidOperator(
         task_id="spark_submit_job",
         json_index_file=JSON_INDEX_STR,
         params={"index_type": "index_hadoop", "datasource": "datasource_prd"},
     )
+    expected_default_druid_ingest_conn_id = "druid_ingest_default"
     expected_default_timeout = 1
+    expected_default_max_ingestion_time = None
+    expected_default_ingestion_type = IngestionType.BATCH
+    expected_default_verify_ssl = True
+    assert expected_default_druid_ingest_conn_id == operator.conn_id
     assert expected_default_timeout == operator.timeout
+    assert expected_default_max_ingestion_time == operator.max_ingestion_time
+    assert expected_default_ingestion_type == operator.ingestion_type
+    assert expected_default_verify_ssl == operator.verify_ssl
 
 
 @patch("airflow.providers.apache.druid.operators.druid.DruidHook")
@@ -120,6 +128,7 @@ def 
test_execute_calls_druid_hook_with_the_right_parameters(mock_druid_hook):
     druid_ingest_conn_id = "druid_ingest_default"
     max_ingestion_time = 5
     timeout = 5
+    verify_ssl = "/path/to/ca.crt"
     operator = DruidOperator(
         task_id="spark_submit_job",
         json_index_file=json_index_file,
@@ -127,11 +136,13 @@ def 
test_execute_calls_druid_hook_with_the_right_parameters(mock_druid_hook):
         timeout=timeout,
         ingestion_type=IngestionType.MSQ,
         max_ingestion_time=max_ingestion_time,
+        verify_ssl=verify_ssl,
     )
     operator.execute(context={})
     mock_druid_hook.assert_called_once_with(
         druid_ingest_conn_id=druid_ingest_conn_id,
         timeout=timeout,
         max_ingestion_time=max_ingestion_time,
+        verify_ssl=verify_ssl,
     )
     
mock_druid_hook_instance.submit_indexing_job.assert_called_once_with(json_index_file,
 IngestionType.MSQ)

Reply via email to