This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.0 by this push:
new cd513f5705b7 [SPARK-51156][CONNECT] Static token authentication
support in Spark Connect
cd513f5705b7 is described below
commit cd513f5705b706a5d9061049b6551f434e0a9034
Author: Adam Binford <[email protected]>
AuthorDate: Sun Feb 23 10:25:59 2025 +0900
[SPARK-51156][CONNECT] Static token authentication support in Spark Connect
### What changes were proposed in this pull request?
Adds static token authentication support to Spark Connect, which is used by
default for automatically started servers locally.
### Why are the changes needed?
To add authentication support to Spark Connect so a connect server isn't
started that could be accessible to other users inadvertently.
### Does this PR introduce _any_ user-facing change?
The local authentication should be transparent to users, but adds the
option for users manually starting connect servers to specify an authentication
token.
### How was this patch tested?
Existing UTs
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #50006 from Kimahriman/spark-connect-local-auth.
Lead-authored-by: Adam Binford <[email protected]>
Co-authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
(cherry picked from commit 7e9547c6329334e26118e873afaf0b1173019169)
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/sql/connect/client/core.py | 29 ++++++++-----
python/pyspark/sql/connect/session.py | 11 ++++-
.../sql/tests/connect/test_connect_session.py | 18 ++++++++-
.../pandas/test_pandas_grouped_map_with_state.py | 3 +-
python/pyspark/testing/connectutils.py | 3 ++
.../SparkConnectClientBuilderParseTestSuite.scala | 2 +-
.../connect/client/SparkConnectClientSuite.scala | 4 +-
.../apache/spark/sql/connect/SparkSession.scala | 4 +-
.../sql/connect/client/SparkConnectClient.scala | 27 +++++++------
.../apache/spark/sql/connect/config/Connect.scala | 18 +++++++++
.../planner/StreamingForeachBatchHelper.scala | 6 ++-
.../planner/StreamingQueryListenerHelper.scala | 6 ++-
.../PreSharedKeyAuthenticationInterceptor.scala | 47 ++++++++++++++++++++++
.../sql/connect/service/SparkConnectService.scala | 6 ++-
.../connect/service/SparkConnectAuthSuite.scala | 46 +++++++++++++++++++++
15 files changed, 197 insertions(+), 33 deletions(-)
diff --git a/python/pyspark/sql/connect/client/core.py
b/python/pyspark/sql/connect/client/core.py
index 918540fa756b..360f391de6c1 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -220,7 +220,9 @@ class ChannelBuilder:
@property
def token(self) -> Optional[str]:
- return self._params.get(ChannelBuilder.PARAM_TOKEN, None)
+ return self._params.get(
+ ChannelBuilder.PARAM_TOKEN,
os.environ.get("SPARK_CONNECT_AUTHENTICATE_TOKEN")
+ )
def metadata(self) -> Iterable[Tuple[str, str]]:
"""
@@ -410,10 +412,11 @@ class DefaultChannelBuilder(ChannelBuilder):
@property
def secure(self) -> bool:
- return (
- self.getDefault(ChannelBuilder.PARAM_USE_SSL, "").lower() == "true"
- or self.token is not None
- )
+ return self.use_ssl or self.token is not None
+
+ @property
+ def use_ssl(self) -> bool:
+ return self.getDefault(ChannelBuilder.PARAM_USE_SSL, "").lower() ==
"true"
@property
def host(self) -> str:
@@ -439,14 +442,20 @@ class DefaultChannelBuilder(ChannelBuilder):
if not self.secure:
return self._insecure_channel(self.endpoint)
+ elif not self.use_ssl and self._host == "localhost":
+ creds = grpc.local_channel_credentials()
+
+ if self.token is not None:
+ creds = grpc.composite_channel_credentials(
+ creds, grpc.access_token_call_credentials(self.token)
+ )
+ return self._secure_channel(self.endpoint, creds)
else:
- ssl_creds = grpc.ssl_channel_credentials()
+ creds = grpc.ssl_channel_credentials()
- if self.token is None:
- creds = ssl_creds
- else:
+ if self.token is not None:
creds = grpc.composite_channel_credentials(
- ssl_creds, grpc.access_token_call_credentials(self.token)
+ creds, grpc.access_token_call_credentials(self.token)
)
return self._secure_channel(self.endpoint, creds)
diff --git a/python/pyspark/sql/connect/session.py
b/python/pyspark/sql/connect/session.py
index c863af3265dc..4918762d240e 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+import uuid
from pyspark.sql.connect.utils import check_dependencies
check_dependencies(__name__)
@@ -1030,6 +1031,8 @@ class SparkSession:
2. Starts a regular Spark session that automatically starts a Spark
Connect server
via ``spark.plugins`` feature.
+
+ Returns the authentication token that should be used to connect to
this session.
"""
from pyspark import SparkContext, SparkConf
@@ -1049,6 +1052,13 @@ class SparkSession:
if "spark.api.mode" in overwrite_conf:
del overwrite_conf["spark.api.mode"]
+ # Check for a user provided authentication token, creating a new
one if not,
+ # and make sure it's set in the environment,
+ if "SPARK_CONNECT_AUTHENTICATE_TOKEN" not in os.environ:
+ os.environ["SPARK_CONNECT_AUTHENTICATE_TOKEN"] = opts.get(
+ "spark.connect.authenticate.token", str(uuid.uuid4())
+ )
+
# Configurations to be set if unset.
default_conf = {
"spark.plugins":
"org.apache.spark.sql.connect.SparkConnectPlugin",
@@ -1081,7 +1091,6 @@ class SparkSession:
new_opts = {k: opts[k] for k in opts if k in runtime_conf_keys}
opts.clear()
opts.update(new_opts)
-
finally:
if origin_remote is not None:
os.environ["SPARK_REMOTE"] = origin_remote
diff --git a/python/pyspark/sql/tests/connect/test_connect_session.py
b/python/pyspark/sql/tests/connect/test_connect_session.py
index 1ab069a4025c..1fd59609d450 100644
--- a/python/pyspark/sql/tests/connect/test_connect_session.py
+++ b/python/pyspark/sql/tests/connect/test_connect_session.py
@@ -43,6 +43,7 @@ if should_test_connect:
from pyspark.errors.exceptions.connect import (
AnalysisException,
SparkConnectException,
+ SparkConnectGrpcException,
SparkUpgradeException,
)
@@ -237,7 +238,13 @@ class SparkConnectSessionTests(ReusedConnectTestCase):
class CustomChannelBuilder(ChannelBuilder):
def toChannel(self):
- return self._insecure_channel(endpoint)
+ creds = grpc.local_channel_credentials()
+
+ if self.token is not None:
+ creds = grpc.composite_channel_credentials(
+ creds, grpc.access_token_call_credentials(self.token)
+ )
+ return self._secure_channel(endpoint, creds)
session =
RemoteSparkSession.builder.channelBuilder(CustomChannelBuilder()).create()
session.sql("select 1 + 1")
@@ -290,6 +297,15 @@ class SparkConnectSessionTests(ReusedConnectTestCase):
self.assertEqual(session.range(1).first()[0], 0)
self.assertIsInstance(session, RemoteSparkSession)
+ def test_authentication(self):
+ # All servers start with a default token of "deadbeef", so supply in
invalid one
+ session =
RemoteSparkSession.builder.remote("sc://localhost/;token=invalid").create()
+
+ with self.assertRaises(SparkConnectGrpcException) as e:
+ session.range(3).collect()
+
+ self.assertTrue("Invalid authentication token" in str(e.exception))
+
@unittest.skipIf(not should_test_connect, connect_requirement_message)
class SparkConnectSessionWithOptionsTest(unittest.TestCase):
diff --git
a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py
b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py
index 47f7d672cc8c..e1b8d7c76d18 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py
@@ -24,7 +24,6 @@ import tempfile
import unittest
from typing import cast
-from pyspark import SparkConf
from pyspark.sql.streaming.state import GroupStateTimeout, GroupState
from pyspark.sql.types import (
LongType,
@@ -56,7 +55,7 @@ if have_pyarrow:
class GroupedApplyInPandasWithStateTestsMixin:
@classmethod
def conf(cls):
- cfg = SparkConf()
+ cfg = super().conf()
cfg.set("spark.sql.shuffle.partitions", "5")
return cfg
diff --git a/python/pyspark/testing/connectutils.py
b/python/pyspark/testing/connectutils.py
index 7dea8a2103c3..423a717e8ab5 100644
--- a/python/pyspark/testing/connectutils.py
+++ b/python/pyspark/testing/connectutils.py
@@ -155,6 +155,9 @@ class ReusedConnectTestCase(unittest.TestCase,
SQLTestUtils, PySparkErrorTestUti
conf._jconf.remove("spark.master")
conf.set("spark.connect.execute.reattachable.senderMaxStreamDuration",
"1s")
conf.set("spark.connect.execute.reattachable.senderMaxStreamSize",
"123")
+ # Set a static token for all tests so the parallelism doesn't
overwrite each
+ # tests' environment variables
+ conf.set("spark.connect.authenticate.token", "deadbeef")
return conf
@classmethod
diff --git
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientBuilderParseTestSuite.scala
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientBuilderParseTestSuite.scala
index b342d5b41569..a3c022066532 100644
---
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientBuilderParseTestSuite.scala
+++
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientBuilderParseTestSuite.scala
@@ -125,7 +125,7 @@ class SparkConnectClientBuilderParseTestSuite extends
ConnectFunSuite {
assert(builder.host === "localhost")
assert(builder.port === 15002)
assert(builder.userAgent.contains("_SPARK_CONNECT_SCALA"))
- assert(builder.sslEnabled)
+ assert(!builder.sslEnabled)
assert(builder.token.contains("thisismysecret"))
assert(builder.userId.isEmpty)
assert(builder.userName.isEmpty)
diff --git
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
index acee1b2775f1..3d1ba71b9f90 100644
---
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
+++
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
@@ -299,8 +299,8 @@ class SparkConnectClientSuite extends ConnectFunSuite with
BeforeAndAfterEach {
TestPackURI(s"sc://host:123/;session_id=${UUID.randomUUID().toString}",
isCorrect = true),
TestPackURI("sc://host:123/;use_ssl=true;token=mySecretToken", isCorrect =
true),
TestPackURI("sc://host:123/;token=mySecretToken;use_ssl=true", isCorrect =
true),
- TestPackURI("sc://host:123/;use_ssl=false;token=mySecretToken", isCorrect
= false),
- TestPackURI("sc://host:123/;token=mySecretToken;use_ssl=false", isCorrect
= false),
+ TestPackURI("sc://host:123/;use_ssl=false;token=mySecretToken", isCorrect
= true),
+ TestPackURI("sc://host:123/;token=mySecretToken;use_ssl=false", isCorrect
= true),
TestPackURI("sc://host:123/;param1=value1;param2=value2", isCorrect =
true),
TestPackURI(
"sc://SPARK-45486",
diff --git
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
index c4067ea3ac33..0af7c7b6d97a 100644
---
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
+++
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
@@ -762,6 +762,7 @@ object SparkSession extends SparkSessionCompanion with
Logging {
(remoteString.exists(_.startsWith("local")) ||
(remoteString.isDefined && isAPIModeConnect)) &&
maybeConnectStartScript.exists(Files.exists(_))) {
+ val token = java.util.UUID.randomUUID().toString()
val serverId = UUID.randomUUID().toString
server = Some {
val args =
@@ -779,6 +780,7 @@ object SparkSession extends SparkSessionCompanion with
Logging {
pb.environment().remove(SparkConnectClient.SPARK_REMOTE)
pb.environment().put("SPARK_IDENT_STRING", serverId)
pb.environment().put("HOSTNAME", "local")
+ pb.environment().put("SPARK_CONNECT_AUTHENTICATE_TOKEN", token)
pb.start()
}
@@ -800,7 +802,7 @@ object SparkSession extends SparkSessionCompanion with
Logging {
}
}
- System.setProperty("spark.remote", "sc://localhost")
+ System.setProperty("spark.remote", s"sc://localhost/;token=$token")
// scalastyle:off runtimeaddshutdownhook
Runtime.getRuntime.addShutdownHook(new Thread() {
diff --git
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
index dd241c50c934..57ed45418316 100644
---
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
+++
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
@@ -468,8 +468,6 @@ object SparkConnectClient {
* sc://localhost/;token=aaa;use_ssl=true
* }}}
*
- * Throws exception if the token is set but use_ssl=false.
- *
* @param inputToken
* the user token.
* @return
@@ -477,11 +475,7 @@ object SparkConnectClient {
*/
def token(inputToken: String): Builder = {
require(inputToken != null && inputToken.nonEmpty)
- if (_configuration.isSslEnabled.contains(false)) {
- throw new
IllegalArgumentException(AUTH_TOKEN_ON_INSECURE_CONN_ERROR_MSG)
- }
- _configuration =
- _configuration.copy(token = Option(inputToken), isSslEnabled =
Option(true))
+ _configuration = _configuration.copy(token = Option(inputToken))
this
}
@@ -499,7 +493,6 @@ object SparkConnectClient {
* this builder.
*/
def disableSsl(): Builder = {
- require(token.isEmpty, AUTH_TOKEN_ON_INSECURE_CONN_ERROR_MSG)
_configuration = _configuration.copy(isSslEnabled = Option(false))
this
}
@@ -737,6 +730,8 @@ object SparkConnectClient {
grpcMaxMessageSize: Int = ConnectCommon.CONNECT_GRPC_MAX_MESSAGE_SIZE,
grpcMaxRecursionLimit: Int =
ConnectCommon.CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT) {
+ private def isLocal = host.equals("localhost")
+
def userContext: proto.UserContext = {
val builder = proto.UserContext.newBuilder()
if (userId != null) {
@@ -749,7 +744,7 @@ object SparkConnectClient {
}
def credentials: ChannelCredentials = {
- if (isSslEnabled.contains(true)) {
+ if (isSslEnabled.contains(true) || (token.isDefined && !isLocal)) {
token match {
case Some(t) =>
// With access token added in the http header.
@@ -765,10 +760,18 @@ object SparkConnectClient {
}
def createChannel(): ManagedChannel = {
- val channelBuilder = Grpc.newChannelBuilderForAddress(host, port,
credentials)
+ val creds = credentials
+ val channelBuilder = Grpc.newChannelBuilderForAddress(host, port, creds)
+
+ // Workaround LocalChannelCredentials are added in
+ // https://github.com/grpc/grpc-java/issues/9900
+ var metadataWithOptionalToken = metadata
+ if (!isSslEnabled.contains(true) && isLocal && token.isDefined) {
+ metadataWithOptionalToken = metadata + (("Authorization", s"Bearer
${token.get}"))
+ }
- if (metadata.nonEmpty) {
- channelBuilder.intercept(new MetadataHeaderClientInterceptor(metadata))
+ if (metadataWithOptionalToken.nonEmpty) {
+ channelBuilder.intercept(new
MetadataHeaderClientInterceptor(metadataWithOptionalToken))
}
interceptors.foreach(channelBuilder.intercept(_))
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
index b0c5a2a055b5..9f884b683079 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
@@ -18,6 +18,7 @@ package org.apache.spark.sql.connect.config
import java.util.concurrent.TimeUnit
+import org.apache.spark.SparkEnv
import org.apache.spark.network.util.ByteUnit
import org.apache.spark.sql.connect.common.config.ConnectCommon
import org.apache.spark.sql.internal.SQLConf
@@ -313,4 +314,21 @@ object Connect {
.internal()
.booleanConf
.createWithDefault(true)
+
+ val CONNECT_AUTHENTICATE_TOKEN =
+ buildStaticConf("spark.connect.authenticate.token")
+ .doc("A pre-shared token that will be used to authenticate clients. This
secret must be" +
+ " passed as a bearer token by for clients to connect.")
+ .version("4.0.0")
+ .internal()
+ .stringConf
+ .createOptional
+
+ val CONNECT_AUTHENTICATE_TOKEN_ENV = "SPARK_CONNECT_AUTHENTICATE_TOKEN"
+
+ def getAuthenticateToken: Option[String] = {
+ SparkEnv.get.conf.get(CONNECT_AUTHENTICATE_TOKEN).orElse {
+ Option(System.getenv.get(CONNECT_AUTHENTICATE_TOKEN_ENV))
+ }
+ }
}
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
index 07c5da9744cc..cc6b58216ed7 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
@@ -31,6 +31,7 @@ import org.apache.spark.internal.LogKeys.{DATAFRAME_ID,
QUERY_ID, RUN_ID_STRING,
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder,
AgnosticEncoders}
import org.apache.spark.sql.connect.common.ForeachWriterPacket
+import org.apache.spark.sql.connect.config.Connect
import org.apache.spark.sql.connect.service.SessionHolder
import org.apache.spark.sql.connect.service.SparkConnectService
import org.apache.spark.sql.streaming.StreamingQuery
@@ -131,7 +132,10 @@ object StreamingForeachBatchHelper extends Logging {
sessionHolder: SessionHolder): (ForeachBatchFnType, AutoCloseable) = {
val port = SparkConnectService.localPort
- val connectUrl = s"sc://localhost:$port/;user_id=${sessionHolder.userId}"
+ var connectUrl = s"sc://localhost:$port/;user_id=${sessionHolder.userId}"
+ Connect.getAuthenticateToken.foreach { token =>
+ connectUrl = s"$connectUrl;token=$token"
+ }
val runner = StreamingPythonRunner(
pythonFn,
connectUrl,
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala
index c342050a212e..42c090d43f06 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala
@@ -23,6 +23,7 @@ import org.apache.spark.SparkException
import org.apache.spark.api.python.{PythonException, PythonWorkerUtils,
SimplePythonFunction, SpecialLengths, StreamingPythonRunner}
import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys.FUNCTION_NAME
+import org.apache.spark.sql.connect.config.Connect
import org.apache.spark.sql.connect.service.{SessionHolder,
SparkConnectService}
import org.apache.spark.sql.streaming.StreamingQueryListener
@@ -36,7 +37,10 @@ class PythonStreamingQueryListener(listener:
SimplePythonFunction, sessionHolder
with Logging {
private val port = SparkConnectService.localPort
- private val connectUrl =
s"sc://localhost:$port/;user_id=${sessionHolder.userId}"
+ private var connectUrl =
s"sc://localhost:$port/;user_id=${sessionHolder.userId}"
+ Connect.getAuthenticateToken.foreach { token =>
+ connectUrl = s"$connectUrl;token=$token"
+ }
// Scoped for testing
private[connect] val runner = StreamingPythonRunner(
listener,
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/PreSharedKeyAuthenticationInterceptor.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/PreSharedKeyAuthenticationInterceptor.scala
new file mode 100644
index 000000000000..5d7cc65358eb
--- /dev/null
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/PreSharedKeyAuthenticationInterceptor.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.service
+
+import io.grpc.{Metadata, ServerCall, ServerCallHandler, ServerInterceptor,
Status}
+
+class PreSharedKeyAuthenticationInterceptor(token: String) extends
ServerInterceptor {
+
+ val authorizationMetadataKey =
+ Metadata.Key.of("Authorization", Metadata.ASCII_STRING_MARSHALLER)
+
+ val expectedValue = s"Bearer $token"
+
+ override def interceptCall[ReqT, RespT](
+ call: ServerCall[ReqT, RespT],
+ metadata: Metadata,
+ next: ServerCallHandler[ReqT, RespT]): ServerCall.Listener[ReqT] = {
+ val authHeaderValue = metadata.get(authorizationMetadataKey)
+
+ if (authHeaderValue == null) {
+ val status = Status.UNAUTHENTICATED.withDescription("No authentication
token provided")
+ call.close(status, new Metadata())
+ new ServerCall.Listener[ReqT]() {}
+ } else if (authHeaderValue != expectedValue) {
+ val status = Status.UNAUTHENTICATED.withDescription("Invalid
authentication token")
+ call.close(status, new Metadata())
+ new ServerCall.Listener[ReqT]() {}
+ } else {
+ next.startCall(call, metadata)
+ }
+ }
+}
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
index e62c19b66c8e..8fa64ddcce49 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
@@ -39,7 +39,7 @@ import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys.HOST
import org.apache.spark.internal.config.UI.UI_ENABLED
import org.apache.spark.scheduler.{LiveListenerBus, SparkListenerEvent}
-import
org.apache.spark.sql.connect.config.Connect.{CONNECT_GRPC_BINDING_ADDRESS,
CONNECT_GRPC_BINDING_PORT, CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT,
CONNECT_GRPC_MAX_INBOUND_MESSAGE_SIZE, CONNECT_GRPC_PORT_MAX_RETRIES}
+import org.apache.spark.sql.connect.config.Connect.{getAuthenticateToken,
CONNECT_GRPC_BINDING_ADDRESS, CONNECT_GRPC_BINDING_PORT,
CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT, CONNECT_GRPC_MAX_INBOUND_MESSAGE_SIZE,
CONNECT_GRPC_PORT_MAX_RETRIES}
import org.apache.spark.sql.connect.execution.ConnectProgressExecutionListener
import org.apache.spark.sql.connect.ui.{SparkConnectServerAppStatusStore,
SparkConnectServerListener, SparkConnectServerTab}
import org.apache.spark.sql.connect.utils.ErrorUtils
@@ -381,6 +381,10 @@ object SparkConnectService extends Logging {
sb.maxInboundMessageSize(SparkEnv.get.conf.get(CONNECT_GRPC_MAX_INBOUND_MESSAGE_SIZE).toInt)
.addService(sparkConnectService)
+ getAuthenticateToken.foreach { token =>
+ sb.intercept(new PreSharedKeyAuthenticationInterceptor(token))
+ }
+
// Add all registered interceptors to the server builder.
SparkConnectInterceptorRegistry.chainInterceptors(sb,
configuredInterceptors)
diff --git
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectAuthSuite.scala
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectAuthSuite.scala
new file mode 100644
index 000000000000..30f186ab7c2b
--- /dev/null
+++
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectAuthSuite.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connect.service
+
+// import io.grpc.StatusRuntimeException
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.connect.{SparkConnectServerTest, SparkSession}
+import org.apache.spark.sql.connect.service.SparkConnectService
+
+class SparkConnectAuthSuite extends SparkConnectServerTest {
+ override protected def sparkConf = {
+ super.sparkConf.set("spark.connect.authenticate.token", "deadbeef")
+ }
+
+ test("Test local authentication") {
+ val session = SparkSession
+ .builder()
+
.remote(s"sc://localhost:${SparkConnectService.localPort}/;token=deadbeef")
+ .create()
+ session.range(5).collect()
+
+ val invalidSession = SparkSession
+ .builder()
+
.remote(s"sc://localhost:${SparkConnectService.localPort}/;token=invalid")
+ .create()
+ val exception = intercept[SparkException] {
+ invalidSession.range(5).collect()
+ }
+ assert(exception.getMessage.contains("Invalid authentication token"))
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]