This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new d24933c3f4 DruidHook add SQL-based task support (#32795)
d24933c3f4 is described below
commit d24933c3f49a14fcb4e215c78a8f43a568209688
Author: Vioao <[email protected]>
AuthorDate: Sun Aug 6 21:13:13 2023 +0800
DruidHook add SQL-based task support (#32795)
* DruidHook supports SQL-based task
---------
Co-authored-by: vio.ao <[email protected]>
Co-authored-by: Ashish Patel <[email protected]>
---
airflow/providers/apache/druid/hooks/druid.py | 33 ++++++++++++++++----
airflow/providers/apache/druid/operators/druid.py | 7 +++--
.../operators.rst | 1 +
tests/providers/apache/druid/hooks/test_druid.py | 35 ++++++++++++++++++++--
.../providers/apache/druid/operators/test_druid.py | 27 +++++++++++++++++
5 files changed, 93 insertions(+), 10 deletions(-)
diff --git a/airflow/providers/apache/druid/hooks/druid.py
b/airflow/providers/apache/druid/hooks/druid.py
index 5b5c814fb5..7708684e60 100644
--- a/airflow/providers/apache/druid/hooks/druid.py
+++ b/airflow/providers/apache/druid/hooks/druid.py
@@ -18,6 +18,7 @@
from __future__ import annotations
import time
+from enum import Enum
from typing import Any, Iterable
import requests
@@ -28,6 +29,17 @@ from airflow.hooks.base import BaseHook
from airflow.providers.common.sql.hooks.sql import DbApiHook
+class IngestionType(Enum):
+ """
+ Druid Ingestion Type. Could be Native batch ingestion or SQL-based
ingestion.
+
+ https://druid.apache.org/docs/latest/ingestion/index.html
+ """
+
+ BATCH = 1
+ MSQ = 2
+
+
class DruidHook(BaseHook):
"""
Connection to Druid overlord for ingestion.
@@ -59,13 +71,16 @@ class DruidHook(BaseHook):
if self.timeout < 1:
raise ValueError("Druid timeout should be equal or greater than 1")
- def get_conn_url(self) -> str:
+ def get_conn_url(self, ingestion_type: IngestionType =
IngestionType.BATCH) -> str:
"""Get Druid connection url."""
conn = self.get_connection(self.druid_ingest_conn_id)
host = conn.host
port = conn.port
conn_type = conn.conn_type or "http"
- endpoint = conn.extra_dejson.get("endpoint", "")
+ if ingestion_type == IngestionType.BATCH:
+ endpoint = conn.extra_dejson.get("endpoint", "")
+ else:
+ endpoint = conn.extra_dejson.get("msq_endpoint", "")
return f"{conn_type}://{host}:{port}/{endpoint}"
def get_auth(self) -> requests.auth.HTTPBasicAuth | None:
@@ -82,9 +97,11 @@ class DruidHook(BaseHook):
else:
return None
- def submit_indexing_job(self, json_index_spec: dict[str, Any] | str) ->
None:
+ def submit_indexing_job(
+ self, json_index_spec: dict[str, Any] | str, ingestion_type:
IngestionType = IngestionType.BATCH
+ ) -> None:
"""Submit Druid ingestion job."""
- url = self.get_conn_url()
+ url = self.get_conn_url(ingestion_type)
self.log.info("Druid ingestion spec: %s", json_index_spec)
req_index = requests.post(url, data=json_index_spec,
headers=self.header, auth=self.get_auth())
@@ -96,14 +113,18 @@ class DruidHook(BaseHook):
req_json = req_index.json()
# Wait until the job is completed
- druid_task_id = req_json["task"]
+ if ingestion_type == IngestionType.BATCH:
+ druid_task_id = req_json["task"]
+ else:
+ druid_task_id = req_json["taskId"]
+ druid_task_status_url = f"{self.get_conn_url()}/{druid_task_id}/status"
self.log.info("Druid indexing task-id: %s", druid_task_id)
running = True
sec = 0
while running:
- req_status = requests.get(f"{url}/{druid_task_id}/status",
auth=self.get_auth())
+ req_status = requests.get(druid_task_status_url,
auth=self.get_auth())
self.log.info("Job still running for %s seconds...", sec)
diff --git a/airflow/providers/apache/druid/operators/druid.py
b/airflow/providers/apache/druid/operators/druid.py
index 7e1cd60799..080287e5ec 100644
--- a/airflow/providers/apache/druid/operators/druid.py
+++ b/airflow/providers/apache/druid/operators/druid.py
@@ -20,7 +20,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Any, Sequence
from airflow.models import BaseOperator
-from airflow.providers.apache.druid.hooks.druid import DruidHook
+from airflow.providers.apache.druid.hooks.druid import DruidHook, IngestionType
if TYPE_CHECKING:
from airflow.utils.context import Context
@@ -36,6 +36,7 @@ class DruidOperator(BaseOperator):
:param timeout: The interval (in seconds) between polling the Druid job
for the status
of the ingestion job. Must be greater than or equal to 1
:param max_ingestion_time: The maximum ingestion time before assuming the
job failed
+ :param ingestion_type: The ingestion type of the job. Could be
IngestionType.Batch or IngestionType.MSQ
"""
template_fields: Sequence[str] = ("json_index_file",)
@@ -49,6 +50,7 @@ class DruidOperator(BaseOperator):
druid_ingest_conn_id: str = "druid_ingest_default",
timeout: int = 1,
max_ingestion_time: int | None = None,
+ ingestion_type: IngestionType = IngestionType.BATCH,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
@@ -56,6 +58,7 @@ class DruidOperator(BaseOperator):
self.conn_id = druid_ingest_conn_id
self.timeout = timeout
self.max_ingestion_time = max_ingestion_time
+ self.ingestion_type = ingestion_type
def execute(self, context: Context) -> None:
hook = DruidHook(
@@ -64,4 +67,4 @@ class DruidOperator(BaseOperator):
max_ingestion_time=self.max_ingestion_time,
)
self.log.info("Submitting %s", self.json_index_file)
- hook.submit_indexing_job(self.json_index_file)
+ hook.submit_indexing_job(self.json_index_file, self.ingestion_type)
diff --git a/docs/apache-airflow-providers-apache-druid/operators.rst
b/docs/apache-airflow-providers-apache-druid/operators.rst
index 1d2f1c022a..6930e7b4d3 100644
--- a/docs/apache-airflow-providers-apache-druid/operators.rst
+++ b/docs/apache-airflow-providers-apache-druid/operators.rst
@@ -29,6 +29,7 @@ DruidOperator
-------------------
Submit a task directly to Druid, you need to provide the filepath to the Druid
index specification ``json_index_file``, and the connection id of the Druid
overlord ``druid_ingest_conn_id`` which accepts index jobs in Airflow
Connections.
+In addition, you can provide the ingestion type ``ingestion_type`` to
determine whether the job is Batch Ingestion or SQL-based ingestion.
There is also a example content of the Druid Ingestion specification below.
diff --git a/tests/providers/apache/druid/hooks/test_druid.py
b/tests/providers/apache/druid/hooks/test_druid.py
index 7d97f857b4..38f26f06cc 100644
--- a/tests/providers/apache/druid/hooks/test_druid.py
+++ b/tests/providers/apache/druid/hooks/test_druid.py
@@ -23,7 +23,7 @@ import pytest
import requests
from airflow.exceptions import AirflowException
-from airflow.providers.apache.druid.hooks.druid import DruidDbApiHook,
DruidHook
+from airflow.providers.apache.druid.hooks.druid import DruidDbApiHook,
DruidHook, IngestionType
class TestDruidHook:
@@ -35,7 +35,11 @@ class TestDruidHook:
session.mount("mock", adapter)
class TestDRuidhook(DruidHook):
- def get_conn_url(self):
+ self.is_sql_based_ingestion = False
+
+ def get_conn_url(self, ingestion_type: IngestionType =
IngestionType.BATCH):
+ if ingestion_type == IngestionType.MSQ:
+ return "http://druid-overlord:8081/druid/v2/sql/task"
return "http://druid-overlord:8081/druid/indexer/v1/task"
self.db_hook = TestDRuidhook()
@@ -73,6 +77,22 @@ class TestDruidHook:
assert task_post.called_once
assert status_check.called_once
+ def test_submit_sql_based_ingestion_ok(self, requests_mock):
+ task_post = requests_mock.post(
+ "http://druid-overlord:8081/druid/v2/sql/task",
+ text='{"taskId":"9f8a7359-77d4-4612-b0cd-cc2f6a3c28de"}',
+ )
+ status_check = requests_mock.get(
+
"http://druid-overlord:8081/druid/indexer/v1/task/9f8a7359-77d4-4612-b0cd-cc2f6a3c28de/status",
+ text='{"status":{"status": "SUCCESS"}}',
+ )
+
+ # Exists just as it should
+ self.db_hook.submit_indexing_job("Long json file", IngestionType.MSQ)
+
+ assert task_post.called_once
+ assert status_check.called_once
+
def test_submit_correct_json_body(self, requests_mock):
task_post = requests_mock.post(
"http://druid-overlord:8081/druid/indexer/v1/task",
@@ -149,6 +169,17 @@ class TestDruidHook:
hook = DruidHook(timeout=1, max_ingestion_time=5)
assert hook.get_conn_url() == "https://test_host:1/ingest"
+
@patch("airflow.providers.apache.druid.hooks.druid.DruidHook.get_connection")
+ def test_get_conn_url_with_ingestion_type(self, mock_get_connection):
+ get_conn_value = MagicMock()
+ get_conn_value.host = "test_host"
+ get_conn_value.conn_type = "https"
+ get_conn_value.port = "1"
+ get_conn_value.extra_dejson = {"endpoint": "ingest", "msq_endpoint":
"sql_ingest"}
+ mock_get_connection.return_value = get_conn_value
+ hook = DruidHook(timeout=1, max_ingestion_time=5)
+ assert hook.get_conn_url(IngestionType.MSQ) ==
"https://test_host:1/sql_ingest"
+
@patch("airflow.providers.apache.druid.hooks.druid.DruidHook.get_connection")
def test_get_auth(self, mock_get_connection):
get_conn_value = MagicMock()
diff --git a/tests/providers/apache/druid/operators/test_druid.py
b/tests/providers/apache/druid/operators/test_druid.py
index 30d14f4f1a..86f4650f70 100644
--- a/tests/providers/apache/druid/operators/test_druid.py
+++ b/tests/providers/apache/druid/operators/test_druid.py
@@ -18,7 +18,9 @@
from __future__ import annotations
import json
+from unittest.mock import MagicMock, patch
+from airflow.providers.apache.druid.hooks.druid import IngestionType
from airflow.providers.apache.druid.operators.druid import DruidOperator
from airflow.utils import timezone
from airflow.utils.types import DagRunType
@@ -104,3 +106,28 @@ def test_init_default_timeout():
)
expected_default_timeout = 1
assert expected_default_timeout == operator.timeout
+
+
+@patch("airflow.providers.apache.druid.operators.druid.DruidHook")
+def test_execute_calls_druid_hook_with_the_right_parameters(mock_druid_hook):
+ mock_druid_hook_instance = MagicMock()
+ mock_druid_hook.return_value = mock_druid_hook_instance
+ json_index_file = "sql.json"
+ druid_ingest_conn_id = "druid_ingest_default"
+ max_ingestion_time = 5
+ timeout = 5
+ operator = DruidOperator(
+ task_id="spark_submit_job",
+ json_index_file=json_index_file,
+ druid_ingest_conn_id=druid_ingest_conn_id,
+ timeout=timeout,
+ ingestion_type=IngestionType.MSQ,
+ max_ingestion_time=max_ingestion_time,
+ )
+ operator.execute(context={})
+ mock_druid_hook.assert_called_once_with(
+ druid_ingest_conn_id=druid_ingest_conn_id,
+ timeout=timeout,
+ max_ingestion_time=max_ingestion_time,
+ )
+
mock_druid_hook_instance.submit_indexing_job.assert_called_once_with(json_index_file,
IngestionType.MSQ)