This is an automated email from the ASF dual-hosted git repository.

bolke pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 9103ea1702 Add support for Spark Connect to pyspark decorator (#35665)
9103ea1702 is described below

commit 9103ea170236f8761520aaa70656fcb010ea8e3e
Author: Bolke de Bruin <[email protected]>
AuthorDate: Thu Nov 16 09:51:00 2023 +0100

    Add support for Spark Connect to pyspark decorator (#35665)
    
    * Add support for Spark Connect to pyspark decorator
    
    In Apache Spark 3.4 Spark Connect was introduced which
    allows remote connectivity to remote Spark Cluster using
    the DataFrame API.
---
 .../providers/apache/spark/decorators/pyspark.py   |  27 +++--
 .../providers/apache/spark/hooks/spark_connect.py  |  99 +++++++++++++++++
 airflow/providers/apache/spark/provider.yaml       |   4 +
 .../connections/spark.rst                          |  11 +-
 .../decorators/pyspark.rst                         |  23 +++-
 generated/provider_dependencies.json               |   1 +
 .../apache/spark/decorators/test_pyspark.py        | 117 ++++++++++++++++++++-
 .../apache/spark/hooks/test_spark_connect.py       |  72 +++++++++++++
 8 files changed, 339 insertions(+), 15 deletions(-)

diff --git a/airflow/providers/apache/spark/decorators/pyspark.py 
b/airflow/providers/apache/spark/decorators/pyspark.py
index 6f576b03a2..c460d09f4f 100644
--- a/airflow/providers/apache/spark/decorators/pyspark.py
+++ b/airflow/providers/apache/spark/decorators/pyspark.py
@@ -23,6 +23,7 @@ from typing import TYPE_CHECKING, Any, Callable, Sequence
 from airflow.decorators.base import DecoratedOperator, TaskDecorator, 
task_decorator_factory
 from airflow.hooks.base import BaseHook
 from airflow.operators.python import PythonOperator
+from airflow.providers.apache.spark.hooks.spark_connect import SparkConnectHook
 
 if TYPE_CHECKING:
     from airflow.utils.context import Context
@@ -73,34 +74,44 @@ class _PySparkDecoratedOperator(DecoratedOperator, 
PythonOperator):
         from pyspark import SparkConf
         from pyspark.sql import SparkSession
 
-        conf = SparkConf().setAppName(f"{self.dag_id}-{self.task_id}")
+        conf = SparkConf()
+        conf.set("spark.app.name", f"{self.dag_id}-{self.task_id}")
 
-        master = "local[*]"
+        url = "local[*]"
         if self.conn_id:
+            # we handle both spark connect and spark standalone
             conn = BaseHook.get_connection(self.conn_id)
-            if conn.port:
-                master = f"{conn.host}:{conn.port}"
+            if conn.conn_type == SparkConnectHook.conn_type:
+                url = SparkConnectHook(self.conn_id).get_connection_url()
+            elif conn.port:
+                url = f"{conn.host}:{conn.port}"
             elif conn.host:
-                master = conn.host
+                url = conn.host
 
             for key, value in conn.extra_dejson.items():
                 conf.set(key, value)
 
-        conf.setMaster(master)
+        # you cannot have both remote and master
+        if url.startswith("sc://"):
+            conf.set("spark.remote", url)
 
         # task can override connection config
         for key, value in self.config_kwargs.items():
             conf.set(key, value)
 
+        if not conf.get("spark.remote") and not conf.get("spark.master"):
+            conf.set("spark.master", url)
+
         spark = SparkSession.builder.config(conf=conf).getOrCreate()
-        sc = spark.sparkContext
 
         if not self.op_kwargs:
             self.op_kwargs = {}
 
         op_kwargs: dict[str, Any] = dict(self.op_kwargs)
         op_kwargs["spark"] = spark
-        op_kwargs["sc"] = sc
+
+        # spark context is not available when using spark connect
+        op_kwargs["sc"] = spark.sparkContext if not conf.get("spark.remote") 
else None
 
         self.op_kwargs = op_kwargs
         return super().execute(context)
diff --git a/airflow/providers/apache/spark/hooks/spark_connect.py 
b/airflow/providers/apache/spark/hooks/spark_connect.py
new file mode 100644
index 0000000000..29828b0b78
--- /dev/null
+++ b/airflow/providers/apache/spark/hooks/spark_connect.py
@@ -0,0 +1,99 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import Any
+from urllib.parse import quote, urlparse, urlunparse
+
+from airflow.hooks.base import BaseHook
+from airflow.utils.log.logging_mixin import LoggingMixin
+
+
+class SparkConnectHook(BaseHook, LoggingMixin):
+    """Hook for Spark Connect."""
+
+    # from pyspark's ChannelBuilder
+    PARAM_USE_SSL = "use_ssl"
+    PARAM_TOKEN = "token"
+    PARAM_USER_ID = "user_id"
+
+    conn_name_attr = "conn_id"
+    default_conn_name = "spark_connect_default"
+    conn_type = "spark_connect"
+    hook_name = "Spark Connect"
+
+    @staticmethod
+    def get_ui_field_behaviour() -> dict[str, Any]:
+        """Return custom field behaviour."""
+        return {
+            "hidden_fields": [
+                "schema",
+            ],
+            "relabeling": {"password": "Token", "login": "User ID"},
+        }
+
+    @staticmethod
+    def get_connection_form_widgets() -> dict[str, Any]:
+        """Returns connection widgets to add to connection form."""
+        from flask_babel import lazy_gettext
+        from wtforms import BooleanField
+
+        return {
+            SparkConnectHook.PARAM_USE_SSL: BooleanField(lazy_gettext("Use 
SSL"), default=False),
+        }
+
+    def __init__(self, conn_id: str = default_conn_name) -> None:
+        super().__init__()
+        self._conn_id = conn_id
+
+    def get_connection_url(self) -> str:
+        conn = self.get_connection(self._conn_id)
+
+        host = conn.host
+        if conn.host.find("://") == -1:
+            host = f"sc://{conn.host}"
+        if conn.port:
+            host = f"{conn.host}:{conn.port}"
+
+        url = urlparse(host)
+
+        if url.path:
+            raise ValueError("Path {url.path} is not supported in Spark 
Connect connection URL")
+
+        params = []
+
+        if conn.login:
+            
params.append(f"{SparkConnectHook.PARAM_USER_ID}={quote(conn.login)}")
+
+        if conn.password:
+            
params.append(f"{SparkConnectHook.PARAM_TOKEN}={quote(conn.password)}")
+
+        use_ssl = conn.extra_dejson.get(SparkConnectHook.PARAM_USE_SSL)
+        if use_ssl is not None:
+            
params.append(f"{SparkConnectHook.PARAM_USE_SSL}={quote(str(use_ssl))}")
+
+        return urlunparse(
+            (
+                "sc",
+                url.netloc,
+                "/",
+                ";".join(params),  # params
+                "",
+                url.fragment,
+            )
+        )
diff --git a/airflow/providers/apache/spark/provider.yaml 
b/airflow/providers/apache/spark/provider.yaml
index c7b58ec727..31f2728a5c 100644
--- a/airflow/providers/apache/spark/provider.yaml
+++ b/airflow/providers/apache/spark/provider.yaml
@@ -51,6 +51,7 @@ versions:
 dependencies:
   - apache-airflow>=2.5.0
   - pyspark
+  - grpcio-status
 
 integrations:
   - integration-name: Apache Spark
@@ -70,6 +71,7 @@ operators:
 hooks:
   - integration-name: Apache Spark
     python-modules:
+      - airflow.providers.apache.spark.hooks.spark_connect
       - airflow.providers.apache.spark.hooks.spark_jdbc
       - airflow.providers.apache.spark.hooks.spark_jdbc_script
       - airflow.providers.apache.spark.hooks.spark_sql
@@ -77,6 +79,8 @@ hooks:
 
 
 connection-types:
+  - hook-class-name: 
airflow.providers.apache.spark.hooks.spark_connect.SparkConnectHook
+    connection-type: spark_connect
   - hook-class-name: 
airflow.providers.apache.spark.hooks.spark_jdbc.SparkJDBCHook
     connection-type: spark_jdbc
   - hook-class-name: 
airflow.providers.apache.spark.hooks.spark_sql.SparkSqlHook
diff --git a/docs/apache-airflow-providers-apache-spark/connections/spark.rst 
b/docs/apache-airflow-providers-apache-spark/connections/spark.rst
index 28249f3af9..05b92ce75c 100644
--- a/docs/apache-airflow-providers-apache-spark/connections/spark.rst
+++ b/docs/apache-airflow-providers-apache-spark/connections/spark.rst
@@ -27,7 +27,7 @@ The Apache Spark connection type enables connection to Apache 
Spark.
 Default Connection IDs
 ----------------------
 
-Spark Submit and Spark JDBC hooks and operators use ``spark_default`` by 
default. Spark SQL hooks and operators point to ``spark_sql_default`` by 
default.
+Spark Submit and Spark JDBC hooks and operators use ``spark_default`` by 
default. Spark SQL hooks and operators point to ``spark_sql_default`` by 
default. The Spark Connect hook uses ``spark_connect_default`` by default.
 
 Configuring the Connection
 --------------------------
@@ -45,6 +45,15 @@ Extra (optional)
     * ``spark-binary`` - The command to use for Spark submit. Some distros may 
use ``spark2-submit``. Default ``spark-submit``. Only ``spark-submit``, 
``spark2-submit`` or ``spark3-submit`` are allowed as value.
     * ``namespace`` - Kubernetes namespace (``spark.kubernetes.namespace``) to 
divide cluster resources between multiple users (via resource quota).
 
+User ID (optional, only applies to Spark Connect)
+    The user ID to authenticate with the proxy.
+
+Token (optional, only applies to Spark Connect)
+    The token to authenticate with the proxy.
+
+Use SSL (optional, only applies to Spark Connect)
+    Whether to use SSL when connecting.
+
 When specifying the connection in environment variable you should specify
 it using URI syntax.
 
diff --git a/docs/apache-airflow-providers-apache-spark/decorators/pyspark.rst 
b/docs/apache-airflow-providers-apache-spark/decorators/pyspark.rst
index 28b51ec848..1755e079b9 100644
--- a/docs/apache-airflow-providers-apache-spark/decorators/pyspark.rst
+++ b/docs/apache-airflow-providers-apache-spark/decorators/pyspark.rst
@@ -23,7 +23,7 @@ PySpark Decorator
 =================
 
 Python callable wrapped within the ``@task.pyspark`` decorator
-is injected with a SparkContext object.
+is injected with a SparkSession and SparkContext object if available.
 
 Parameters
 ----------
@@ -49,3 +49,24 @@ that the ``spark`` and ``sc`` objects are injected into the 
function.
     :dedent: 4
     :start-after: [START task_pyspark]
     :end-before: [END task_pyspark]
+
+
+Spark Connect
+-------------
+
+In `Apache Spark 3.4 
<https://spark.apache.org/docs/latest/spark-connect-overview.html>`_,
+Spark Connect introduced a decoupled client-server architecture
+that allows remote connectivity to Spark clusters using the DataFrame API. 
Using
+Spark Connect is the preferred way in Airflow to make use of the PySpark 
decorator,
+because it does not require to run the Spark driver on the same host as 
Airflow.
+To make use of Spark Connect, you prepend your host url with ``sc://``. For 
example,
+``sc://spark-cluster:15002``.
+
+
+Authentication
+^^^^^^^^^^^^^^
+
+Spark Connect does not have built-in authentication. The gRPC HTTP/2 interface 
however
+allows the use of authentication to communicate with the Spark Connect server 
through
+authenticating proxies. To make use of authentication make sure to create a 
``Spark Connect``
+connection and set the right credentials.
diff --git a/generated/provider_dependencies.json 
b/generated/provider_dependencies.json
index c7c939670a..0b234b4e12 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -190,6 +190,7 @@
   "apache.spark": {
     "deps": [
       "apache-airflow>=2.5.0",
+      "grpcio-status",
       "pyspark"
     ],
     "cross-providers-deps": [
diff --git a/tests/providers/apache/spark/decorators/test_pyspark.py 
b/tests/providers/apache/spark/decorators/test_pyspark.py
index ea307c9c3c..e312e09d40 100644
--- a/tests/providers/apache/spark/decorators/test_pyspark.py
+++ b/tests/providers/apache/spark/decorators/test_pyspark.py
@@ -16,6 +16,7 @@
 # under the License.
 from __future__ import annotations
 
+from typing import Any
 from unittest import mock
 
 import pytest
@@ -27,6 +28,22 @@ from airflow.utils import db, timezone
 DEFAULT_DATE = timezone.datetime(2021, 9, 1)
 
 
+class FakeConfig:
+    data: dict[str, Any]
+
+    def __init__(self, data: dict[str, Any] | None = None):
+        if data:
+            self.data = data
+        else:
+            self.data = {}
+
+    def get(self, key: str, default: Any = None) -> Any:
+        return self.data.get(key, default)
+
+    def set(self, key: str, value: Any) -> None:
+        self.data[key] = value
+
+
 class TestPysparkDecorator:
     def setup_method(self):
         db.merge_conn(
@@ -38,14 +55,47 @@ class TestPysparkDecorator:
             )
         )
 
+        db.merge_conn(
+            Connection(
+                conn_id="spark-connect",
+                conn_type="spark",
+                host="sc://localhost",
+                extra="",
+            )
+        )
+
+        db.merge_conn(
+            Connection(
+                conn_id="spark-connect-auth",
+                conn_type="spark_connect",
+                host="sc://localhost",
+                password="1234",
+                login="connect",
+                extra={
+                    "use_ssl": True,
+                },
+            )
+        )
+
     @pytest.mark.db_test
-    @mock.patch("pyspark.SparkConf.setAppName")
+    @mock.patch("pyspark.SparkConf")
     @mock.patch("pyspark.sql.SparkSession")
     def test_pyspark_decorator_with_connection(self, spark_mock, conf_mock, 
dag_maker):
+        config = FakeConfig()
+
+        builder = mock.MagicMock()
+        spark_mock.builder.config.return_value = builder
+        builder.getOrCreate.return_value = builder
+        builder.sparkContext.return_value = builder
+
+        conf_mock.return_value = config
+
         @task.pyspark(conn_id="pyspark_local", 
config_kwargs={"spark.executor.memory": "2g"})
         def f(spark, sc):
             import random
 
+            assert spark is not None
+            assert sc is not None
             return [random.random() for _ in range(100)]
 
         with dag_maker():
@@ -55,14 +105,20 @@ class TestPysparkDecorator:
         ret.operator.run(start_date=dr.execution_date, 
end_date=dr.execution_date)
         ti = dr.get_task_instances()[0]
         assert len(ti.xcom_pull()) == 100
-        conf_mock().set.assert_called_with("spark.executor.memory", "2g")
-        conf_mock().setMaster.assert_called_once_with("spark://none")
+        assert config.get("spark.master") == "spark://none"
+        assert config.get("spark.executor.memory") == "2g"
+        assert config.get("spark.remote") is None
+        assert config.get("spark.app.name")
+
         spark_mock.builder.config.assert_called_once_with(conf=conf_mock())
 
     @pytest.mark.db_test
-    @mock.patch("pyspark.SparkConf.setAppName")
+    @mock.patch("pyspark.SparkConf")
     @mock.patch("pyspark.sql.SparkSession")
     def test_simple_pyspark_decorator(self, spark_mock, conf_mock, dag_maker):
+        config = FakeConfig()
+        conf_mock.return_value = config
+
         e = 2
 
         @task.pyspark
@@ -76,5 +132,56 @@ class TestPysparkDecorator:
         ret.operator.run(start_date=dr.execution_date, 
end_date=dr.execution_date)
         ti = dr.get_task_instances()[0]
         assert ti.xcom_pull() == e
-        conf_mock().setMaster.assert_called_once_with("local[*]")
+        assert config.get("spark.master") == "local[*]"
         spark_mock.builder.config.assert_called_once_with(conf=conf_mock())
+
+    @pytest.mark.db_test
+    @mock.patch("pyspark.SparkConf")
+    @mock.patch("pyspark.sql.SparkSession")
+    def test_spark_connect(self, spark_mock, conf_mock, dag_maker):
+        config = FakeConfig()
+        conf_mock.return_value = config
+
+        @task.pyspark(conn_id="spark-connect")
+        def f(spark, sc):
+            assert spark is not None
+            assert sc is None
+
+            return True
+
+        with dag_maker():
+            ret = f()
+
+        dr = dag_maker.create_dagrun()
+        ret.operator.run(start_date=dr.execution_date, 
end_date=dr.execution_date)
+        ti = dr.get_task_instances()[0]
+        assert ti.xcom_pull()
+        assert config.get("spark.remote") == "sc://localhost"
+        assert config.get("spark.master") is None
+        assert config.get("spark.app.name")
+        spark_mock.builder.config.assert_called_once_with(conf=conf_mock())
+
+    @pytest.mark.db_test
+    @mock.patch("pyspark.SparkConf")
+    @mock.patch("pyspark.sql.SparkSession")
+    def test_spark_connect_auth(self, spark_mock, conf_mock, dag_maker):
+        config = FakeConfig()
+        conf_mock.return_value = config
+
+        @task.pyspark(conn_id="spark-connect-auth")
+        def f(spark, sc):
+            assert spark is not None
+            assert sc is None
+
+            return True
+
+        with dag_maker():
+            ret = f()
+
+        dr = dag_maker.create_dagrun()
+        ret.operator.run(start_date=dr.execution_date, 
end_date=dr.execution_date)
+        ti = dr.get_task_instances()[0]
+        assert ti.xcom_pull()
+        assert config.get("spark.remote") == 
"sc://localhost/;user_id=connect;token=1234;use_ssl=True"
+        assert config.get("spark.master") is None
+        assert config.get("spark.app.name")
diff --git a/tests/providers/apache/spark/hooks/test_spark_connect.py 
b/tests/providers/apache/spark/hooks/test_spark_connect.py
new file mode 100644
index 0000000000..936432e955
--- /dev/null
+++ b/tests/providers/apache/spark/hooks/test_spark_connect.py
@@ -0,0 +1,72 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import pytest
+
+from airflow.models import Connection
+from airflow.providers.apache.spark.hooks.spark_connect import SparkConnectHook
+from airflow.utils import db
+
+pytestmark = pytest.mark.db_test
+
+
+class TestSparkConnectHook:
+    def setup_method(self):
+        db.merge_conn(
+            Connection(
+                conn_id="spark-default",
+                conn_type="spark_connect",
+                host="sc://spark-host",
+                port=1000,
+                login="spark-user",
+                password="1234",
+                extra='{"queue": "root.etl", "deploy-mode": "cluster"}',
+            )
+        )
+
+        db.merge_conn(
+            Connection(
+                conn_id="spark-test",
+                conn_type="spark_connect",
+                host="nowhere",
+                login="spark-user",
+            )
+        )
+
+        db.merge_conn(
+            Connection(
+                conn_id="spark-app",
+                conn_type="spark_connect",
+                host="sc://cluster/app",
+                login="spark-user",
+            )
+        )
+
+    def test_get_connection_url(self):
+        expected_url = "sc://spark-host:1000/;user_id=spark-user;token=1234"
+        hook = SparkConnectHook(conn_id="spark-default")
+        assert hook.get_connection_url() == expected_url
+
+        expected_url = "sc://nowhere/;user_id=spark-user"
+        hook = SparkConnectHook(conn_id="spark-test")
+        assert hook.get_connection_url() == expected_url
+
+        hook = SparkConnectHook(conn_id="spark-app")
+        with pytest.raises(ValueError):
+            hook.get_connection_url()

Reply via email to