This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.0 by this push:
     new cd513f5705b7 [SPARK-51156][CONNECT] Static token authentication 
support in Spark Connect
cd513f5705b7 is described below

commit cd513f5705b706a5d9061049b6551f434e0a9034
Author: Adam Binford <[email protected]>
AuthorDate: Sun Feb 23 10:25:59 2025 +0900

    [SPARK-51156][CONNECT] Static token authentication support in Spark Connect
    
    ### What changes were proposed in this pull request?
    
    Adds static token authentication support to Spark Connect, which is used by 
default for automatically started servers locally.
    
    ### Why are the changes needed?
    
    To add authentication support to Spark Connect so a connect server isn't 
started that could be accessible to other users inadvertently.
    
    ### Does this PR introduce _any_ user-facing change?
    
    The local authentication should be transparent to users, but adds the 
option for users manually starting connect servers to specify an authentication 
token.
    
    ### How was this patch tested?
    
    Existing UTs
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #50006 from Kimahriman/spark-connect-local-auth.
    
    Lead-authored-by: Adam Binford <[email protected]>
    Co-authored-by: Hyukjin Kwon <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
    (cherry picked from commit 7e9547c6329334e26118e873afaf0b1173019169)
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/sql/connect/client/core.py          | 29 ++++++++-----
 python/pyspark/sql/connect/session.py              | 11 ++++-
 .../sql/tests/connect/test_connect_session.py      | 18 ++++++++-
 .../pandas/test_pandas_grouped_map_with_state.py   |  3 +-
 python/pyspark/testing/connectutils.py             |  3 ++
 .../SparkConnectClientBuilderParseTestSuite.scala  |  2 +-
 .../connect/client/SparkConnectClientSuite.scala   |  4 +-
 .../apache/spark/sql/connect/SparkSession.scala    |  4 +-
 .../sql/connect/client/SparkConnectClient.scala    | 27 +++++++------
 .../apache/spark/sql/connect/config/Connect.scala  | 18 +++++++++
 .../planner/StreamingForeachBatchHelper.scala      |  6 ++-
 .../planner/StreamingQueryListenerHelper.scala     |  6 ++-
 .../PreSharedKeyAuthenticationInterceptor.scala    | 47 ++++++++++++++++++++++
 .../sql/connect/service/SparkConnectService.scala  |  6 ++-
 .../connect/service/SparkConnectAuthSuite.scala    | 46 +++++++++++++++++++++
 15 files changed, 197 insertions(+), 33 deletions(-)

diff --git a/python/pyspark/sql/connect/client/core.py 
b/python/pyspark/sql/connect/client/core.py
index 918540fa756b..360f391de6c1 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -220,7 +220,9 @@ class ChannelBuilder:
 
     @property
     def token(self) -> Optional[str]:
-        return self._params.get(ChannelBuilder.PARAM_TOKEN, None)
+        return self._params.get(
+            ChannelBuilder.PARAM_TOKEN, 
os.environ.get("SPARK_CONNECT_AUTHENTICATE_TOKEN")
+        )
 
     def metadata(self) -> Iterable[Tuple[str, str]]:
         """
@@ -410,10 +412,11 @@ class DefaultChannelBuilder(ChannelBuilder):
 
     @property
     def secure(self) -> bool:
-        return (
-            self.getDefault(ChannelBuilder.PARAM_USE_SSL, "").lower() == "true"
-            or self.token is not None
-        )
+        return self.use_ssl or self.token is not None
+
+    @property
+    def use_ssl(self) -> bool:
+        return self.getDefault(ChannelBuilder.PARAM_USE_SSL, "").lower() == 
"true"
 
     @property
     def host(self) -> str:
@@ -439,14 +442,20 @@ class DefaultChannelBuilder(ChannelBuilder):
 
         if not self.secure:
             return self._insecure_channel(self.endpoint)
+        elif not self.use_ssl and self._host == "localhost":
+            creds = grpc.local_channel_credentials()
+
+            if self.token is not None:
+                creds = grpc.composite_channel_credentials(
+                    creds, grpc.access_token_call_credentials(self.token)
+                )
+            return self._secure_channel(self.endpoint, creds)
         else:
-            ssl_creds = grpc.ssl_channel_credentials()
+            creds = grpc.ssl_channel_credentials()
 
-            if self.token is None:
-                creds = ssl_creds
-            else:
+            if self.token is not None:
                 creds = grpc.composite_channel_credentials(
-                    ssl_creds, grpc.access_token_call_credentials(self.token)
+                    creds, grpc.access_token_call_credentials(self.token)
                 )
 
             return self._secure_channel(self.endpoint, creds)
diff --git a/python/pyspark/sql/connect/session.py 
b/python/pyspark/sql/connect/session.py
index c863af3265dc..4918762d240e 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -14,6 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
+import uuid
 from pyspark.sql.connect.utils import check_dependencies
 
 check_dependencies(__name__)
@@ -1030,6 +1031,8 @@ class SparkSession:
 
         2. Starts a regular Spark session that automatically starts a Spark 
Connect server
            via ``spark.plugins`` feature.
+
+        Returns the authentication token that should be used to connect to 
this session.
         """
         from pyspark import SparkContext, SparkConf
 
