This is an automated email from the ASF dual-hosted git repository.

rom pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 509428b1c1 Changed conf property from str to dict in SparkSqlOperator 
(#42835)
509428b1c1 is described below

commit 509428b1c1f3f0a639d79f0c9b02036b53d5e63c
Author: Aleksandar Milosevic <[email protected]>
AuthorDate: Tue Oct 22 07:38:51 2024 +0200

    Changed conf property from str to dict in SparkSqlOperator (#42835)
---
 .../providers/apache/spark/hooks/spark_sql.py      | 11 +++++--
 .../providers/apache/spark/operators/spark_sql.py  |  2 +-
 .../tests/apache/spark/hooks/test_spark_sql.py     | 36 +++++++++++++++++++++-
 3 files changed, 44 insertions(+), 5 deletions(-)

diff --git a/providers/src/airflow/providers/apache/spark/hooks/spark_sql.py 
b/providers/src/airflow/providers/apache/spark/hooks/spark_sql.py
index 4d5da04567..f31dcbaaa2 100644
--- a/providers/src/airflow/providers/apache/spark/hooks/spark_sql.py
+++ b/providers/src/airflow/providers/apache/spark/hooks/spark_sql.py
@@ -82,7 +82,7 @@ class SparkSqlHook(BaseHook):
     def __init__(
         self,
         sql: str,
-        conf: str | None = None,
+        conf: dict[str, Any] | str | None = None,
         conn_id: str = default_conn_name,
         total_executor_cores: int | None = None,
         executor_cores: int | None = None,
@@ -143,8 +143,13 @@ class SparkSqlHook(BaseHook):
         """
         connection_cmd = ["spark-sql"]
         if self._conf:
-            for conf_el in self._conf.split(","):
-                connection_cmd += ["--conf", conf_el]
+            conf = self._conf
+            if isinstance(conf, dict):
+                for key, value in conf.items():
+                    connection_cmd += ["--conf", f"{key}={value}"]
+            elif isinstance(conf, str):
+                for conf_el in conf.split(","):
+                    connection_cmd += ["--conf", conf_el]
         if self._total_executor_cores:
             connection_cmd += ["--total-executor-cores", 
str(self._total_executor_cores)]
         if self._executor_cores:
diff --git 
a/providers/src/airflow/providers/apache/spark/operators/spark_sql.py 
b/providers/src/airflow/providers/apache/spark/operators/spark_sql.py
index ba0227f453..d4e6fbc98c 100644
--- a/providers/src/airflow/providers/apache/spark/operators/spark_sql.py
+++ b/providers/src/airflow/providers/apache/spark/operators/spark_sql.py
@@ -63,7 +63,7 @@ class SparkSqlOperator(BaseOperator):
         self,
         *,
         sql: str,
-        conf: str | None = None,
+        conf: dict[str, Any] | str | None = None,
         conn_id: str = "spark_sql_default",
         total_executor_cores: int | None = None,
         executor_cores: int | None = None,
diff --git a/providers/tests/apache/spark/hooks/test_spark_sql.py 
b/providers/tests/apache/spark/hooks/test_spark_sql.py
index 3ea9ab20e9..b6d9bd704c 100644
--- a/providers/tests/apache/spark/hooks/test_spark_sql.py
+++ b/providers/tests/apache/spark/hooks/test_spark_sql.py
@@ -42,6 +42,18 @@ def get_after(sentinel, iterable):
 
 class TestSparkSqlHook:
     _config = {
+        "conn_id": "spark_default",
+        "executor_cores": 4,
+        "executor_memory": "22g",
+        "keytab": "privileged_user.keytab",
+        "name": "spark-job",
+        "num_executors": 10,
+        "verbose": True,
+        "sql": " /path/to/sql/file.sql ",
+        "conf": {"key": "value", "PROP": "VALUE"},
+    }
+
+    _config_str = {
         "conn_id": "spark_default",
         "executor_cores": 4,
         "executor_memory": "22g",
@@ -78,7 +90,29 @@ class TestSparkSqlHook:
         assert self._config["sql"].strip() == sql_path
 
         # Check if all config settings are there
-        for key_value in self._config["conf"].split(","):
+        for k, v in self._config["conf"].items():
+            assert f"--conf {k}={v}" in cmd
+
+        if self._config["verbose"]:
+            assert "--verbose" in cmd
+
+    def test_build_command_with_str_conf(self):
+        hook = SparkSqlHook(**self._config_str)
+
+        # The subprocess requires an array but we build the cmd by joining on 
a space
+        cmd = " ".join(hook._prepare_command(""))
+
+        # Check all the parameters
+        assert f"--executor-cores {self._config_str['executor_cores']}" in cmd
+        assert f"--executor-memory {self._config_str['executor_memory']}" in 
cmd
+        assert f"--keytab {self._config_str['keytab']}" in cmd
+        assert f"--name {self._config_str['name']}" in cmd
+        assert f"--num-executors {self._config_str['num_executors']}" in cmd
+        sql_path = get_after("-f", hook._prepare_command(""))
+        assert self._config_str["sql"].strip() == sql_path
+
+        # Check if all config settings are there
+        for key_value in self._config_str["conf"].split(","):
             k, v = key_value.split("=")
             assert f"--conf {k}={v}" in cmd
 

Reply via email to