This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 1cb127b9fd Validate Hive Beeline parameters (#29502)
1cb127b9fd is described below
commit 1cb127b9fd22a7dc8e0b82cab8acb7cd4c317c9c
Author: Jarek Potiuk <[email protected]>
AuthorDate: Wed Feb 15 02:52:12 2023 +0100
Validate Hive Beeline parameters (#29502)
The parameters for Hive when beeline is used should be validated
in order to avoid unnecessary jdbc calls when they are invalid.
This PR adds raising an exception early in cases the parameters
are not correct.
---
airflow/providers/apache/hive/hooks/hive.py | 17 +++++++++++++
tests/providers/apache/hive/hooks/test_hive.py | 33 +++++++++++++++++++++++++-
2 files changed, 49 insertions(+), 1 deletion(-)
diff --git a/airflow/providers/apache/hive/hooks/hive.py
b/airflow/providers/apache/hive/hooks/hive.py
index d083dba952..137bb17eaa 100644
--- a/airflow/providers/apache/hive/hooks/hive.py
+++ b/airflow/providers/apache/hive/hooks/hive.py
@@ -141,6 +141,7 @@ class HiveCliHook(BaseHook):
if self.use_beeline:
hive_bin = "beeline"
+ self._validate_beeline_parameters(conn)
jdbc_url = f"jdbc:hive2://{conn.host}:{conn.port}/{conn.schema}"
if conf.get("core", "security") == "kerberos":
template = conn.extra_dejson.get("principal",
"hive/[email protected]")
@@ -165,6 +166,22 @@ class HiveCliHook(BaseHook):
return [hive_bin] + cmd_extra + hive_params_list
+ def _validate_beeline_parameters(self, conn):
+ if ":" in conn.host or "/" in conn.host or ";" in conn.host:
+ raise Exception(
+ f"The host used in beeline command ({conn.host}) should not
contain ':/;' characters)"
+ )
+ try:
+ int_port = int(conn.port)
+ if int_port <= 0 or int_port > 65535:
+ raise Exception(f"The port used in beeline command
({conn.port}) should be in range 0-65535)")
+ except (ValueError, TypeError) as e:
+ raise Exception(f"The port used in beeline command ({conn.port})
should be a valid integer: {e})")
+ if ";" in conn.schema:
+ raise Exception(
+ f"The schema used in beeline command ({conn.schema}) should
not contain ';' character)"
+ )
+
@staticmethod
def _prepare_hiveconf(d: dict[Any, Any]) -> list[Any]:
"""
diff --git a/tests/providers/apache/hive/hooks/test_hive.py
b/tests/providers/apache/hive/hooks/test_hive.py
index 0a3439f400..a0d856eb60 100644
--- a/tests/providers/apache/hive/hooks/test_hive.py
+++ b/tests/providers/apache/hive/hooks/test_hive.py
@@ -29,7 +29,7 @@ from hmsclient import HMSClient
from airflow.exceptions import AirflowException
from airflow.models.connection import Connection
from airflow.models.dag import DAG
-from airflow.providers.apache.hive.hooks.hive import HiveMetastoreHook,
HiveServer2Hook
+from airflow.providers.apache.hive.hooks.hive import HiveCliHook,
HiveMetastoreHook, HiveServer2Hook
from airflow.secrets.environment_variables import CONN_ENV_PREFIX
from airflow.utils import timezone
from airflow.utils.operator_helpers import AIRFLOW_VAR_NAME_FORMAT_MAPPING
@@ -641,6 +641,37 @@ class TestHiveServer2Hook:
database="default",
)
+ @pytest.mark.parametrize(
+ "host, port, schema, message",
+ [
+ ("localhost", "10000", "default", None),
+ ("localhost:", "10000", "default", "The host used in beeline
command"),
+ (";ocalhost", "10000", "default", "The host used in beeline
command"),
+ (";ocalho/", "10000", "default", "The host used in beeline
command"),
+ ("localhost", "as", "default", "The port used in beeline command"),
+ ("localhost", "0;", "default", "The port used in beeline command"),
+ ("localhost", "10/", "default", "The port used in beeline
command"),
+ ("localhost", ":", "default", "The port used in beeline command"),
+ ("localhost", "-1", "default", "The port used in beeline command"),
+ ("localhost", "655536", "default", "The port used in beeline
command"),
+ ("localhost", "1234", "default;", "The schema used in beeline
command"),
+ ],
+ )
+ def test_get_conn_with_wrong_connection_parameters(self, host, port,
schema, message):
+ connection = Connection(
+ conn_id="test",
+ conn_type="hive",
+ host=host,
+ port=port,
+ schema=schema,
+ )
+ hook = HiveCliHook()
+ if message:
+ with pytest.raises(Exception, match=message):
+ hook._validate_beeline_parameters(connection)
+ else:
+ hook._validate_beeline_parameters(connection)
+
def test_get_records(self):
hook = MockHiveServer2Hook()
query = f"SELECT * FROM {self.table}"