This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new c4619b503e5 [SPARK-41738][CONNECT] Mix ClientId in SparkSession cache c4619b503e5 is described below commit c4619b503e58da38c6223020a73c2ca2e0a8c0fa Author: Martin Grund <martin.gr...@databricks.com> AuthorDate: Wed Dec 28 20:04:45 2022 +0900 [SPARK-41738][CONNECT] Mix ClientId in SparkSession cache ### What changes were proposed in this pull request? This PR mixes the client ID into the cache for the SparkSessions on the server. This is necessary to allow to concurrent SparkSessions of the same user to run without interfering. The client ID was added to be used in that way, but until now the functionality was not implemented. On the client side, the Python client now validates that the result received is actually from the same session. ### Why are the changes needed? Stability ### Does this PR introduce _any_ user-facing change? NO ### How was this patch tested? Existing UT Closes #39256 from grundprinzip/SPARK-41738. Authored-by: Martin Grund <martin.gr...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../spark/sql/connect/service/SparkConnectService.scala | 14 ++++++++------ .../sql/connect/service/SparkConnectStreamHandler.scala | 4 +++- .../sql/connect/planner/SparkConnectServiceSuite.scala | 3 ++- python/pyspark/sql/connect/client.py | 12 ++++++++++++ 4 files changed, 25 insertions(+), 8 deletions(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala index 3046c8eebfc..bfcea3d2252 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala @@ -94,7 +94,9 @@ class SparkConnectService(debug: Boolean) s"${request.getPlan.getOpTypeCase} not supported for analysis.")) } val session = - SparkConnectService.getOrCreateIsolatedSession(request.getUserContext.getUserId).session + SparkConnectService + .getOrCreateIsolatedSession(request.getUserContext.getUserId, request.getClientId) + .session val explainMode = request.getExplain.getExplainMode match { case proto.Explain.ExplainMode.SIMPLE => SimpleMode @@ -145,7 +147,7 @@ class SparkConnectService(debug: Boolean) * @param userId * @param session */ -case class SessionHolder(userId: String, session: SparkSession) +case class SessionHolder(userId: String, sessionId: String, session: SparkSession) /** * Static instance of the SparkConnectService. @@ -161,7 +163,7 @@ object SparkConnectService { // Type alias for the SessionCacheKey. Right now this is a String but allows us to switch to a // different or complex type easily. - private type SessionCacheKey = String; + private type SessionCacheKey = (String, String); private var server: Server = _ @@ -183,11 +185,11 @@ object SparkConnectService { /** * Based on the `key` find or create a new SparkSession. */ - def getOrCreateIsolatedSession(key: SessionCacheKey): SessionHolder = { + def getOrCreateIsolatedSession(userId: String, sessionId: String): SessionHolder = { userSessionMapping.get( - key, + (userId, sessionId), () => { - SessionHolder(key, newIsolatedSession()) + SessionHolder(userId, sessionId, newIsolatedSession()) }) } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala index fcae3501cef..9631b93f6e9 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala @@ -42,7 +42,9 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp def handle(v: ExecutePlanRequest): Unit = { val session = - SparkConnectService.getOrCreateIsolatedSession(v.getUserContext.getUserId).session + SparkConnectService + .getOrCreateIsolatedSession(v.getUserContext.getUserId, v.getClientId) + .session v.getPlan.getOpTypeCase match { case proto.Plan.OpTypeCase.COMMAND => handleCommand(session, v) case proto.Plan.OpTypeCase.ROOT => handlePlan(session, v) diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala index 8f4268b904b..6dcce0926dc 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala @@ -153,7 +153,7 @@ class SparkConnectServiceSuite extends SharedSparkSession { val instance = new SparkConnectService(false) // Add an always crashing UDF - val session = SparkConnectService.getOrCreateIsolatedSession("c1").session + val session = SparkConnectService.getOrCreateIsolatedSession("c1", "session").session val instaKill: Long => Long = { _ => throw new Exception("Kaboom") } @@ -172,6 +172,7 @@ class SparkConnectServiceSuite extends SharedSparkSession { .newBuilder() .setPlan(plan) .setUserContext(context) + .setClientId("session") .build() // The observer is executed inside this thread. So diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py index 2e32a87676f..e258dbd92b4 100644 --- a/python/pyspark/sql/connect/client.py +++ b/python/pyspark/sql/connect/client.py @@ -290,6 +290,10 @@ class SparkConnectClient(object): # Parse the connection string. self._builder = ChannelBuilder(connectionString) self._user_id = None + # Generate a unique session ID for this client. This UUID must be unique to allow + # concurrent Spark sessions of the same user. If the channel is closed, creating + # a new client will create a new session ID. + self._session_id = str(uuid.uuid4()) if self._builder.userId is not None: self._user_id = self._builder.userId elif userId is not None: @@ -367,6 +371,7 @@ class SparkConnectClient(object): def _execute_plan_request_with_metadata(self) -> pb2.ExecutePlanRequest: req = pb2.ExecutePlanRequest() + req.client_id = self._session_id req.client_type = "_SPARK_CONNECT_PYTHON" if self._user_id: req.user_context.user_id = self._user_id @@ -374,6 +379,7 @@ class SparkConnectClient(object): def _analyze_plan_request_with_metadata(self) -> pb2.AnalyzePlanRequest: req = pb2.AnalyzePlanRequest() + req.client_id = self._session_id req.client_type = "_SPARK_CONNECT_PYTHON" if self._user_id: req.user_context.user_id = self._user_id @@ -401,6 +407,8 @@ class SparkConnectClient(object): req.explain.explain_mode = pb2.Explain.ExplainMode.FORMATTED resp = self._stub.AnalyzePlan(req, metadata=self._builder.metadata()) + if resp.client_id != self._session_id: + raise ValueError("Received incorrect session identifier for request.") return AnalyzeResult.fromProto(resp) def _process_batch(self, arrow_batch: pb2.ExecutePlanResponse.ArrowBatch) -> "pandas.DataFrame": @@ -409,6 +417,8 @@ class SparkConnectClient(object): def _execute(self, req: pb2.ExecutePlanRequest) -> None: for b in self._stub.ExecutePlan(req, metadata=self._builder.metadata()): + if b.client_id != self._session_id: + raise ValueError("Received incorrect session identifier for request.") continue return @@ -419,6 +429,8 @@ class SparkConnectClient(object): result_dfs = [] for b in self._stub.ExecutePlan(req, metadata=self._builder.metadata()): + if b.client_id != self._session_id: + raise ValueError("Received incorrect session identifier for request.") if b.metrics is not None: m = b.metrics if b.HasField("arrow_batch"): --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org