@@ -1049,6 +1052,13 @@ class SparkSession:
             if "spark.api.mode" in overwrite_conf:
                 del overwrite_conf["spark.api.mode"]
 
+            # Check for a user provided authentication token, creating a new 
one if not,
+            # and make sure it's set in the environment,
+            if "SPARK_CONNECT_AUTHENTICATE_TOKEN" not in os.environ:
+                os.environ["SPARK_CONNECT_AUTHENTICATE_TOKEN"] = opts.get(
+                    "spark.connect.authenticate.token", str(uuid.uuid4())
+                )
+
             # Configurations to be set if unset.
             default_conf = {
                 "spark.plugins": 
"org.apache.spark.sql.connect.SparkConnectPlugin",
@@ -1081,7 +1091,6 @@ class SparkSession:
                 new_opts = {k: opts[k] for k in opts if k in runtime_conf_keys}
                 opts.clear()
                 opts.update(new_opts)
-
             finally:
                 if origin_remote is not None:
                     os.environ["SPARK_REMOTE"] = origin_remote
diff --git a/python/pyspark/sql/tests/connect/test_connect_session.py 
b/python/pyspark/sql/tests/connect/test_connect_session.py
index 1ab069a4025c..1fd59609d450 100644
--- a/python/pyspark/sql/tests/connect/test_connect_session.py
+++ b/python/pyspark/sql/tests/connect/test_connect_session.py
@@ -43,6 +43,7 @@ if should_test_connect:
     from pyspark.errors.exceptions.connect import (
         AnalysisException,
         SparkConnectException,
+        SparkConnectGrpcException,
         SparkUpgradeException,
     )
 
@@ -237,7 +238,13 @@ class SparkConnectSessionTests(ReusedConnectTestCase):
 
         class CustomChannelBuilder(ChannelBuilder):
             def toChannel(self):
-                return self._insecure_channel(endpoint)
+                creds = grpc.local_channel_credentials()
+
+                if self.token is not None:
+                    creds = grpc.composite_channel_credentials(
+                        creds, grpc.access_token_call_credentials(self.token)
+                    )
+                return self._secure_channel(endpoint, creds)
 
         session = 
RemoteSparkSession.builder.channelBuilder(CustomChannelBuilder()).create()
         session.sql("select 1 + 1")
@@ -290,6 +297,15 @@ class SparkConnectSessionTests(ReusedConnectTestCase):
         self.assertEqual(session.range(1).first()[0], 0)
         self.assertIsInstance(session, RemoteSparkSession)
 
+    def test_authentication(self):
+        # All servers start with a default token of "deadbeef", so supply in 
invalid one
+        session = 
RemoteSparkSession.builder.remote("sc://localhost/;token=invalid").create()
+
+        with self.assertRaises(SparkConnectGrpcException) as e:
+            session.range(3).collect()
+
+        self.assertTrue("Invalid authentication token" in str(e.exception))
+
 
 @unittest.skipIf(not should_test_connect, connect_requirement_message)
 class SparkConnectSessionWithOptionsTest(unittest.TestCase):
diff --git 
a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py 
b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py
index 47f7d672cc8c..e1b8d7c76d18 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py
@@ -24,7 +24,6 @@ import tempfile
 import unittest
 from typing import cast
 
-from pyspark import SparkConf
 from pyspark.sql.streaming.state import GroupStateTimeout, GroupState
 from pyspark.sql.types import (
     LongType,
@@ -56,7 +55,7 @@ if have_pyarrow:
 class GroupedApplyInPandasWithStateTestsMixin:
     @classmethod
     def conf(cls):
-        cfg = SparkConf()
+        cfg = super().conf()
         cfg.set("spark.sql.shuffle.partitions", "5")
         return cfg
 
diff --git a/python/pyspark/testing/connectutils.py 
b/python/pyspark/testing/connectutils.py
index 7dea8a2103c3..423a717e8ab5 100644
--- a/python/pyspark/testing/connectutils.py
+++ b/python/pyspark/testing/connectutils.py
@@ -155,6 +155,9 @@ class ReusedConnectTestCase(unittest.TestCase, 
SQLTestUtils, PySparkErrorTestUti
             conf._jconf.remove("spark.master")
         conf.set("spark.connect.execute.reattachable.senderMaxStreamDuration", 
"1s")
         conf.set("spark.connect.execute.reattachable.senderMaxStreamSize", 
"123")
+        # Set a static token for all tests so the parallelism doesn't 
overwrite each
+        # tests' environment variables
+        conf.set("spark.connect.authenticate.token", "deadbeef")
         return conf
 
     @classmethod
diff --git 
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientBuilderParseTestSuite.scala
 
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientBuilderParseTestSuite.scala
index b342d5b41569..a3c022066532 100644
--- 
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientBuilderParseTestSuite.scala
+++ 
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientBuilderParseTestSuite.scala
@@ -125,7 +125,7 @@ class SparkConnectClientBuilderParseTestSuite extends 
ConnectFunSuite {
       assert(builder.host === "localhost")
       assert(builder.port === 15002)
       assert(builder.userAgent.contains("_SPARK_CONNECT_SCALA"))
-      assert(builder.sslEnabled)
+      assert(!builder.sslEnabled)
       assert(builder.token.contains("thisismysecret"))
       assert(builder.userId.isEmpty)
       assert(builder.userName.isEmpty)
diff --git 
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
 
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
index acee1b2775f1..3d1ba71b9f90 100644
--- 
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
+++ 
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
@@ -299,8 +299,8 @@ class SparkConnectClientSuite extends ConnectFunSuite with 
BeforeAndAfterEach {
     TestPackURI(s"sc://host:123/;session_id=${UUID.randomUUID().toString}", 
isCorrect = true),
     TestPackURI("sc://host:123/;use_ssl=true;token=mySecretToken", isCorrect = 
true),
     TestPackURI("sc://host:123/;token=mySecretToken;use_ssl=true", isCorrect = 
true),
-    TestPackURI("sc://host:123/;use_ssl=false;token=mySecretToken", isCorrect 
= false),
-    TestPackURI("sc://host:123/;token=mySecretToken;use_ssl=false", isCorrect 
= false),
+    TestPackURI("sc://host:123/;use_ssl=false;token=mySecretToken", isCorrect 
= true),
+    TestPackURI("sc://host:123/;token=mySecretToken;use_ssl=false", isCorrect 
= true),
     TestPackURI("sc://host:123/;param1=value1;param2=value2", isCorrect = 
true),
     TestPackURI(
       "sc://SPARK-45486",
diff --git 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
index c4067ea3ac33..0af7c7b6d97a 100644
--- 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
+++ 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
@@ -762,6 +762,7 @@ object SparkSession extends SparkSessionCompanion with 
Logging {
         (remoteString.exists(_.startsWith("local")) ||
           (remoteString.isDefined && isAPIModeConnect)) &&
         maybeConnectStartScript.exists(Files.exists(_))) {
+        val token = java.util.UUID.randomUUID().toString()
         val serverId = UUID.randomUUID().toString
         server = Some {
           val args =
@@ -779,6 +780,7 @@ object SparkSession extends SparkSessionCompanion with 
Logging {
           pb.environment().remove(SparkConnectClient.SPARK_REMOTE)
           pb.environment().put("SPARK_IDENT_STRING", serverId)
           pb.environment().put("HOSTNAME", "local")
+          pb.environment().put("SPARK_CONNECT_AUTHENTICATE_TOKEN", token)
           pb.start()
         }
 
@@ -800,7 +802,7 @@ object SparkSession extends SparkSessionCompanion with 
Logging {
             }
           }
 
-        System.setProperty("spark.remote", "sc://localhost")
+        System.setProperty("spark.remote", s"sc://localhost/;token=$token")
 
         // scalastyle:off runtimeaddshutdownhook
         Runtime.getRuntime.addShutdownHook(new Thread() {
diff --git 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
index dd241c50c934..57ed45418316 100644
--- 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
+++ 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
@@ -468,8 +468,6 @@ object SparkConnectClient {
      * sc://localhost/;token=aaa;use_ssl=true
      * }}}
      *
-     * Throws exception if the token is set but use_ssl=false.
-     *
      * @param inputToken
      *   the user token.
      * @return
@@ -477,11 +475,7 @@ object SparkConnectClient {
      */
     def token(inputToken: String): Builder = {
       require(inputToken != null && inputToken.nonEmpty)
-      if (_configuration.isSslEnabled.contains(false)) {
-        throw new 
IllegalArgumentException(AUTH_TOKEN_ON_INSECURE_CONN_ERROR_MSG)
-      }
-      _configuration =
-        _configuration.copy(token = Option(inputToken), isSslEnabled = 
Option(true))
+      _configuration = _configuration.copy(token = Option(inputToken))
       this
     }
 
@@ -499,7 +493,6 @@ object SparkConnectClient {
      *   this builder.
      */
     def disableSsl(): Builder = {
-      require(token.isEmpty, AUTH_TOKEN_ON_INSECURE_CONN_ERROR_MSG)
       _configuration = _configuration.copy(isSslEnabled = Option(false))
       this
     }
@@ -737,6 +730,8 @@ object SparkConnectClient {
       grpcMaxMessageSize: Int = ConnectCommon.CONNECT_GRPC_MAX_MESSAGE_SIZE,
       grpcMaxRecursionLimit: Int = 
ConnectCommon.CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT) {
 
+    private def isLocal = host.equals("localhost")
+
     def userContext: proto.UserContext = {
       val builder = proto.UserContext.newBuilder()
       if (userId != null) {
@@ -749,7 +744,7 @@ object SparkConnectClient {
     }
 
     def credentials: ChannelCredentials = {
-      if (isSslEnabled.contains(true)) {
+      if (isSslEnabled.contains(true) || (token.isDefined && !isLocal)) {
         token match {
           case Some(t) =>
             // With access token added in the http header.
@@ -765,10 +760,18 @@ object SparkConnectClient {
     }
 
     def createChannel(): ManagedChannel = {
-      val channelBuilder = Grpc.newChannelBuilderForAddress(host, port, 
credentials)
+      val creds = credentials
+      val channelBuilder = Grpc.newChannelBuilderForAddress(host, port, creds)
+
+      // Workaround LocalChannelCredentials are added in
+      // https://github.com/grpc/grpc-java/issues/9900
+      var metadataWithOptionalToken = metadata
+      if (!isSslEnabled.contains(true) && isLocal && token.isDefined) {
+        metadataWithOptionalToken = metadata + (("Authorization", s"Bearer 
${token.get}"))
+      }
 
-      if (metadata.nonEmpty) {
-        channelBuilder.intercept(new MetadataHeaderClientInterceptor(metadata))
+      if (metadataWithOptionalToken.nonEmpty) {
+        channelBuilder.intercept(new 
MetadataHeaderClientInterceptor(metadataWithOptionalToken))
       }
 
       interceptors.foreach(channelBuilder.intercept(_))
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
index b0c5a2a055b5..9f884b683079 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
@@ -18,6 +18,7 @@ package org.apache.spark.sql.connect.config
 
 import java.util.concurrent.TimeUnit
 
+import org.apache.spark.SparkEnv
 import org.apache.spark.network.util.ByteUnit
 import org.apache.spark.sql.connect.common.config.ConnectCommon
 import org.apache.spark.sql.internal.SQLConf
@@ -313,4 +314,21 @@ object Connect {
       .internal()
       .booleanConf
       .createWithDefault(true)
+
+  val CONNECT_AUTHENTICATE_TOKEN =
+    buildStaticConf("spark.connect.authenticate.token")
+      .doc("A pre-shared token that will be used to authenticate clients. This 
secret must be" +
+        " passed as a bearer token by for clients to connect.")
+      .version("4.0.0")
+      .internal()
+      .stringConf
+      .createOptional
+
+  val CONNECT_AUTHENTICATE_TOKEN_ENV = "SPARK_CONNECT_AUTHENTICATE_TOKEN"
+
+  def getAuthenticateToken: Option[String] = {
+    SparkEnv.get.conf.get(CONNECT_AUTHENTICATE_TOKEN).orElse {
+      Option(System.getenv.get(CONNECT_AUTHENTICATE_TOKEN_ENV))
+    }
+  }
 }
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
index 07c5da9744cc..cc6b58216ed7 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
@@ -31,6 +31,7 @@ import org.apache.spark.internal.LogKeys.{DATAFRAME_ID, 
QUERY_ID, RUN_ID_STRING,
 import org.apache.spark.sql.{DataFrame, Dataset}
 import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, 
AgnosticEncoders}
 import org.apache.spark.sql.connect.common.ForeachWriterPacket
+import org.apache.spark.sql.connect.config.Connect
 import org.apache.spark.sql.connect.service.SessionHolder
 import org.apache.spark.sql.connect.service.SparkConnectService
 import org.apache.spark.sql.streaming.StreamingQuery
@@ -131,7 +132,10 @@ object StreamingForeachBatchHelper extends Logging {
       sessionHolder: SessionHolder): (ForeachBatchFnType, AutoCloseable) = {
 
     val port = SparkConnectService.localPort
-    val connectUrl = s"sc://localhost:$port/;user_id=${sessionHolder.userId}"
+    var connectUrl = s"sc://localhost:$port/;user_id=${sessionHolder.userId}"
+    Connect.getAuthenticateToken.foreach { token =>
+      connectUrl = s"$connectUrl;token=$token"
+    }
     val runner = StreamingPythonRunner(
       pythonFn,
       connectUrl,
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala
index c342050a212e..42c090d43f06 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala
@@ -23,6 +23,7 @@ import org.apache.spark.SparkException
 import org.apache.spark.api.python.{PythonException, PythonWorkerUtils, 
SimplePythonFunction, SpecialLengths, StreamingPythonRunner}
 import org.apache.spark.internal.{Logging, MDC}
 import org.apache.spark.internal.LogKeys.FUNCTION_NAME
+import org.apache.spark.sql.connect.config.Connect
 import org.apache.spark.sql.connect.service.{SessionHolder, 
SparkConnectService}
 import org.apache.spark.sql.streaming.StreamingQueryListener
 
@@ -36,7 +37,10 @@ class PythonStreamingQueryListener(listener: 
SimplePythonFunction, sessionHolder
     with Logging {
 
   private val port = SparkConnectService.localPort
-  private val connectUrl = 
s"sc://localhost:$port/;user_id=${sessionHolder.userId}"
+  private var connectUrl = 
s"sc://localhost:$port/;user_id=${sessionHolder.userId}"
+  Connect.getAuthenticateToken.foreach { token =>
+    connectUrl = s"$connectUrl;token=$token"
+  }
   // Scoped for testing
   private[connect] val runner = StreamingPythonRunner(
     listener,
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/PreSharedKeyAuthenticationInterceptor.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/PreSharedKeyAuthenticationInterceptor.scala
new file mode 100644
index 000000000000..5d7cc65358eb
--- /dev/null
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/PreSharedKeyAuthenticationInterceptor.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.service
+
+import io.grpc.{Metadata, ServerCall, ServerCallHandler, ServerInterceptor, 
Status}
+
+class PreSharedKeyAuthenticationInterceptor(token: String) extends 
ServerInterceptor {
+
+  val authorizationMetadataKey =
+    Metadata.Key.of("Authorization", Metadata.ASCII_STRING_MARSHALLER)
+
+  val expectedValue = s"Bearer $token"
+
+  override def interceptCall[ReqT, RespT](
+      call: ServerCall[ReqT, RespT],
+      metadata: Metadata,
+      next: ServerCallHandler[ReqT, RespT]): ServerCall.Listener[ReqT] = {
+    val authHeaderValue = metadata.get(authorizationMetadataKey)
+
+    if (authHeaderValue == null) {
+      val status = Status.UNAUTHENTICATED.withDescription("No authentication 
token provided")
+      call.close(status, new Metadata())
+      new ServerCall.Listener[ReqT]() {}
+    } else if (authHeaderValue != expectedValue) {
+      val status = Status.UNAUTHENTICATED.withDescription("Invalid 
authentication token")
+      call.close(status, new Metadata())
+      new ServerCall.Listener[ReqT]() {}
+    } else {
+      next.startCall(call, metadata)
+    }
+  }
+}
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
index e62c19b66c8e..8fa64ddcce49 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
@@ -39,7 +39,7 @@ import org.apache.spark.internal.{Logging, MDC}
 import org.apache.spark.internal.LogKeys.HOST
 import org.apache.spark.internal.config.UI.UI_ENABLED
 import org.apache.spark.scheduler.{LiveListenerBus, SparkListenerEvent}
-import 
org.apache.spark.sql.connect.config.Connect.{CONNECT_GRPC_BINDING_ADDRESS, 
CONNECT_GRPC_BINDING_PORT, CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT, 
CONNECT_GRPC_MAX_INBOUND_MESSAGE_SIZE, CONNECT_GRPC_PORT_MAX_RETRIES}
+import org.apache.spark.sql.connect.config.Connect.{getAuthenticateToken, 
CONNECT_GRPC_BINDING_ADDRESS, CONNECT_GRPC_BINDING_PORT, 
CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT, CONNECT_GRPC_MAX_INBOUND_MESSAGE_SIZE, 
CONNECT_GRPC_PORT_MAX_RETRIES}
 import org.apache.spark.sql.connect.execution.ConnectProgressExecutionListener
 import org.apache.spark.sql.connect.ui.{SparkConnectServerAppStatusStore, 
SparkConnectServerListener, SparkConnectServerTab}
 import org.apache.spark.sql.connect.utils.ErrorUtils
@@ -381,6 +381,10 @@ object SparkConnectService extends Logging {
       
sb.maxInboundMessageSize(SparkEnv.get.conf.get(CONNECT_GRPC_MAX_INBOUND_MESSAGE_SIZE).toInt)
         .addService(sparkConnectService)
 
+      getAuthenticateToken.foreach { token =>
+        sb.intercept(new PreSharedKeyAuthenticationInterceptor(token))
+      }
+
       // Add all registered interceptors to the server builder.
       SparkConnectInterceptorRegistry.chainInterceptors(sb, 
configuredInterceptors)
 
diff --git 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectAuthSuite.scala
 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectAuthSuite.scala
new file mode 100644
index 000000000000..30f186ab7c2b
--- /dev/null
+++ 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectAuthSuite.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connect.service
+
+// import io.grpc.StatusRuntimeException
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.connect.{SparkConnectServerTest, SparkSession}
+import org.apache.spark.sql.connect.service.SparkConnectService
+
+class SparkConnectAuthSuite extends SparkConnectServerTest {
+  override protected def sparkConf = {
+    super.sparkConf.set("spark.connect.authenticate.token", "deadbeef")
+  }
+
+  test("Test local authentication") {
+    val session = SparkSession
+      .builder()
+      
.remote(s"sc://localhost:${SparkConnectService.localPort}/;token=deadbeef")
+      .create()
+    session.range(5).collect()
+
+    val invalidSession = SparkSession
+      .builder()
+      
.remote(s"sc://localhost:${SparkConnectService.localPort}/;token=invalid")
+      .create()
+    val exception = intercept[SparkException] {
+      invalidSession.range(5).collect()
+    }
+    assert(exception.getMessage.contains("Invalid authentication token"))
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to