This is an automated email from the ASF dual-hosted git repository.
rom pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 509428b1c1 Changed conf property from str to dict in SparkSqlOperator
(#42835)
509428b1c1 is described below
commit 509428b1c1f3f0a639d79f0c9b02036b53d5e63c
Author: Aleksandar Milosevic <[email protected]>
AuthorDate: Tue Oct 22 07:38:51 2024 +0200
Changed conf property from str to dict in SparkSqlOperator (#42835)
---
.../providers/apache/spark/hooks/spark_sql.py | 11 +++++--
.../providers/apache/spark/operators/spark_sql.py | 2 +-
.../tests/apache/spark/hooks/test_spark_sql.py | 36 +++++++++++++++++++++-
3 files changed, 44 insertions(+), 5 deletions(-)
diff --git a/providers/src/airflow/providers/apache/spark/hooks/spark_sql.py
b/providers/src/airflow/providers/apache/spark/hooks/spark_sql.py
index 4d5da04567..f31dcbaaa2 100644
--- a/providers/src/airflow/providers/apache/spark/hooks/spark_sql.py
+++ b/providers/src/airflow/providers/apache/spark/hooks/spark_sql.py
@@ -82,7 +82,7 @@ class SparkSqlHook(BaseHook):
def __init__(
self,
sql: str,
- conf: str | None = None,
+ conf: dict[str, Any] | str | None = None,
conn_id: str = default_conn_name,
total_executor_cores: int | None = None,
executor_cores: int | None = None,
@@ -143,8 +143,13 @@ class SparkSqlHook(BaseHook):
"""
connection_cmd = ["spark-sql"]
if self._conf:
- for conf_el in self._conf.split(","):
- connection_cmd += ["--conf", conf_el]
+ conf = self._conf
+ if isinstance(conf, dict):
+ for key, value in conf.items():
+ connection_cmd += ["--conf", f"{key}={value}"]
+ elif isinstance(conf, str):
+ for conf_el in conf.split(","):
+ connection_cmd += ["--conf", conf_el]
if self._total_executor_cores:
connection_cmd += ["--total-executor-cores",
str(self._total_executor_cores)]
if self._executor_cores:
diff --git
a/providers/src/airflow/providers/apache/spark/operators/spark_sql.py
b/providers/src/airflow/providers/apache/spark/operators/spark_sql.py
index ba0227f453..d4e6fbc98c 100644
--- a/providers/src/airflow/providers/apache/spark/operators/spark_sql.py
+++ b/providers/src/airflow/providers/apache/spark/operators/spark_sql.py
@@ -63,7 +63,7 @@ class SparkSqlOperator(BaseOperator):
self,
*,
sql: str,
- conf: str | None = None,
+ conf: dict[str, Any] | str | None = None,
conn_id: str = "spark_sql_default",
total_executor_cores: int | None = None,
executor_cores: int | None = None,
diff --git a/providers/tests/apache/spark/hooks/test_spark_sql.py
b/providers/tests/apache/spark/hooks/test_spark_sql.py
index 3ea9ab20e9..b6d9bd704c 100644
--- a/providers/tests/apache/spark/hooks/test_spark_sql.py
+++ b/providers/tests/apache/spark/hooks/test_spark_sql.py
@@ -42,6 +42,18 @@ def get_after(sentinel, iterable):
class TestSparkSqlHook:
_config = {
+ "conn_id": "spark_default",
+ "executor_cores": 4,
+ "executor_memory": "22g",
+ "keytab": "privileged_user.keytab",
+ "name": "spark-job",
+ "num_executors": 10,
+ "verbose": True,
+ "sql": " /path/to/sql/file.sql ",
+ "conf": {"key": "value", "PROP": "VALUE"},
+ }
+
+ _config_str = {
"conn_id": "spark_default",
"executor_cores": 4,
"executor_memory": "22g",
@@ -78,7 +90,29 @@ class TestSparkSqlHook:
assert self._config["sql"].strip() == sql_path
# Check if all config settings are there
- for key_value in self._config["conf"].split(","):
+ for k, v in self._config["conf"].items():
+ assert f"--conf {k}={v}" in cmd
+
+ if self._config["verbose"]:
+ assert "--verbose" in cmd
+
+ def test_build_command_with_str_conf(self):
+ hook = SparkSqlHook(**self._config_str)
+
+ # The subprocess requires an array but we build the cmd by joining on
a space
+ cmd = " ".join(hook._prepare_command(""))
+
+ # Check all the parameters
+ assert f"--executor-cores {self._config_str['executor_cores']}" in cmd
+ assert f"--executor-memory {self._config_str['executor_memory']}" in
cmd
+ assert f"--keytab {self._config_str['keytab']}" in cmd
+ assert f"--name {self._config_str['name']}" in cmd
+ assert f"--num-executors {self._config_str['num_executors']}" in cmd
+ sql_path = get_after("-f", hook._prepare_command(""))
+ assert self._config_str["sql"].strip() == sql_path
+
+ # Check if all config settings are there
+ for key_value in self._config_str["conf"].split(","):
k, v = key_value.split("=")
assert f"--conf {k}={v}" in cmd