This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new c4619b503e5 [SPARK-41738][CONNECT] Mix ClientId in SparkSession cache
c4619b503e5 is described below
commit c4619b503e58da38c6223020a73c2ca2e0a8c0fa
Author: Martin Grund <[email protected]>
AuthorDate: Wed Dec 28 20:04:45 2022 +0900
[SPARK-41738][CONNECT] Mix ClientId in SparkSession cache
### What changes were proposed in this pull request?
This PR mixes the client ID into the cache for the SparkSessions on the
server. This is necessary to allow to concurrent SparkSessions of the same user
to run without interfering.
The client ID was added to be used in that way, but until now the
functionality was not implemented.
On the client side, the Python client now validates that the result
received is actually from the same session.
### Why are the changes needed?
Stability
### Does this PR introduce _any_ user-facing change?
NO
### How was this patch tested?
Existing UT
Closes #39256 from grundprinzip/SPARK-41738.
Authored-by: Martin Grund <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../spark/sql/connect/service/SparkConnectService.scala | 14 ++++++++------
.../sql/connect/service/SparkConnectStreamHandler.scala | 4 +++-
.../sql/connect/planner/SparkConnectServiceSuite.scala | 3 ++-
python/pyspark/sql/connect/client.py | 12 ++++++++++++
4 files changed, 25 insertions(+), 8 deletions(-)
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
index 3046c8eebfc..bfcea3d2252 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
@@ -94,7 +94,9 @@ class SparkConnectService(debug: Boolean)
s"${request.getPlan.getOpTypeCase} not supported for analysis."))
}
val session =
-
SparkConnectService.getOrCreateIsolatedSession(request.getUserContext.getUserId).session
+ SparkConnectService
+ .getOrCreateIsolatedSession(request.getUserContext.getUserId,
request.getClientId)
+ .session
val explainMode = request.getExplain.getExplainMode match {
case proto.Explain.ExplainMode.SIMPLE => SimpleMode
@@ -145,7 +147,7 @@ class SparkConnectService(debug: Boolean)
* @param userId
* @param session
*/
-case class SessionHolder(userId: String, session: SparkSession)
+case class SessionHolder(userId: String, sessionId: String, session:
SparkSession)
/**
* Static instance of the SparkConnectService.
@@ -161,7 +163,7 @@ object SparkConnectService {
// Type alias for the SessionCacheKey. Right now this is a String but allows
us to switch to a
// different or complex type easily.
- private type SessionCacheKey = String;
+ private type SessionCacheKey = (String, String);
private var server: Server = _
@@ -183,11 +185,11 @@ object SparkConnectService {
/**
* Based on the `key` find or create a new SparkSession.
*/
- def getOrCreateIsolatedSession(key: SessionCacheKey): SessionHolder = {
+ def getOrCreateIsolatedSession(userId: String, sessionId: String):
SessionHolder = {
userSessionMapping.get(
- key,
+ (userId, sessionId),
() => {
- SessionHolder(key, newIsolatedSession())
+ SessionHolder(userId, sessionId, newIsolatedSession())
})
}
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
index fcae3501cef..9631b93f6e9 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
@@ -42,7 +42,9 @@ class SparkConnectStreamHandler(responseObserver:
StreamObserver[ExecutePlanResp
def handle(v: ExecutePlanRequest): Unit = {
val session =
-
SparkConnectService.getOrCreateIsolatedSession(v.getUserContext.getUserId).session
+ SparkConnectService
+ .getOrCreateIsolatedSession(v.getUserContext.getUserId, v.getClientId)
+ .session
v.getPlan.getOpTypeCase match {
case proto.Plan.OpTypeCase.COMMAND => handleCommand(session, v)
case proto.Plan.OpTypeCase.ROOT => handlePlan(session, v)
diff --git
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
index 8f4268b904b..6dcce0926dc 100644
---
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
+++
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
@@ -153,7 +153,7 @@ class SparkConnectServiceSuite extends SharedSparkSession {
val instance = new SparkConnectService(false)
// Add an always crashing UDF
- val session = SparkConnectService.getOrCreateIsolatedSession("c1").session
+ val session = SparkConnectService.getOrCreateIsolatedSession("c1",
"session").session
val instaKill: Long => Long = { _ =>
throw new Exception("Kaboom")
}
@@ -172,6 +172,7 @@ class SparkConnectServiceSuite extends SharedSparkSession {
.newBuilder()
.setPlan(plan)
.setUserContext(context)
+ .setClientId("session")
.build()
// The observer is executed inside this thread. So
diff --git a/python/pyspark/sql/connect/client.py
b/python/pyspark/sql/connect/client.py
index 2e32a87676f..e258dbd92b4 100644
--- a/python/pyspark/sql/connect/client.py
+++ b/python/pyspark/sql/connect/client.py
@@ -290,6 +290,10 @@ class SparkConnectClient(object):
# Parse the connection string.
self._builder = ChannelBuilder(connectionString)
self._user_id = None
+ # Generate a unique session ID for this client. This UUID must be
unique to allow
+ # concurrent Spark sessions of the same user. If the channel is
closed, creating
+ # a new client will create a new session ID.
+ self._session_id = str(uuid.uuid4())
if self._builder.userId is not None:
self._user_id = self._builder.userId
elif userId is not None:
@@ -367,6 +371,7 @@ class SparkConnectClient(object):
def _execute_plan_request_with_metadata(self) -> pb2.ExecutePlanRequest:
req = pb2.ExecutePlanRequest()
+ req.client_id = self._session_id
req.client_type = "_SPARK_CONNECT_PYTHON"
if self._user_id:
req.user_context.user_id = self._user_id
@@ -374,6 +379,7 @@ class SparkConnectClient(object):
def _analyze_plan_request_with_metadata(self) -> pb2.AnalyzePlanRequest:
req = pb2.AnalyzePlanRequest()
+ req.client_id = self._session_id
req.client_type = "_SPARK_CONNECT_PYTHON"
if self._user_id:
req.user_context.user_id = self._user_id
@@ -401,6 +407,8 @@ class SparkConnectClient(object):
req.explain.explain_mode = pb2.Explain.ExplainMode.FORMATTED
resp = self._stub.AnalyzePlan(req, metadata=self._builder.metadata())
+ if resp.client_id != self._session_id:
+ raise ValueError("Received incorrect session identifier for
request.")
return AnalyzeResult.fromProto(resp)
def _process_batch(self, arrow_batch: pb2.ExecutePlanResponse.ArrowBatch)
-> "pandas.DataFrame":
@@ -409,6 +417,8 @@ class SparkConnectClient(object):
def _execute(self, req: pb2.ExecutePlanRequest) -> None:
for b in self._stub.ExecutePlan(req,
metadata=self._builder.metadata()):
+ if b.client_id != self._session_id:
+ raise ValueError("Received incorrect session identifier for
request.")
continue
return
@@ -419,6 +429,8 @@ class SparkConnectClient(object):
result_dfs = []
for b in self._stub.ExecutePlan(req,
metadata=self._builder.metadata()):
+ if b.client_id != self._session_id:
+ raise ValueError("Received incorrect session identifier for
request.")
if b.metrics is not None:
m = b.metrics
if b.HasField("arrow_batch"):
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]