This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 8e75fc96af99 [SPARK-54637][CONNECT][TESTS] Add SQL API test helpers to
SparkConnectServerTest
8e75fc96af99 is described below
commit 8e75fc96af99b128c7db7abb5518bd57836ca757
Author: Juliusz Sompolski <Juliusz Sompolski>
AuthorDate: Tue Dec 9 08:02:21 2025 -0800
[SPARK-54637][CONNECT][TESTS] Add SQL API test helpers to
SparkConnectServerTest
### What changes were proposed in this pull request?
Add testing helpers to SparkConnectServerTest to enable using connect Spark
SQL APIs in tests using that helper.
### Why are the changes needed?
In Spark 3.5, a testing trait SparkConnectServerTest was introduced that
helped test Spark Connect Service with a SparkConnectClient in the same JVM
proccess, which tested real Spark Connect code paths (SparkConnectClient
communicating with the server over actual connection to the localhost server).
Before that, using RemoteSparkSession, server was started in a separate process.
It helped
* testability: can trigger stuff from the client, then have verification
code checking stuff server side. Can also do some more internal server side
setup to test specific things.
* debugging, as both client and server can be easily connected to by a
debugger.
At that time, it was impossible to test Spark Connect client SQL APIs
(SparkSession, Dataset) this way, because they were in the same namespace as
server, and hence couldn't be classloaded together.
Since Spark 4.0, there is a new API layer that makes it possible for
connect and classic implementation of the interfaces to coexist. With that,
testing can be extended to use actual SparkSession and other APIs, instead of
having to construct tests using more raw APIs.
### Does this PR introduce _any_ user-facing change?
No. It's testing only.
### How was this patch tested?
Added SparkConnectServerTestSuite showcasing the new APIs.
### Was this patch authored or co-authored using generative AI tooling?
Generated-by: Claude Code opus 4.5
Closes #53384 from juliuszsompolski/spark-connect-server-client-test.
Authored-by: Juliusz Sompolski <Juliusz Sompolski>
Signed-off-by: Dongjoon Hyun <[email protected]>
---
.../spark/sql/connect/SparkConnectServerTest.scala | 72 ++++++-
.../sql/connect/SparkConnectServerTestSuite.scala | 207 +++++++++++++++++++++
2 files changed, 278 insertions(+), 1 deletion(-)
diff --git
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
index 1b2b7ab42029..7b9052bb9d2c 100644
---
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
+++
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
@@ -26,14 +26,17 @@ import org.scalatest.time.Span
import org.scalatest.time.SpanSugar._
import org.apache.spark.connect.proto
+import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.classic
+import org.apache.spark.sql.connect
import org.apache.spark.sql.connect.client.{CloseableIterator,
CustomSparkConnectBlockingStub, ExecutePlanResponseReattachableIterator,
RetryPolicy, SparkConnectClient, SparkConnectStubState}
import org.apache.spark.sql.connect.client.arrow.ArrowSerializer
import org.apache.spark.sql.connect.common.config.ConnectCommon
import org.apache.spark.sql.connect.config.Connect
import org.apache.spark.sql.connect.dsl.MockRemoteSession
import org.apache.spark.sql.connect.dsl.plans._
-import org.apache.spark.sql.connect.service.{ExecuteHolder,
SparkConnectService}
+import org.apache.spark.sql.connect.service.{ExecuteHolder, SessionKey,
SparkConnectService}
import org.apache.spark.sql.test.SharedSparkSession
/**
@@ -320,4 +323,71 @@ trait SparkConnectServerTest extends SharedSparkSession {
val plan = buildPlan(query)
runQuery(plan, queryTimeout, iterSleep)
}
+
+ /**
+ * Helper method to create a connect SparkSession that connects to the
localhost server. Similar
+ * to withClient, but provides a full SparkSession API instead of just a
client.
+ *
+ * @param sessionId
+ * Optional session ID (defaults to defaultSessionId)
+ * @param userId
+ * Optional user ID (defaults to defaultUserId)
+ * @param f
+ * Function to execute with the session
+ */
+ protected def withSession(sessionId: String = defaultSessionId, userId:
String = defaultUserId)(
+ f: SparkSession => Unit): Unit = {
+ withSession(f, sessionId, userId)
+ }
+
+ /**
+ * Helper method to create a connect SparkSession with default session and
user IDs.
+ *
+ * @param f
+ * Function to execute with the session
+ */
+ protected def withSession(f: SparkSession => Unit): Unit = {
+ withSession(f, defaultSessionId, defaultUserId)
+ }
+
+ private def withSession(f: SparkSession => Unit, sessionId: String, userId:
String): Unit = {
+ val client = SparkConnectClient
+ .builder()
+ .port(serverPort)
+ .sessionId(sessionId)
+ .userId(userId)
+ .build()
+
+ val session = connect.SparkSession
+ .builder()
+ .client(client)
+ .create()
+ try f(session)
+ finally {
+ session.close()
+ }
+ }
+
+ /**
+ * Get the server-side SparkSession corresponding to a client SparkSession.
+ *
+ * This helper takes a sql.SparkSession (which is assumed to be a
connect.SparkSession),
+ * extracts the userId and sessionId from it, and looks up the corresponding
server-side classic
+ * SparkSession using SparkConnectSessionManager.
+ *
+ * @param clientSession
+ * The client SparkSession (must be a connect.SparkSession)
+ * @return
+ * The server-side classic SparkSession
+ */
+ protected def getServerSession(clientSession: SparkSession):
classic.SparkSession = {
+ val connectSession = clientSession.asInstanceOf[connect.SparkSession]
+ val userId = connectSession.client.userId
+ val sessionId = connectSession.sessionId
+ val key = SessionKey(userId, sessionId)
+ SparkConnectService.sessionManager
+ .getIsolatedSessionIfPresent(key)
+ .get
+ .session
+ }
}
diff --git
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTestSuite.scala
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTestSuite.scala
new file mode 100644
index 000000000000..c14114ced663
--- /dev/null
+++
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTestSuite.scala
@@ -0,0 +1,207 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connect
+
+import org.scalatest.time.SpanSugar._
+
+/**
+ * Test suite showcasing the APIs provided by SparkConnectServerTest trait.
+ *
+ * This suite demonstrates:
+ * - Session and client helper methods (withSession, withClient,
getServerSession)
+ * - Low-level stub helpers (withRawBlockingStub, withCustomBlockingStub)
+ * - Plan building helpers (buildPlan, buildExecutePlanRequest, etc.)
+ * - Assertion helpers for execution state
+ */
+class SparkConnectServerTestSuite extends SparkConnectServerTest {
+
+ test("withSession: execute SQL and collect results") {
+ withSession { session =>
+ val df = session.sql("SELECT 1 as value")
+ val result = df.collect()
+ assert(result.length == 1)
+ assert(result(0).getInt(0) == 1)
+ }
+ }
+
+ test("withSession: with custom session and user IDs") {
+ val customSessionId = java.util.UUID.randomUUID().toString
+ val customUserId = "test-user"
+ withSession(sessionId = customSessionId, userId = customUserId) { session
=>
+ val df = session.sql("SELECT 'hello' as greeting")
+ val result = df.collect()
+ assert(result.length == 1)
+ assert(result(0).getString(0) == "hello")
+ }
+ }
+
+ test("withSession: DataFrame operations") {
+ withSession { session =>
+ val df = session.range(10)
+ assert(df.count() == 10)
+
+ val sum = df.selectExpr("sum(id)").collect()(0).getLong(0)
+ assert(sum == 45) // 0 + 1 + ... + 9 = 45
+ }
+ }
+
+ test("withClient: execute plan and iterate results") {
+ withClient { client =>
+ val plan = buildPlan("SELECT 1 as x, 2 as y")
+ val iter = client.execute(plan)
+ var hasResults = false
+ while (iter.hasNext) {
+ iter.next()
+ hasResults = true
+ }
+ assert(hasResults)
+ }
+ }
+
+ test("withClient: with custom session and user IDs") {
+ val customSessionId = java.util.UUID.randomUUID().toString
+ val customUserId = "custom-user"
+ withClient(sessionId = customSessionId, userId = customUserId) { client =>
+ val plan = buildPlan("SELECT 42")
+ val iter = client.execute(plan)
+ while (iter.hasNext) iter.next()
+ }
+ }
+
+ test("getServerSession: returns server-side classic session") {
+ withSession { clientSession =>
+ clientSession.sql("SELECT 1").collect()
+
+ val serverSession = getServerSession(clientSession)
+
+ assert(serverSession != null)
+ assert(serverSession.sparkContext != null)
+ }
+ }
+
+ test("getServerSession: client and server share configuration") {
+ withSession { clientSession =>
+ clientSession.sql("SET spark.sql.shuffle.partitions=17").collect()
+
+ val serverSession = getServerSession(clientSession)
+ assert(serverSession.conf.get("spark.sql.shuffle.partitions") == "17")
+ }
+ }
+
+ test("getServerSession: register and use temporary view from server") {
+ withSession { clientSession =>
+ clientSession.sql("SELECT 1 as a, 2 as b").collect()
+
+ val serverSession = getServerSession(clientSession)
+
+ // Create a temp view on the server side
+ import serverSession.implicits._
+ val serverDf = Seq((100, "server"), (200, "side")).toDF("num", "source")
+ serverDf.createOrReplaceTempView("server_view")
+
+ // Access the view from the client
+ val result = clientSession.sql("SELECT * FROM server_view ORDER BY
num").collect()
+ assert(result.length == 2)
+ assert(result(0).getInt(0) == 100)
+ assert(result(0).getString(1) == "server")
+ assert(result(1).getInt(0) == 200)
+ assert(result(1).getString(1) == "side")
+ }
+ }
+
+ test("withRawBlockingStub: execute plan via raw gRPC stub") {
+ withRawBlockingStub { stub =>
+ val request = buildExecutePlanRequest(buildPlan("SELECT 'raw' as mode"))
+ val iter = stub.executePlan(request)
+ assert(iter.hasNext)
+ while (iter.hasNext) iter.next()
+ }
+ }
+
+ test("withCustomBlockingStub: execute plan via custom blocking stub") {
+ withCustomBlockingStub() { stub =>
+ val request = buildExecutePlanRequest(buildPlan("SELECT 'custom' as
mode"))
+ val iter = stub.executePlan(request)
+ while (iter.hasNext) iter.next()
+ }
+ }
+
+ test("buildPlan: creates plan from SQL query") {
+ val plan = buildPlan("SELECT 1, 2, 3")
+ assert(plan.hasRoot)
+ }
+
+ test("buildSqlCommandPlan: creates command plan") {
+ val plan = buildSqlCommandPlan("SET spark.sql.adaptive.enabled=true")
+ assert(plan.hasCommand)
+ assert(plan.getCommand.hasSqlCommand)
+ }
+
+ test("buildLocalRelation: creates plan from local data") {
+ val data = Seq((1, "a"), (2, "b"), (3, "c"))
+ val plan = buildLocalRelation(data)
+ assert(plan.hasRoot)
+ assert(plan.getRoot.hasLocalRelation)
+ }
+
+ test("buildExecutePlanRequest: creates request with options") {
+ val plan = buildPlan("SELECT 1")
+ val request = buildExecutePlanRequest(plan)
+ assert(request.hasPlan)
+ assert(request.hasUserContext)
+ assert(request.getSessionId == defaultSessionId)
+ }
+
+ test("buildExecutePlanRequest: with custom session and operation IDs") {
+ val plan = buildPlan("SELECT 1")
+ val customSessionId = "my-session"
+ val customOperationId = "my-operation"
+ val request =
+ buildExecutePlanRequest(plan, sessionId = customSessionId, operationId =
customOperationId)
+ assert(request.getSessionId == customSessionId)
+ assert(request.getOperationId == customOperationId)
+ }
+
+ test("runQuery: executes query string with timeout") {
+ runQuery("SELECT * FROM range(100)", 30.seconds)
+ }
+
+ test("runQuery: executes plan with timeout and iter sleep") {
+ val plan = buildPlan("SELECT * FROM range(10)")
+ runQuery(plan, 30.seconds, iterSleep = 10)
+ }
+
+ test("assertNoActiveExecutions: verifies clean state") {
+ assertNoActiveExecutions()
+ }
+
+ test("assertNoActiveRpcs: verifies no active RPCs") {
+ assertNoActiveRpcs()
+ }
+
+ test("eventuallyGetExecutionHolder: retrieves active execution") {
+ withRawBlockingStub { stub =>
+ val request = buildExecutePlanRequest(buildPlan("SELECT * FROM
range(1000000)"))
+ val iter = stub.executePlan(request)
+ iter.hasNext // trigger execution
+
+ val holder = eventuallyGetExecutionHolder
+ assert(holder != null)
+ assert(holder.operationId == request.getOperationId)
+ }
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]