This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 1cb127b9fd Validate Hive Beeline parameters (#29502)
1cb127b9fd is described below

commit 1cb127b9fd22a7dc8e0b82cab8acb7cd4c317c9c
Author: Jarek Potiuk <[email protected]>
AuthorDate: Wed Feb 15 02:52:12 2023 +0100

    Validate Hive Beeline parameters (#29502)
    
    The parameters for Hive when beeline is used should be validated
    in order to avoid unnecessary jdbc calls when they are invalid.
    
    This PR adds raising an exception early in cases the parameters
    are not correct.
---
 airflow/providers/apache/hive/hooks/hive.py    | 17 +++++++++++++
 tests/providers/apache/hive/hooks/test_hive.py | 33 +++++++++++++++++++++++++-
 2 files changed, 49 insertions(+), 1 deletion(-)

diff --git a/airflow/providers/apache/hive/hooks/hive.py 
b/airflow/providers/apache/hive/hooks/hive.py
index d083dba952..137bb17eaa 100644
--- a/airflow/providers/apache/hive/hooks/hive.py
+++ b/airflow/providers/apache/hive/hooks/hive.py
@@ -141,6 +141,7 @@ class HiveCliHook(BaseHook):
 
         if self.use_beeline:
             hive_bin = "beeline"
+            self._validate_beeline_parameters(conn)
             jdbc_url = f"jdbc:hive2://{conn.host}:{conn.port}/{conn.schema}"
             if conf.get("core", "security") == "kerberos":
                 template = conn.extra_dejson.get("principal", 
"hive/[email protected]")
@@ -165,6 +166,22 @@ class HiveCliHook(BaseHook):
 
         return [hive_bin] + cmd_extra + hive_params_list
 
+    def _validate_beeline_parameters(self, conn):
+        if ":" in conn.host or "/" in conn.host or ";" in conn.host:
+            raise Exception(
+                f"The host used in beeline command ({conn.host}) should not 
contain ':/;' characters)"
+            )
+        try:
+            int_port = int(conn.port)
+            if int_port <= 0 or int_port > 65535:
+                raise Exception(f"The port used in beeline command 
({conn.port}) should be in range 0-65535)")
+        except (ValueError, TypeError) as e:
+            raise Exception(f"The port used in beeline command ({conn.port}) 
should be a valid integer: {e})")
+        if ";" in conn.schema:
+            raise Exception(
+                f"The schema used in beeline command ({conn.schema}) should 
not contain ';' character)"
+            )
+
     @staticmethod
     def _prepare_hiveconf(d: dict[Any, Any]) -> list[Any]:
         """
diff --git a/tests/providers/apache/hive/hooks/test_hive.py 
b/tests/providers/apache/hive/hooks/test_hive.py
index 0a3439f400..a0d856eb60 100644
--- a/tests/providers/apache/hive/hooks/test_hive.py
+++ b/tests/providers/apache/hive/hooks/test_hive.py
@@ -29,7 +29,7 @@ from hmsclient import HMSClient
 from airflow.exceptions import AirflowException
 from airflow.models.connection import Connection
 from airflow.models.dag import DAG
-from airflow.providers.apache.hive.hooks.hive import HiveMetastoreHook, 
HiveServer2Hook
+from airflow.providers.apache.hive.hooks.hive import HiveCliHook, 
HiveMetastoreHook, HiveServer2Hook
 from airflow.secrets.environment_variables import CONN_ENV_PREFIX
 from airflow.utils import timezone
 from airflow.utils.operator_helpers import AIRFLOW_VAR_NAME_FORMAT_MAPPING
@@ -641,6 +641,37 @@ class TestHiveServer2Hook:
                 database="default",
             )
 
+    @pytest.mark.parametrize(
+        "host, port, schema, message",
+        [
+            ("localhost", "10000", "default", None),
+            ("localhost:", "10000", "default", "The host used in beeline 
command"),
+            (";ocalhost", "10000", "default", "The host used in beeline 
command"),
+            (";ocalho/", "10000", "default", "The host used in beeline 
command"),
+            ("localhost", "as", "default", "The port used in beeline command"),
+            ("localhost", "0;", "default", "The port used in beeline command"),
+            ("localhost", "10/", "default", "The port used in beeline 
command"),
+            ("localhost", ":", "default", "The port used in beeline command"),
+            ("localhost", "-1", "default", "The port used in beeline command"),
+            ("localhost", "655536", "default", "The port used in beeline 
command"),
+            ("localhost", "1234", "default;", "The schema used in beeline 
command"),
+        ],
+    )
+    def test_get_conn_with_wrong_connection_parameters(self, host, port, 
schema, message):
+        connection = Connection(
+            conn_id="test",
+            conn_type="hive",
+            host=host,
+            port=port,
+            schema=schema,
+        )
+        hook = HiveCliHook()
+        if message:
+            with pytest.raises(Exception, match=message):
+                hook._validate_beeline_parameters(connection)
+        else:
+            hook._validate_beeline_parameters(connection)
+
     def test_get_records(self):
         hook = MockHiveServer2Hook()
         query = f"SELECT * FROM {self.table}"

Reply via email to