This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 5d1f201bb0 Only restrict spark binary passed via extra (#30213)
5d1f201bb0 is described below

commit 5d1f201bb0411d7060fd4fe49807fd49495f973e
Author: Jarek Potiuk <[email protected]>
AuthorDate: Wed Mar 22 10:34:20 2023 +0100

    Only restrict spark binary passed via extra (#30213)
    
    As discussed in #30064 - the security vulnerabilty fix from
    the #27646 restricted the spark binaries a little too much (the
    binaries should be restricted only when passed via extra).
    
    This PR fixes it, spark submit is only restricted when passed
    via extra, you can still pass any binary via Hook parameter.
---
 .../providers/apache/spark/hooks/spark_submit.py   | 28 +++++++++-------------
 .../apache/spark/hooks/test_spark_submit.py        |  5 ++--
 2 files changed, 13 insertions(+), 20 deletions(-)

diff --git a/airflow/providers/apache/spark/hooks/spark_submit.py 
b/airflow/providers/apache/spark/hooks/spark_submit.py
index 459a0916b9..842df0e28a 100644
--- a/airflow/providers/apache/spark/hooks/spark_submit.py
+++ b/airflow/providers/apache/spark/hooks/spark_submit.py
@@ -150,14 +150,7 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
         self._submit_sp: Any | None = None
         self._yarn_application_id: str | None = None
         self._kubernetes_driver_pod: str | None = None
-        self._spark_binary = spark_binary
-        if self._spark_binary is not None and self._spark_binary not in 
ALLOWED_SPARK_BINARIES:
-            raise RuntimeError(
-                f"The spark-binary extra can be on of {ALLOWED_SPARK_BINARIES} 
and it"
-                f" was `{spark_binary}`. Please make sure your spark binary is 
one of the"
-                f" allowed ones and that it is available on the PATH"
-            )
-
+        self.spark_binary = spark_binary
         self._connection = self._resolve_connection()
         self._is_yarn = "yarn" in self._connection["master"]
         self._is_kubernetes = "k8s" in self._connection["master"]
@@ -186,7 +179,7 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
             "master": "yarn",
             "queue": None,
             "deploy_mode": None,
-            "spark_binary": self._spark_binary or "spark-submit",
+            "spark_binary": self.spark_binary or "spark-submit",
             "namespace": None,
         }
 
@@ -203,13 +196,14 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
             extra = conn.extra_dejson
             conn_data["queue"] = extra.get("queue")
             conn_data["deploy_mode"] = extra.get("deploy-mode")
-            spark_binary = self._spark_binary or extra.get("spark-binary", 
"spark-submit")
-            if spark_binary not in ALLOWED_SPARK_BINARIES:
-                raise RuntimeError(
-                    f"The `spark-binary` extra can be one of 
{ALLOWED_SPARK_BINARIES} and it"
-                    f" was `{spark_binary}`. Please make sure your spark 
binary is one of the"
-                    " allowed ones and that it is available on the PATH"
-                )
+            if not self.spark_binary:
+                self.spark_binary = extra.get("spark-binary", "spark-submit")
+                if self.spark_binary is not None and self.spark_binary not in 
ALLOWED_SPARK_BINARIES:
+                    raise RuntimeError(
+                        f"The spark-binary extra can be on of 
{ALLOWED_SPARK_BINARIES} and it"
+                        f" was `{self.spark_binary}`. Please make sure your 
spark binary is one of the"
+                        f" allowed ones and that it is available on the PATH"
+                    )
             conn_spark_home = extra.get("spark-home")
             if conn_spark_home:
                 raise RuntimeError(
@@ -217,7 +211,7 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
                     f" {ALLOWED_SPARK_BINARIES} is available on the PATH, and 
set `spark-binary`"
                     " if needed."
                 )
-            conn_data["spark_binary"] = spark_binary
+            conn_data["spark_binary"] = self.spark_binary
             conn_data["namespace"] = extra.get("namespace")
         except AirflowException:
             self.log.info(
diff --git a/tests/providers/apache/spark/hooks/test_spark_submit.py 
b/tests/providers/apache/spark/hooks/test_spark_submit.py
index 874d0e3f8c..052b15aeb2 100644
--- a/tests/providers/apache/spark/hooks/test_spark_submit.py
+++ b/tests/providers/apache/spark/hooks/test_spark_submit.py
@@ -461,9 +461,8 @@ class TestSparkSubmitHook:
         assert connection == expected_spark_connection
         assert cmd[0] == "spark3-submit"
 
-    def 
test_resolve_connection_custom_spark_binary_not_allowed_runtime_error(self):
-        with pytest.raises(RuntimeError):
-            SparkSubmitHook(conn_id="spark_binary_set", 
spark_binary="another-custom-spark-submit")
+    def test_resolve_connection_custom_spark_binary_allowed_in_hook(self):
+        SparkSubmitHook(conn_id="spark_binary_set", 
spark_binary="another-custom-spark-submit")
 
     def 
test_resolve_connection_spark_binary_extra_not_allowed_runtime_error(self):
         with pytest.raises(RuntimeError):

Reply via email to