This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new d24933c3f4 DruidHook add SQL-based task support (#32795)
d24933c3f4 is described below

commit d24933c3f49a14fcb4e215c78a8f43a568209688
Author: Vioao <[email protected]>
AuthorDate: Sun Aug 6 21:13:13 2023 +0800

    DruidHook add SQL-based task support (#32795)
    
    * DruidHook supports SQL-based task
    ---------
    
    Co-authored-by: vio.ao <[email protected]>
    Co-authored-by: Ashish Patel <[email protected]>
---
 airflow/providers/apache/druid/hooks/druid.py      | 33 ++++++++++++++++----
 airflow/providers/apache/druid/operators/druid.py  |  7 +++--
 .../operators.rst                                  |  1 +
 tests/providers/apache/druid/hooks/test_druid.py   | 35 ++++++++++++++++++++--
 .../providers/apache/druid/operators/test_druid.py | 27 +++++++++++++++++
 5 files changed, 93 insertions(+), 10 deletions(-)

diff --git a/airflow/providers/apache/druid/hooks/druid.py 
b/airflow/providers/apache/druid/hooks/druid.py
index 5b5c814fb5..7708684e60 100644
--- a/airflow/providers/apache/druid/hooks/druid.py
+++ b/airflow/providers/apache/druid/hooks/druid.py
@@ -18,6 +18,7 @@
 from __future__ import annotations
 
 import time
+from enum import Enum
 from typing import Any, Iterable
 
 import requests
@@ -28,6 +29,17 @@ from airflow.hooks.base import BaseHook
 from airflow.providers.common.sql.hooks.sql import DbApiHook
 
 
+class IngestionType(Enum):
+    """
+    Druid Ingestion Type. Could be Native batch ingestion or SQL-based 
ingestion.
+
+    https://druid.apache.org/docs/latest/ingestion/index.html
+    """
+
+    BATCH = 1
+    MSQ = 2
+
+
 class DruidHook(BaseHook):
     """
     Connection to Druid overlord for ingestion.
@@ -59,13 +71,16 @@ class DruidHook(BaseHook):
         if self.timeout < 1:
             raise ValueError("Druid timeout should be equal or greater than 1")
 
-    def get_conn_url(self) -> str:
+    def get_conn_url(self, ingestion_type: IngestionType = 
IngestionType.BATCH) -> str:
         """Get Druid connection url."""
         conn = self.get_connection(self.druid_ingest_conn_id)
         host = conn.host
         port = conn.port
         conn_type = conn.conn_type or "http"
-        endpoint = conn.extra_dejson.get("endpoint", "")
+        if ingestion_type == IngestionType.BATCH:
+            endpoint = conn.extra_dejson.get("endpoint", "")
+        else:
+            endpoint = conn.extra_dejson.get("msq_endpoint", "")
         return f"{conn_type}://{host}:{port}/{endpoint}"
 
     def get_auth(self) -> requests.auth.HTTPBasicAuth | None:
@@ -82,9 +97,11 @@ class DruidHook(BaseHook):
         else:
             return None
 
-    def submit_indexing_job(self, json_index_spec: dict[str, Any] | str) -> 
None:
+    def submit_indexing_job(
+        self, json_index_spec: dict[str, Any] | str, ingestion_type: 
IngestionType = IngestionType.BATCH
+    ) -> None:
         """Submit Druid ingestion job."""
-        url = self.get_conn_url()
+        url = self.get_conn_url(ingestion_type)
 
         self.log.info("Druid ingestion spec: %s", json_index_spec)
         req_index = requests.post(url, data=json_index_spec, 
headers=self.header, auth=self.get_auth())
@@ -96,14 +113,18 @@ class DruidHook(BaseHook):
 
         req_json = req_index.json()
         # Wait until the job is completed
-        druid_task_id = req_json["task"]
+        if ingestion_type == IngestionType.BATCH:
+            druid_task_id = req_json["task"]
+        else:
+            druid_task_id = req_json["taskId"]
+        druid_task_status_url = f"{self.get_conn_url()}/{druid_task_id}/status"
         self.log.info("Druid indexing task-id: %s", druid_task_id)
 
         running = True
 
         sec = 0
         while running:
-            req_status = requests.get(f"{url}/{druid_task_id}/status", 
auth=self.get_auth())
+            req_status = requests.get(druid_task_status_url, 
auth=self.get_auth())
 
             self.log.info("Job still running for %s seconds...", sec)
 
diff --git a/airflow/providers/apache/druid/operators/druid.py 
b/airflow/providers/apache/druid/operators/druid.py
index 7e1cd60799..080287e5ec 100644
--- a/airflow/providers/apache/druid/operators/druid.py
+++ b/airflow/providers/apache/druid/operators/druid.py
@@ -20,7 +20,7 @@ from __future__ import annotations
 from typing import TYPE_CHECKING, Any, Sequence
 
 from airflow.models import BaseOperator
-from airflow.providers.apache.druid.hooks.druid import DruidHook
+from airflow.providers.apache.druid.hooks.druid import DruidHook, IngestionType
 
 if TYPE_CHECKING:
     from airflow.utils.context import Context
@@ -36,6 +36,7 @@ class DruidOperator(BaseOperator):
     :param timeout: The interval (in seconds) between polling the Druid job 
for the status
         of the ingestion job. Must be greater than or equal to 1
     :param max_ingestion_time: The maximum ingestion time before assuming the 
job failed
+    :param ingestion_type: The ingestion type of the job. Could be 
IngestionType.Batch or IngestionType.MSQ
     """
 
     template_fields: Sequence[str] = ("json_index_file",)
@@ -49,6 +50,7 @@ class DruidOperator(BaseOperator):
         druid_ingest_conn_id: str = "druid_ingest_default",
         timeout: int = 1,
         max_ingestion_time: int | None = None,
+        ingestion_type: IngestionType = IngestionType.BATCH,
         **kwargs: Any,
     ) -> None:
         super().__init__(**kwargs)
@@ -56,6 +58,7 @@ class DruidOperator(BaseOperator):
         self.conn_id = druid_ingest_conn_id
         self.timeout = timeout
         self.max_ingestion_time = max_ingestion_time
+        self.ingestion_type = ingestion_type
 
     def execute(self, context: Context) -> None:
         hook = DruidHook(
@@ -64,4 +67,4 @@ class DruidOperator(BaseOperator):
             max_ingestion_time=self.max_ingestion_time,
         )
         self.log.info("Submitting %s", self.json_index_file)
-        hook.submit_indexing_job(self.json_index_file)
+        hook.submit_indexing_job(self.json_index_file, self.ingestion_type)
diff --git a/docs/apache-airflow-providers-apache-druid/operators.rst 
b/docs/apache-airflow-providers-apache-druid/operators.rst
index 1d2f1c022a..6930e7b4d3 100644
--- a/docs/apache-airflow-providers-apache-druid/operators.rst
+++ b/docs/apache-airflow-providers-apache-druid/operators.rst
@@ -29,6 +29,7 @@ DruidOperator
 -------------------
 
 Submit a task directly to Druid, you need to provide the filepath to the Druid 
index specification ``json_index_file``, and the connection id of the Druid 
overlord ``druid_ingest_conn_id`` which accepts index jobs in Airflow 
Connections.
+In addition, you can provide the ingestion type ``ingestion_type`` to 
determine whether the job is Batch Ingestion or SQL-based ingestion.
 
 There is also a example content of the Druid Ingestion specification below.
 
diff --git a/tests/providers/apache/druid/hooks/test_druid.py 
b/tests/providers/apache/druid/hooks/test_druid.py
index 7d97f857b4..38f26f06cc 100644
--- a/tests/providers/apache/druid/hooks/test_druid.py
+++ b/tests/providers/apache/druid/hooks/test_druid.py
@@ -23,7 +23,7 @@ import pytest
 import requests
 
 from airflow.exceptions import AirflowException
-from airflow.providers.apache.druid.hooks.druid import DruidDbApiHook, 
DruidHook
+from airflow.providers.apache.druid.hooks.druid import DruidDbApiHook, 
DruidHook, IngestionType
 
 
 class TestDruidHook:
@@ -35,7 +35,11 @@ class TestDruidHook:
         session.mount("mock", adapter)
 
         class TestDRuidhook(DruidHook):
-            def get_conn_url(self):
+            self.is_sql_based_ingestion = False
+
+            def get_conn_url(self, ingestion_type: IngestionType = 
IngestionType.BATCH):
+                if ingestion_type == IngestionType.MSQ:
+                    return "http://druid-overlord:8081/druid/v2/sql/task";
                 return "http://druid-overlord:8081/druid/indexer/v1/task";
 
         self.db_hook = TestDRuidhook()
@@ -73,6 +77,22 @@ class TestDruidHook:
         assert task_post.called_once
         assert status_check.called_once
 
+    def test_submit_sql_based_ingestion_ok(self, requests_mock):
+        task_post = requests_mock.post(
+            "http://druid-overlord:8081/druid/v2/sql/task";,
+            text='{"taskId":"9f8a7359-77d4-4612-b0cd-cc2f6a3c28de"}',
+        )
+        status_check = requests_mock.get(
+            
"http://druid-overlord:8081/druid/indexer/v1/task/9f8a7359-77d4-4612-b0cd-cc2f6a3c28de/status";,
+            text='{"status":{"status": "SUCCESS"}}',
+        )
+
+        # Exists just as it should
+        self.db_hook.submit_indexing_job("Long json file", IngestionType.MSQ)
+
+        assert task_post.called_once
+        assert status_check.called_once
+
     def test_submit_correct_json_body(self, requests_mock):
         task_post = requests_mock.post(
             "http://druid-overlord:8081/druid/indexer/v1/task";,
@@ -149,6 +169,17 @@ class TestDruidHook:
         hook = DruidHook(timeout=1, max_ingestion_time=5)
         assert hook.get_conn_url() == "https://test_host:1/ingest";
 
+    
@patch("airflow.providers.apache.druid.hooks.druid.DruidHook.get_connection")
+    def test_get_conn_url_with_ingestion_type(self, mock_get_connection):
+        get_conn_value = MagicMock()
+        get_conn_value.host = "test_host"
+        get_conn_value.conn_type = "https"
+        get_conn_value.port = "1"
+        get_conn_value.extra_dejson = {"endpoint": "ingest", "msq_endpoint": 
"sql_ingest"}
+        mock_get_connection.return_value = get_conn_value
+        hook = DruidHook(timeout=1, max_ingestion_time=5)
+        assert hook.get_conn_url(IngestionType.MSQ) == 
"https://test_host:1/sql_ingest";
+
     
@patch("airflow.providers.apache.druid.hooks.druid.DruidHook.get_connection")
     def test_get_auth(self, mock_get_connection):
         get_conn_value = MagicMock()
diff --git a/tests/providers/apache/druid/operators/test_druid.py 
b/tests/providers/apache/druid/operators/test_druid.py
index 30d14f4f1a..86f4650f70 100644
--- a/tests/providers/apache/druid/operators/test_druid.py
+++ b/tests/providers/apache/druid/operators/test_druid.py
@@ -18,7 +18,9 @@
 from __future__ import annotations
 
 import json
+from unittest.mock import MagicMock, patch
 
+from airflow.providers.apache.druid.hooks.druid import IngestionType
 from airflow.providers.apache.druid.operators.druid import DruidOperator
 from airflow.utils import timezone
 from airflow.utils.types import DagRunType
@@ -104,3 +106,28 @@ def test_init_default_timeout():
     )
     expected_default_timeout = 1
     assert expected_default_timeout == operator.timeout
+
+
+@patch("airflow.providers.apache.druid.operators.druid.DruidHook")
+def test_execute_calls_druid_hook_with_the_right_parameters(mock_druid_hook):
+    mock_druid_hook_instance = MagicMock()
+    mock_druid_hook.return_value = mock_druid_hook_instance
+    json_index_file = "sql.json"
+    druid_ingest_conn_id = "druid_ingest_default"
+    max_ingestion_time = 5
+    timeout = 5
+    operator = DruidOperator(
+        task_id="spark_submit_job",
+        json_index_file=json_index_file,
+        druid_ingest_conn_id=druid_ingest_conn_id,
+        timeout=timeout,
+        ingestion_type=IngestionType.MSQ,
+        max_ingestion_time=max_ingestion_time,
+    )
+    operator.execute(context={})
+    mock_druid_hook.assert_called_once_with(
+        druid_ingest_conn_id=druid_ingest_conn_id,
+        timeout=timeout,
+        max_ingestion_time=max_ingestion_time,
+    )
+    
mock_druid_hook_instance.submit_indexing_job.assert_called_once_with(json_index_file,
 IngestionType.MSQ)

Reply via email to