This is an automated email from the ASF dual-hosted git repository.
bolke pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 9103ea1702 Add support for Spark Connect to pyspark decorator (#35665)
9103ea1702 is described below
commit 9103ea170236f8761520aaa70656fcb010ea8e3e
Author: Bolke de Bruin <[email protected]>
AuthorDate: Thu Nov 16 09:51:00 2023 +0100
Add support for Spark Connect to pyspark decorator (#35665)
* Add support for Spark Connect to pyspark decorator
In Apache Spark 3.4 Spark Connect was introduced which
allows remote connectivity to remote Spark Cluster using
the DataFrame API.
---
.../providers/apache/spark/decorators/pyspark.py | 27 +++--
.../providers/apache/spark/hooks/spark_connect.py | 99 +++++++++++++++++
airflow/providers/apache/spark/provider.yaml | 4 +
.../connections/spark.rst | 11 +-
.../decorators/pyspark.rst | 23 +++-
generated/provider_dependencies.json | 1 +
.../apache/spark/decorators/test_pyspark.py | 117 ++++++++++++++++++++-
.../apache/spark/hooks/test_spark_connect.py | 72 +++++++++++++
8 files changed, 339 insertions(+), 15 deletions(-)
diff --git a/airflow/providers/apache/spark/decorators/pyspark.py
b/airflow/providers/apache/spark/decorators/pyspark.py
index 6f576b03a2..c460d09f4f 100644
--- a/airflow/providers/apache/spark/decorators/pyspark.py
+++ b/airflow/providers/apache/spark/decorators/pyspark.py
@@ -23,6 +23,7 @@ from typing import TYPE_CHECKING, Any, Callable, Sequence
from airflow.decorators.base import DecoratedOperator, TaskDecorator,
task_decorator_factory
from airflow.hooks.base import BaseHook
from airflow.operators.python import PythonOperator
+from airflow.providers.apache.spark.hooks.spark_connect import SparkConnectHook
if TYPE_CHECKING:
from airflow.utils.context import Context
@@ -73,34 +74,44 @@ class _PySparkDecoratedOperator(DecoratedOperator,
PythonOperator):
from pyspark import SparkConf
from pyspark.sql import SparkSession
- conf = SparkConf().setAppName(f"{self.dag_id}-{self.task_id}")
+ conf = SparkConf()
+ conf.set("spark.app.name", f"{self.dag_id}-{self.task_id}")
- master = "local[*]"
+ url = "local[*]"
if self.conn_id:
+ # we handle both spark connect and spark standalone
conn = BaseHook.get_connection(self.conn_id)
- if conn.port:
- master = f"{conn.host}:{conn.port}"
+ if conn.conn_type == SparkConnectHook.conn_type:
+ url = SparkConnectHook(self.conn_id).get_connection_url()
+ elif conn.port:
+ url = f"{conn.host}:{conn.port}"
elif conn.host:
- master = conn.host
+ url = conn.host
for key, value in conn.extra_dejson.items():
conf.set(key, value)
- conf.setMaster(master)
+ # you cannot have both remote and master
+ if url.startswith("sc://"):
+ conf.set("spark.remote", url)
# task can override connection config
for key, value in self.config_kwargs.items():
conf.set(key, value)
+ if not conf.get("spark.remote") and not conf.get("spark.master"):
+ conf.set("spark.master", url)
+
spark = SparkSession.builder.config(conf=conf).getOrCreate()
- sc = spark.sparkContext
if not self.op_kwargs:
self.op_kwargs = {}
op_kwargs: dict[str, Any] = dict(self.op_kwargs)
op_kwargs["spark"] = spark
- op_kwargs["sc"] = sc
+
+ # spark context is not available when using spark connect
+ op_kwargs["sc"] = spark.sparkContext if not conf.get("spark.remote")
else None
self.op_kwargs = op_kwargs
return super().execute(context)
diff --git a/airflow/providers/apache/spark/hooks/spark_connect.py
b/airflow/providers/apache/spark/hooks/spark_connect.py
new file mode 100644
index 0000000000..29828b0b78
--- /dev/null
+++ b/airflow/providers/apache/spark/hooks/spark_connect.py
@@ -0,0 +1,99 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import Any
+from urllib.parse import quote, urlparse, urlunparse
+
+from airflow.hooks.base import BaseHook
+from airflow.utils.log.logging_mixin import LoggingMixin
+
+
+class SparkConnectHook(BaseHook, LoggingMixin):
+ """Hook for Spark Connect."""
+
+ # from pyspark's ChannelBuilder
+ PARAM_USE_SSL = "use_ssl"
+ PARAM_TOKEN = "token"
+ PARAM_USER_ID = "user_id"
+
+ conn_name_attr = "conn_id"
+ default_conn_name = "spark_connect_default"
+ conn_type = "spark_connect"
+ hook_name = "Spark Connect"
+
+ @staticmethod
+ def get_ui_field_behaviour() -> dict[str, Any]:
+ """Return custom field behaviour."""
+ return {
+ "hidden_fields": [
+ "schema",
+ ],
+ "relabeling": {"password": "Token", "login": "User ID"},
+ }
+
+ @staticmethod
+ def get_connection_form_widgets() -> dict[str, Any]:
+ """Returns connection widgets to add to connection form."""
+ from flask_babel import lazy_gettext
+ from wtforms import BooleanField
+
+ return {
+ SparkConnectHook.PARAM_USE_SSL: BooleanField(lazy_gettext("Use
SSL"), default=False),
+ }
+
+ def __init__(self, conn_id: str = default_conn_name) -> None:
+ super().__init__()
+ self._conn_id = conn_id
+
+ def get_connection_url(self) -> str:
+ conn = self.get_connection(self._conn_id)
+
+ host = conn.host
+ if conn.host.find("://") == -1:
+ host = f"sc://{conn.host}"
+ if conn.port:
+ host = f"{conn.host}:{conn.port}"
+
+ url = urlparse(host)
+
+ if url.path:
+ raise ValueError("Path {url.path} is not supported in Spark
Connect connection URL")
+
+ params = []
+
+ if conn.login:
+
params.append(f"{SparkConnectHook.PARAM_USER_ID}={quote(conn.login)}")
+
+ if conn.password:
+
params.append(f"{SparkConnectHook.PARAM_TOKEN}={quote(conn.password)}")
+
+ use_ssl = conn.extra_dejson.get(SparkConnectHook.PARAM_USE_SSL)
+ if use_ssl is not None:
+
params.append(f"{SparkConnectHook.PARAM_USE_SSL}={quote(str(use_ssl))}")
+
+ return urlunparse(
+ (
+ "sc",
+ url.netloc,
+ "/",
+ ";".join(params), # params
+ "",
+ url.fragment,
+ )
+ )
diff --git a/airflow/providers/apache/spark/provider.yaml
b/airflow/providers/apache/spark/provider.yaml
index c7b58ec727..31f2728a5c 100644
--- a/airflow/providers/apache/spark/provider.yaml
+++ b/airflow/providers/apache/spark/provider.yaml
@@ -51,6 +51,7 @@ versions:
dependencies:
- apache-airflow>=2.5.0
- pyspark
+ - grpcio-status
integrations:
- integration-name: Apache Spark
@@ -70,6 +71,7 @@ operators:
hooks:
- integration-name: Apache Spark
python-modules:
+ - airflow.providers.apache.spark.hooks.spark_connect
- airflow.providers.apache.spark.hooks.spark_jdbc
- airflow.providers.apache.spark.hooks.spark_jdbc_script
- airflow.providers.apache.spark.hooks.spark_sql
@@ -77,6 +79,8 @@ hooks:
connection-types:
+ - hook-class-name:
airflow.providers.apache.spark.hooks.spark_connect.SparkConnectHook
+ connection-type: spark_connect
- hook-class-name:
airflow.providers.apache.spark.hooks.spark_jdbc.SparkJDBCHook
connection-type: spark_jdbc
- hook-class-name:
airflow.providers.apache.spark.hooks.spark_sql.SparkSqlHook
diff --git a/docs/apache-airflow-providers-apache-spark/connections/spark.rst
b/docs/apache-airflow-providers-apache-spark/connections/spark.rst
index 28249f3af9..05b92ce75c 100644
--- a/docs/apache-airflow-providers-apache-spark/connections/spark.rst
+++ b/docs/apache-airflow-providers-apache-spark/connections/spark.rst
@@ -27,7 +27,7 @@ The Apache Spark connection type enables connection to Apache
Spark.
Default Connection IDs
----------------------
-Spark Submit and Spark JDBC hooks and operators use ``spark_default`` by
default. Spark SQL hooks and operators point to ``spark_sql_default`` by
default.
+Spark Submit and Spark JDBC hooks and operators use ``spark_default`` by
default. Spark SQL hooks and operators point to ``spark_sql_default`` by
default. The Spark Connect hook uses ``spark_connect_default`` by default.
Configuring the Connection
--------------------------
@@ -45,6 +45,15 @@ Extra (optional)
* ``spark-binary`` - The command to use for Spark submit. Some distros may
use ``spark2-submit``. Default ``spark-submit``. Only ``spark-submit``,
``spark2-submit`` or ``spark3-submit`` are allowed as value.
* ``namespace`` - Kubernetes namespace (``spark.kubernetes.namespace``) to
divide cluster resources between multiple users (via resource quota).
+User ID (optional, only applies to Spark Connect)
+ The user ID to authenticate with the proxy.
+
+Token (optional, only applies to Spark Connect)
+ The token to authenticate with the proxy.
+
+Use SSL (optional, only applies to Spark Connect)
+ Whether to use SSL when connecting.
+
When specifying the connection in environment variable you should specify
it using URI syntax.
diff --git a/docs/apache-airflow-providers-apache-spark/decorators/pyspark.rst
b/docs/apache-airflow-providers-apache-spark/decorators/pyspark.rst
index 28b51ec848..1755e079b9 100644
--- a/docs/apache-airflow-providers-apache-spark/decorators/pyspark.rst
+++ b/docs/apache-airflow-providers-apache-spark/decorators/pyspark.rst
@@ -23,7 +23,7 @@ PySpark Decorator
=================
Python callable wrapped within the ``@task.pyspark`` decorator
-is injected with a SparkContext object.
+is injected with a SparkSession and SparkContext object if available.
Parameters
----------
@@ -49,3 +49,24 @@ that the ``spark`` and ``sc`` objects are injected into the
function.
:dedent: 4
:start-after: [START task_pyspark]
:end-before: [END task_pyspark]
+
+
+Spark Connect
+-------------
+
+In `Apache Spark 3.4
<https://spark.apache.org/docs/latest/spark-connect-overview.html>`_,
+Spark Connect introduced a decoupled client-server architecture
+that allows remote connectivity to Spark clusters using the DataFrame API.
Using
+Spark Connect is the preferred way in Airflow to make use of the PySpark
decorator,
+because it does not require to run the Spark driver on the same host as
Airflow.
+To make use of Spark Connect, you prepend your host url with ``sc://``. For
example,
+``sc://spark-cluster:15002``.
+
+
+Authentication
+^^^^^^^^^^^^^^
+
+Spark Connect does not have built-in authentication. The gRPC HTTP/2 interface
however
+allows the use of authentication to communicate with the Spark Connect server
through
+authenticating proxies. To make use of authentication make sure to create a
``Spark Connect``
+connection and set the right credentials.
diff --git a/generated/provider_dependencies.json
b/generated/provider_dependencies.json
index c7c939670a..0b234b4e12 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -190,6 +190,7 @@
"apache.spark": {
"deps": [
"apache-airflow>=2.5.0",
+ "grpcio-status",
"pyspark"
],
"cross-providers-deps": [
diff --git a/tests/providers/apache/spark/decorators/test_pyspark.py
b/tests/providers/apache/spark/decorators/test_pyspark.py
index ea307c9c3c..e312e09d40 100644
--- a/tests/providers/apache/spark/decorators/test_pyspark.py
+++ b/tests/providers/apache/spark/decorators/test_pyspark.py
@@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations
+from typing import Any
from unittest import mock
import pytest
@@ -27,6 +28,22 @@ from airflow.utils import db, timezone
DEFAULT_DATE = timezone.datetime(2021, 9, 1)
+class FakeConfig:
+ data: dict[str, Any]
+
+ def __init__(self, data: dict[str, Any] | None = None):
+ if data:
+ self.data = data
+ else:
+ self.data = {}
+
+ def get(self, key: str, default: Any = None) -> Any:
+ return self.data.get(key, default)
+
+ def set(self, key: str, value: Any) -> None:
+ self.data[key] = value
+
+
class TestPysparkDecorator:
def setup_method(self):
db.merge_conn(
@@ -38,14 +55,47 @@ class TestPysparkDecorator:
)
)
+ db.merge_conn(
+ Connection(
+ conn_id="spark-connect",
+ conn_type="spark",
+ host="sc://localhost",
+ extra="",
+ )
+ )
+
+ db.merge_conn(
+ Connection(
+ conn_id="spark-connect-auth",
+ conn_type="spark_connect",
+ host="sc://localhost",
+ password="1234",
+ login="connect",
+ extra={
+ "use_ssl": True,
+ },
+ )
+ )
+
@pytest.mark.db_test
- @mock.patch("pyspark.SparkConf.setAppName")
+ @mock.patch("pyspark.SparkConf")
@mock.patch("pyspark.sql.SparkSession")
def test_pyspark_decorator_with_connection(self, spark_mock, conf_mock,
dag_maker):
+ config = FakeConfig()
+
+ builder = mock.MagicMock()
+ spark_mock.builder.config.return_value = builder
+ builder.getOrCreate.return_value = builder
+ builder.sparkContext.return_value = builder
+
+ conf_mock.return_value = config
+
@task.pyspark(conn_id="pyspark_local",
config_kwargs={"spark.executor.memory": "2g"})
def f(spark, sc):
import random
+ assert spark is not None
+ assert sc is not None
return [random.random() for _ in range(100)]
with dag_maker():
@@ -55,14 +105,20 @@ class TestPysparkDecorator:
ret.operator.run(start_date=dr.execution_date,
end_date=dr.execution_date)
ti = dr.get_task_instances()[0]
assert len(ti.xcom_pull()) == 100
- conf_mock().set.assert_called_with("spark.executor.memory", "2g")
- conf_mock().setMaster.assert_called_once_with("spark://none")
+ assert config.get("spark.master") == "spark://none"
+ assert config.get("spark.executor.memory") == "2g"
+ assert config.get("spark.remote") is None
+ assert config.get("spark.app.name")
+
spark_mock.builder.config.assert_called_once_with(conf=conf_mock())
@pytest.mark.db_test
- @mock.patch("pyspark.SparkConf.setAppName")
+ @mock.patch("pyspark.SparkConf")
@mock.patch("pyspark.sql.SparkSession")
def test_simple_pyspark_decorator(self, spark_mock, conf_mock, dag_maker):
+ config = FakeConfig()
+ conf_mock.return_value = config
+
e = 2
@task.pyspark
@@ -76,5 +132,56 @@ class TestPysparkDecorator:
ret.operator.run(start_date=dr.execution_date,
end_date=dr.execution_date)
ti = dr.get_task_instances()[0]
assert ti.xcom_pull() == e
- conf_mock().setMaster.assert_called_once_with("local[*]")
+ assert config.get("spark.master") == "local[*]"
spark_mock.builder.config.assert_called_once_with(conf=conf_mock())
+
+ @pytest.mark.db_test
+ @mock.patch("pyspark.SparkConf")
+ @mock.patch("pyspark.sql.SparkSession")
+ def test_spark_connect(self, spark_mock, conf_mock, dag_maker):
+ config = FakeConfig()
+ conf_mock.return_value = config
+
+ @task.pyspark(conn_id="spark-connect")
+ def f(spark, sc):
+ assert spark is not None
+ assert sc is None
+
+ return True
+
+ with dag_maker():
+ ret = f()
+
+ dr = dag_maker.create_dagrun()
+ ret.operator.run(start_date=dr.execution_date,
end_date=dr.execution_date)
+ ti = dr.get_task_instances()[0]
+ assert ti.xcom_pull()
+ assert config.get("spark.remote") == "sc://localhost"
+ assert config.get("spark.master") is None
+ assert config.get("spark.app.name")
+ spark_mock.builder.config.assert_called_once_with(conf=conf_mock())
+
+ @pytest.mark.db_test
+ @mock.patch("pyspark.SparkConf")
+ @mock.patch("pyspark.sql.SparkSession")
+ def test_spark_connect_auth(self, spark_mock, conf_mock, dag_maker):
+ config = FakeConfig()
+ conf_mock.return_value = config
+
+ @task.pyspark(conn_id="spark-connect-auth")
+ def f(spark, sc):
+ assert spark is not None
+ assert sc is None
+
+ return True
+
+ with dag_maker():
+ ret = f()
+
+ dr = dag_maker.create_dagrun()
+ ret.operator.run(start_date=dr.execution_date,
end_date=dr.execution_date)
+ ti = dr.get_task_instances()[0]
+ assert ti.xcom_pull()
+ assert config.get("spark.remote") ==
"sc://localhost/;user_id=connect;token=1234;use_ssl=True"
+ assert config.get("spark.master") is None
+ assert config.get("spark.app.name")
diff --git a/tests/providers/apache/spark/hooks/test_spark_connect.py
b/tests/providers/apache/spark/hooks/test_spark_connect.py
new file mode 100644
index 0000000000..936432e955
--- /dev/null
+++ b/tests/providers/apache/spark/hooks/test_spark_connect.py
@@ -0,0 +1,72 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import pytest
+
+from airflow.models import Connection
+from airflow.providers.apache.spark.hooks.spark_connect import SparkConnectHook
+from airflow.utils import db
+
+pytestmark = pytest.mark.db_test
+
+
+class TestSparkConnectHook:
+ def setup_method(self):
+ db.merge_conn(
+ Connection(
+ conn_id="spark-default",
+ conn_type="spark_connect",
+ host="sc://spark-host",
+ port=1000,
+ login="spark-user",
+ password="1234",
+ extra='{"queue": "root.etl", "deploy-mode": "cluster"}',
+ )
+ )
+
+ db.merge_conn(
+ Connection(
+ conn_id="spark-test",
+ conn_type="spark_connect",
+ host="nowhere",
+ login="spark-user",
+ )
+ )
+
+ db.merge_conn(
+ Connection(
+ conn_id="spark-app",
+ conn_type="spark_connect",
+ host="sc://cluster/app",
+ login="spark-user",
+ )
+ )
+
+ def test_get_connection_url(self):
+ expected_url = "sc://spark-host:1000/;user_id=spark-user;token=1234"
+ hook = SparkConnectHook(conn_id="spark-default")
+ assert hook.get_connection_url() == expected_url
+
+ expected_url = "sc://nowhere/;user_id=spark-user"
+ hook = SparkConnectHook(conn_id="spark-test")
+ assert hook.get_connection_url() == expected_url
+
+ hook = SparkConnectHook(conn_id="spark-app")
+ with pytest.raises(ValueError):
+ hook.get_connection_url()