This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new b3259877fa Add spark3-submit to list of allowed spark-binary values 
(#30068)
b3259877fa is described below

commit b3259877fac7330d2b65ca7f96fcfc27243582d6
Author: Andrew Otto <[email protected]>
AuthorDate: Wed Mar 15 14:17:54 2023 +0000

    Add spark3-submit to list of allowed spark-binary values (#30068)
    
    * Add spark3-submit to list of allowed spark-binary values
    
    The list of allowed values for spark-binary was restricted in
    apache/airflow#27646.  Add spark3-submit to this list to allow for 
distributions
    of Spark 3 that install the binary this way.
    
    See also apache/airflow#30065.
    
    * Fix lint errors in spark.rst and test_spark_submit.py
---
 .../providers/apache/spark/hooks/spark_submit.py   | 11 ++++----
 .../apache/spark/operators/spark_submit.py         |  2 +-
 .../connections/spark.rst                          |  2 +-
 .../apache/spark/hooks/test_spark_submit.py        | 33 ++++++++++++++++++++--
 4 files changed, 38 insertions(+), 10 deletions(-)

diff --git a/airflow/providers/apache/spark/hooks/spark_submit.py 
b/airflow/providers/apache/spark/hooks/spark_submit.py
index bfc08eda64..459a0916b9 100644
--- a/airflow/providers/apache/spark/hooks/spark_submit.py
+++ b/airflow/providers/apache/spark/hooks/spark_submit.py
@@ -33,7 +33,7 @@ from airflow.utils.log.logging_mixin import LoggingMixin
 with contextlib.suppress(ImportError, NameError):
     from airflow.kubernetes import kube_client
 
-ALLOWED_SPARK_BINARIES = ["spark-submit", "spark2-submit"]
+ALLOWED_SPARK_BINARIES = ["spark-submit", "spark2-submit", "spark3-submit"]
 
 
 class SparkSubmitHook(BaseHook, LoggingMixin):
@@ -78,7 +78,7 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
         supports yarn and k8s mode too.
     :param verbose: Whether to pass the verbose flag to spark-submit process 
for debugging
     :param spark_binary: The command to use for spark submit.
-                         Some distros may use spark2-submit.
+                         Some distros may use spark2-submit or spark3-submit.
     """
 
     conn_name_attr = "conn_id"
@@ -206,15 +206,16 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
             spark_binary = self._spark_binary or extra.get("spark-binary", 
"spark-submit")
             if spark_binary not in ALLOWED_SPARK_BINARIES:
                 raise RuntimeError(
-                    f"The `spark-binary` extra can be on of 
{ALLOWED_SPARK_BINARIES} and it"
+                    f"The `spark-binary` extra can be one of 
{ALLOWED_SPARK_BINARIES} and it"
                     f" was `{spark_binary}`. Please make sure your spark 
binary is one of the"
                     " allowed ones and that it is available on the PATH"
                 )
             conn_spark_home = extra.get("spark-home")
             if conn_spark_home:
                 raise RuntimeError(
-                    "The `spark-home` extra is not allowed any more. Please 
make sure your `spark-submit` or"
-                    " `spark2-submit` are available on the PATH."
+                    "The `spark-home` extra is not allowed any more. Please 
make sure one of"
+                    f" {ALLOWED_SPARK_BINARIES} is available on the PATH, and 
set `spark-binary`"
+                    " if needed."
                 )
             conn_data["spark_binary"] = spark_binary
             conn_data["namespace"] = extra.get("namespace")
diff --git a/airflow/providers/apache/spark/operators/spark_submit.py 
b/airflow/providers/apache/spark/operators/spark_submit.py
index b0b2961c73..f7ca92815a 100644
--- a/airflow/providers/apache/spark/operators/spark_submit.py
+++ b/airflow/providers/apache/spark/operators/spark_submit.py
@@ -69,7 +69,7 @@ class SparkSubmitOperator(BaseOperator):
     :param env_vars: Environment variables for spark-submit. It supports yarn 
and k8s mode too. (templated)
     :param verbose: Whether to pass the verbose flag to spark-submit process 
for debugging
     :param spark_binary: The command to use for spark submit.
-                         Some distros may use spark2-submit.
+                         Some distros may use spark2-submit or spark3-submit.
     """
 
     template_fields: Sequence[str] = (
diff --git a/docs/apache-airflow-providers-apache-spark/connections/spark.rst 
b/docs/apache-airflow-providers-apache-spark/connections/spark.rst
index 2d3247757a..be6981e82a 100644
--- a/docs/apache-airflow-providers-apache-spark/connections/spark.rst
+++ b/docs/apache-airflow-providers-apache-spark/connections/spark.rst
@@ -42,7 +42,7 @@ Extra (optional)
 
     * ``queue`` - The name of the YARN queue to which the application is 
submitted.
     * ``deploy-mode`` - Whether to deploy your driver on the worker nodes 
(cluster) or locally as an external client (client).
-    * ``spark-binary`` - The command to use for Spark submit. Some distros may 
use ``spark2-submit``. Default ``spark-submit``. Only ``spark-submit`` and 
``spark2-submit`` are allowed as value.
+    * ``spark-binary`` - The command to use for Spark submit. Some distros may 
use ``spark2-submit``. Default ``spark-submit``. Only ``spark-submit``, 
``spark2-submit`` or ``spark3-submit`` are allowed as value.
     * ``namespace`` - Kubernetes namespace (``spark.kubernetes.namespace``) to 
divide cluster resources between multiple users (via resource quota).
 
 When specifying the connection in environment variable you should specify
diff --git a/tests/providers/apache/spark/hooks/test_spark_submit.py 
b/tests/providers/apache/spark/hooks/test_spark_submit.py
index 34c441edf1..874d0e3f8c 100644
--- a/tests/providers/apache/spark/hooks/test_spark_submit.py
+++ b/tests/providers/apache/spark/hooks/test_spark_submit.py
@@ -102,6 +102,14 @@ class TestSparkSubmitHook:
                 extra='{"spark-binary": "spark2-submit"}',
             )
         )
+        db.merge_conn(
+            Connection(
+                conn_id="spark_binary_set_spark3_submit",
+                conn_type="spark",
+                host="yarn",
+                extra='{"spark-binary": "spark3-submit"}',
+            )
+        )
         db.merge_conn(
             Connection(
                 conn_id="spark_custom_binary_set",
@@ -434,6 +442,25 @@ class TestSparkSubmitHook:
         assert connection == expected_spark_connection
         assert cmd[0] == "spark2-submit"
 
+    def 
test_resolve_connection_spark_binary_spark3_submit_set_connection(self):
+        # Given
+        hook = SparkSubmitHook(conn_id="spark_binary_set_spark3_submit")
+
+        # When
+        connection = hook._resolve_connection()
+        cmd = hook._build_spark_submit_command(self._spark_job_file)
+
+        # Then
+        expected_spark_connection = {
+            "master": "yarn",
+            "spark_binary": "spark3-submit",
+            "deploy_mode": None,
+            "queue": None,
+            "namespace": None,
+        }
+        assert connection == expected_spark_connection
+        assert cmd[0] == "spark3-submit"
+
     def 
test_resolve_connection_custom_spark_binary_not_allowed_runtime_error(self):
         with pytest.raises(RuntimeError):
             SparkSubmitHook(conn_id="spark_binary_set", 
spark_binary="another-custom-spark-submit")
@@ -448,7 +475,7 @@ class TestSparkSubmitHook:
 
     def test_resolve_connection_spark_binary_default_value_override(self):
         # Given
-        hook = SparkSubmitHook(conn_id="spark_binary_set", 
spark_binary="spark2-submit")
+        hook = SparkSubmitHook(conn_id="spark_binary_set", 
spark_binary="spark3-submit")
 
         # When
         connection = hook._resolve_connection()
@@ -457,13 +484,13 @@ class TestSparkSubmitHook:
         # Then
         expected_spark_connection = {
             "master": "yarn",
-            "spark_binary": "spark2-submit",
+            "spark_binary": "spark3-submit",
             "deploy_mode": None,
             "queue": None,
             "namespace": None,
         }
         assert connection == expected_spark_connection
-        assert cmd[0] == "spark2-submit"
+        assert cmd[0] == "spark3-submit"
 
     def test_resolve_connection_spark_binary_default_value(self):
         # Given

Reply via email to