This is an automated email from the ASF dual-hosted git repository.
hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new c17c77e8cd9 [SPARK-43418][CONNECT] Add SparkSession.Builder.getOrCreate
c17c77e8cd9 is described below
commit c17c77e8cd91b722979e18b54fb1a7e6a6f18e74
Author: Herman van Hovell <[email protected]>
AuthorDate: Tue May 9 15:14:48 2023 -0400
[SPARK-43418][CONNECT] Add SparkSession.Builder.getOrCreate
### What changes were proposed in this pull request?
This PR adds `SparkSession.Builder.getOrCreate()` to the scala client.
`getOrCreate` is used in many existing examples so it is good to support it.
Spark Connect is a bit different from the old API in that it allows you to
connect to multiple servers in the same process. This means that we cannot
entirely match the existing semantics (one session for all). I opted to cache a
number of `SparkSessions`, and make `getOrCreate` return a cached session if
they share the same client configuration. I have also added a `create()`
method for cases where you want a newly instantiated session.
### Why are the changes needed?
Improve compatibility with the existing code.
### Does this PR introduce _any_ user-facing change?
Yes, it adds API.
### How was this patch tested?
Added tests for this.
Closes #41097 from hvanhovell/SPARK-43418.
Authored-by: Herman van Hovell <[email protected]>
Signed-off-by: Herman van Hovell <[email protected]>
---
.../scala/org/apache/spark/sql/SparkSession.scala | 81 +++++++++--
.../apache/spark/sql/application/ConnectRepl.scala | 2 +-
.../sql/connect/client/SparkConnectClient.scala | 159 ++++++++++++---------
.../org/apache/spark/sql/ClientE2ETestSuite.scala | 2 +-
.../scala/org/apache/spark/sql/DatasetSuite.scala | 6 +-
.../apache/spark/sql/PlanGenerationTestSuite.scala | 4 +-
.../apache/spark/sql/SQLImplicitsTestSuite.scala | 5 +-
.../org/apache/spark/sql/SparkSessionSuite.scala | 76 ++++++++++
.../connect/client/SparkConnectClientSuite.scala | 2 +-
.../connect/client/util/RemoteSparkSession.scala | 6 +-
10 files changed, 245 insertions(+), 98 deletions(-)
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
index bc6cf32379f..4e5474a33b7 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -24,6 +24,7 @@ import java.util.concurrent.atomic.AtomicLong
import scala.collection.JavaConverters._
import scala.reflect.runtime.universe.TypeTag
+import com.google.common.cache.{CacheBuilder, CacheLoader}
import org.apache.arrow.memory.RootAllocator
import org.apache.spark.annotation.{DeveloperApi, Experimental}
@@ -35,6 +36,7 @@ import org.apache.spark.sql.catalyst.{JavaTypeInference,
ScalaReflection}
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder}
import
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BoxedLongEncoder,
UnboundRowEncoder}
import org.apache.spark.sql.connect.client.{ClassFinder, SparkConnectClient,
SparkResult}
+import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration
import org.apache.spark.sql.connect.client.util.{Cleaner, ConvertToArrow}
import
org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProto
import org.apache.spark.sql.internal.CatalogImpl
@@ -55,14 +57,12 @@ import org.apache.spark.sql.types.StructType
*
* {{{
* SparkSession.builder
- * .master("local")
- * .appName("Word Count")
- * .config("spark.some.config.option", "some-value")
+ * .remote("sc://localhost:15001/myapp")
* .getOrCreate()
* }}}
*/
class SparkSession private[sql] (
- private val client: SparkConnectClient,
+ private[sql] val client: SparkConnectClient,
private val cleaner: Cleaner,
private val planIdGenerator: AtomicLong)
extends Serializable
@@ -395,7 +395,7 @@ class SparkSession private[sql] (
// scalastyle:on
def newSession(): SparkSession = {
- SparkSession.builder().client(client.copy()).build()
+ SparkSession.builder().client(client.copy()).create()
}
private def range(
@@ -569,14 +569,38 @@ class SparkSession private[sql] (
override def close(): Unit = {
client.shutdown()
allocator.close()
+ SparkSession.onSessionClose(this)
}
}
// The minimal builder needed to create a spark session.
// TODO: implements all methods mentioned in the scaladoc of [[SparkSession]]
object SparkSession extends Logging {
+ private val MAX_CACHED_SESSIONS = 100
private val planIdGenerator = new AtomicLong
+ private val sessions = CacheBuilder
+ .newBuilder()
+ .weakValues()
+ .maximumSize(MAX_CACHED_SESSIONS)
+ .build(new CacheLoader[Configuration, SparkSession] {
+ override def load(c: Configuration): SparkSession = create(c)
+ })
+
+ /**
+ * Create a new [[SparkSession]] based on the connect client
[[Configuration]].
+ */
+ private[sql] def create(configuration: Configuration): SparkSession = {
+ new SparkSession(new SparkConnectClient(configuration), cleaner,
planIdGenerator)
+ }
+
+ /**
+ * Hook called when a session is closed.
+ */
+ private[sql] def onSessionClose(session: SparkSession): Unit = {
+ sessions.invalidate(session.client.configuration)
+ }
+
def builder(): Builder = new Builder()
private[sql] lazy val cleaner = {
@@ -586,23 +610,56 @@ object SparkSession extends Logging {
}
class Builder() extends Logging {
- private var _client: SparkConnectClient = _
+ private val builder = SparkConnectClient.builder()
+ private var client: SparkConnectClient = _
def remote(connectionString: String): Builder = {
-
client(SparkConnectClient.builder().connectionString(connectionString).build())
+ builder.connectionString(connectionString)
this
}
private[sql] def client(client: SparkConnectClient): Builder = {
- _client = client
+ this.client = client
this
}
- def build(): SparkSession = {
- if (_client == null) {
- _client = SparkConnectClient.builder().build()
+ private def tryCreateSessionFromClient(): Option[SparkSession] = {
+ if (client != null) {
+ Option(new SparkSession(client, cleaner, planIdGenerator))
+ } else {
+ None
}
- new SparkSession(_client, cleaner, planIdGenerator)
+ }
+
+ /**
+ * Build the [[SparkSession]].
+ *
+ * This will always return a newly created session.
+ */
+ @deprecated(message = "Please use create() instead.", since = "3.5.0")
+ def build(): SparkSession = create()
+
+ /**
+ * Create a new [[SparkSession]].
+ *
+ * This will always return a newly created session.
+ *
+ * @since 3.5.0
+ */
+ def create(): SparkSession = {
+
tryCreateSessionFromClient().getOrElse(SparkSession.this.create(builder.configuration))
+ }
+
+ /**
+ * Get or create a [[SparkSession]].
+ *
+ * If a session exist with the same configuration that is returned instead
of creating a new
+ * session.
+ *
+ * @since 3.5.0
+ */
+ def getOrCreate(): SparkSession = {
+
tryCreateSessionFromClient().getOrElse(sessions.get(builder.configuration))
}
}
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala
index d119fd60230..d3bd0a6a8b0 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala
@@ -74,7 +74,7 @@ object ConnectRepl {
}
// Build the session.
- val spark = SparkSession.builder().client(client).build()
+ val spark = SparkSession.builder().client(client).getOrCreate()
val sparkBind = new Bind("spark", spark)
// Add the proper imports and register a [[ClassFinder]].
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
index ca00eb74a20..a5aabe62ae4 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
@@ -17,14 +17,14 @@
package org.apache.spark.sql.connect.client
-import com.google.protobuf.ByteString
-import io.grpc.{CallCredentials, CallOptions, Channel, ClientCall,
ClientInterceptor, CompositeChannelCredentials, ForwardingClientCall, Grpc,
InsecureChannelCredentials, ManagedChannel, ManagedChannelBuilder, Metadata,
MethodDescriptor, Status, TlsChannelCredentials}
import java.net.URI
import java.nio.ByteBuffer
import java.nio.charset.StandardCharsets
import java.util.UUID
import java.util.concurrent.Executor
-import scala.language.existentials
+
+import com.google.protobuf.ByteString
+import io.grpc.{CallCredentials, CallOptions, Channel, ChannelCredentials,
ClientCall, ClientInterceptor, CompositeChannelCredentials,
ForwardingClientCall, Grpc, InsecureChannelCredentials, ManagedChannel,
Metadata, MethodDescriptor, Status, TlsChannelCredentials}
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.UserContext
@@ -34,20 +34,24 @@ import
org.apache.spark.sql.connect.common.config.ConnectCommon
* Conceptually the remote spark session that communicates with the server.
*/
private[sql] class SparkConnectClient(
- private val userContext: proto.UserContext,
- private val channelBuilder: ManagedChannelBuilder[_],
- private[client] val userAgent: String) {
+ private[sql] val configuration: SparkConnectClient.Configuration,
+ private val channel: ManagedChannel) {
- private[this] lazy val channel: ManagedChannel = channelBuilder.build()
+ def this(configuration: SparkConnectClient.Configuration) =
+ this(configuration, configuration.createChannel())
+
+ private val userContext: UserContext = configuration.userContext
private[this] val stub =
proto.SparkConnectServiceGrpc.newBlockingStub(channel)
+ private[client] def userAgent: String = configuration.userAgent
+
/**
* Placeholder method.
* @return
* User ID.
*/
- private[sql] def userId: String = userContext.getUserId()
+ private[sql] def userId: String = userContext.getUserId
// Generate a unique session ID for this client. This UUID must be unique to
allow
// concurrent Spark sessions of the same user. If the channel is closed,
creating
@@ -195,9 +199,7 @@ private[sql] class SparkConnectClient(
stub.interrupt(request)
}
- def copy(): SparkConnectClient = {
- new SparkConnectClient(userContext, channelBuilder, userAgent)
- }
+ def copy(): SparkConnectClient = new SparkConnectClient(configuration)
/**
* Add a single artifact to the client session.
@@ -260,10 +262,9 @@ object SparkConnectClient {
"Either remove 'token' or set 'use_ssl=true'"
// for internal tests
- private[sql] def apply(
- userContext: UserContext,
- builder: ManagedChannelBuilder[_]): SparkConnectClient =
- new SparkConnectClient(userContext, builder, DEFAULT_USER_AGENT)
+ private[sql] def apply(channel: ManagedChannel): SparkConnectClient = {
+ new SparkConnectClient(Configuration(), channel)
+ }
def builder(): Builder = new Builder()
@@ -271,50 +272,42 @@ object SparkConnectClient {
* This is a helper class that is used to create a GRPC channel based on
either a set host and
* port or a NameResolver-compliant URI connection string.
*/
- class Builder() {
- private val userContextBuilder = proto.UserContext.newBuilder()
- private var _userAgent: String = DEFAULT_USER_AGENT
-
- private var _host: String = "localhost"
- private var _port: Int = ConnectCommon.CONNECT_GRPC_BINDING_PORT
+ class Builder(private var _configuration: Configuration) {
+ def this() = this(Configuration())
- private var _token: Option[String] = None
- // If no value specified for isSslEnabled, default to false
- private var isSslEnabled: Option[Boolean] = None
-
- private var metadata: Map[String, String] = Map.empty
+ def configuration: Configuration = _configuration
def userId(id: String): Builder = {
// TODO this is not an optional field!
require(id != null && id.nonEmpty)
- userContextBuilder.setUserId(id)
+ _configuration = _configuration.copy(userId = id)
this
}
- def userId: Option[String] =
Option(userContextBuilder.getUserId).filter(_.nonEmpty)
+ def userId: Option[String] = Option(_configuration.userId)
def userName(name: String): Builder = {
require(name != null && name.nonEmpty)
- userContextBuilder.setUserName(name)
+ _configuration = _configuration.copy(userName = name)
this
}
- def userName: Option[String] =
Option(userContextBuilder.getUserName).filter(_.nonEmpty)
+ def userName: Option[String] = Option(_configuration.userName)
def host(inputHost: String): Builder = {
require(inputHost != null)
- _host = inputHost
+ _configuration = _configuration.copy(host = inputHost)
this
}
- def host: String = _host
+ def host: String = _configuration.host
def port(inputPort: Int): Builder = {
- _port = inputPort
+ _configuration = _configuration.copy(port = inputPort)
this
}
- def port: Int = _port
+ def port: Int = _configuration.port
/**
* Setting the token implicitly sets the use_ssl=true. All the following
examples yield the
@@ -335,21 +328,18 @@ object SparkConnectClient {
*/
def token(inputToken: String): Builder = {
require(inputToken != null && inputToken.nonEmpty)
- _token = Some(inputToken)
- // Only set the isSSlEnabled if it is not yet set
- isSslEnabled match {
- case None => isSslEnabled = Some(true)
- case Some(false) =>
- throw new
IllegalArgumentException(AUTH_TOKEN_ON_INSECURE_CONN_ERROR_MSG)
- case Some(true) => // Good, the ssl is enabled
+ if (_configuration.isSslEnabled.contains(false)) {
+ throw new
IllegalArgumentException(AUTH_TOKEN_ON_INSECURE_CONN_ERROR_MSG)
}
+ _configuration =
+ _configuration.copy(token = Option(inputToken), isSslEnabled =
Option(true))
this
}
- def token: Option[String] = _token
+ def token: Option[String] = _configuration.token
def enableSsl(): Builder = {
- isSslEnabled = Some(true)
+ _configuration = _configuration.copy(isSslEnabled = Option(true))
this
}
@@ -360,12 +350,12 @@ object SparkConnectClient {
* this builder.
*/
def disableSsl(): Builder = {
- require(_token.isEmpty, AUTH_TOKEN_ON_INSECURE_CONN_ERROR_MSG)
- isSslEnabled = Some(false)
+ require(token.isEmpty, AUTH_TOKEN_ON_INSECURE_CONN_ERROR_MSG)
+ _configuration = _configuration.copy(isSslEnabled = Option(false))
this
}
- def sslEnabled: Boolean = isSslEnabled.contains(true)
+ def sslEnabled: Boolean = _configuration.isSslEnabled.contains(true)
private object URIParams {
val PARAM_USER_ID = "user_id"
@@ -399,18 +389,18 @@ object SparkConnectClient {
def userAgent(value: String): Builder = {
require(value != null)
- _userAgent = value
+ _configuration = _configuration.copy(userAgent = value)
this
}
- def userAgent: String = _userAgent
+ def userAgent: String = _configuration.userAgent
def option(key: String, value: String): Builder = {
- metadata += ((key, value))
+ _configuration = _configuration.copy(metadata = _configuration.metadata
+ ((key, value)))
this
}
- def options: Map[String, String] = metadata
+ def options: Map[String, String] = _configuration.metadata
private def parseURIParams(uri: URI): Unit = {
val params = uri.getPath.split(';').drop(1).filter(_ != "")
@@ -430,7 +420,7 @@ object SparkConnectClient {
case URIParams.PARAM_TOKEN => token(value)
case URIParams.PARAM_USE_SSL =>
if (java.lang.Boolean.valueOf(value)) enableSsl() else disableSsl()
- case _ => this.metadata = this.metadata + (key -> value)
+ case _ => option(key, value)
}
}
}
@@ -439,7 +429,7 @@ object SparkConnectClient {
* Configure the builder using the env SPARK_REMOTE environment variable.
*/
def loadFromEnvironment(): Builder = {
- sys.env.get("SPARK_REMOTE").foreach(connectionString)
+ sys.env.get(SparkConnectClient.SPARK_REMOTE).foreach(connectionString)
this
}
@@ -453,10 +443,10 @@ object SparkConnectClient {
val uri = new URI(connectionString)
verifyURI(uri)
parseURIParams(uri)
- _host = uri.getHost
+ host(uri.getHost)
val inputPort = uri.getPort
if (inputPort != -1) {
- _port = inputPort
+ port(uri.getPort)
}
this
}
@@ -469,27 +459,56 @@ object SparkConnectClient {
this
}
- def build(): SparkConnectClient = {
- val creds = isSslEnabled match {
- case Some(false) | None => InsecureChannelCredentials.create()
- case Some(true) =>
- _token match {
- case Some(t) =>
- // With access token added in the http header.
- CompositeChannelCredentials.create(
- TlsChannelCredentials.create,
- new AccessTokenCallCredentials(t))
- case None =>
- TlsChannelCredentials.create
- }
+ def build(): SparkConnectClient = new SparkConnectClient(_configuration)
+ }
+
+ /**
+ * Helper class that fully captures the configuration for a
[[SparkConnectClient]].
+ */
+ private[sql] case class Configuration(
+ userId: String = null,
+ userName: String = null,
+ host: String = "localhost",
+ port: Int = ConnectCommon.CONNECT_GRPC_BINDING_PORT,
+ token: Option[String] = None,
+ isSslEnabled: Option[Boolean] = None,
+ metadata: Map[String, String] = Map.empty,
+ userAgent: String = DEFAULT_USER_AGENT) {
+
+ def userContext: proto.UserContext = {
+ val builder = proto.UserContext.newBuilder()
+ if (userId != null) {
+ builder.setUserId(userId)
+ }
+ if (userName != null) {
+ builder.setUserName(userName)
}
+ builder.build()
+ }
+
+ def credentials: ChannelCredentials = {
+ if (isSslEnabled.contains(true)) {
+ token match {
+ case Some(t) =>
+ // With access token added in the http header.
+ CompositeChannelCredentials.create(
+ TlsChannelCredentials.create,
+ new AccessTokenCallCredentials(t))
+ case None =>
+ TlsChannelCredentials.create()
+ }
+ } else {
+ InsecureChannelCredentials.create()
+ }
+ }
- val channelBuilder = Grpc.newChannelBuilderForAddress(_host, _port,
creds)
+ def createChannel(): ManagedChannel = {
+ val channelBuilder = Grpc.newChannelBuilderForAddress(host, port,
credentials)
if (metadata.nonEmpty) {
channelBuilder.intercept(new MetadataHeaderClientInterceptor(metadata))
}
channelBuilder.maxInboundMessageSize(ConnectCommon.CONNECT_GRPC_MAX_MESSAGE_SIZE)
- new SparkConnectClient(userContextBuilder.build(), channelBuilder,
_userAgent)
+ channelBuilder.build()
}
}
@@ -508,7 +527,7 @@ object SparkConnectClient {
appExecutor.execute(() => {
try {
val headers = new Metadata()
- headers.put(AUTH_TOKEN_META_DATA_KEY, s"Bearer $token");
+ headers.put(AUTH_TOKEN_META_DATA_KEY, s"Bearer $token")
applier.apply(headers)
} catch {
case e: Throwable =>
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
index 33fc5d5a4a2..223ff0e9ee2 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
@@ -205,7 +205,7 @@ class ClientE2ETestSuite extends RemoteSparkSession with
SQLHelper {
.builder()
.port(port)
.build())
- .build()
+ .create()
val df2 = spark2.range(10).limit(3)
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index e5738fe7acd..95a1f68179d 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -40,10 +40,8 @@ class DatasetSuite extends ConnectFunSuite with
BeforeAndAfterEach {
private var ss: SparkSession = _
private def newSparkSession(): SparkSession = {
- val client = new SparkConnectClient(
- proto.UserContext.newBuilder().build(),
- InProcessChannelBuilder.forName(getClass.getName).directExecutor(),
- "test")
+ val client = SparkConnectClient(
+
InProcessChannelBuilder.forName(getClass.getName).directExecutor().build())
new SparkSession(client, cleaner = SparkSession.cleaner, planIdGenerator =
new AtomicLong)
}
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
index e256c57cfd2..b9b776420c2 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
@@ -97,9 +97,7 @@ class PlanGenerationTestSuite
override protected def beforeAll(): Unit = {
super.beforeAll()
- val client = SparkConnectClient(
- proto.UserContext.newBuilder().build(),
- InProcessChannelBuilder.forName("/dev/null"))
+ val client =
SparkConnectClient(InProcessChannelBuilder.forName("/dev/null").build())
session =
new SparkSession(client, cleaner = SparkSession.cleaner, planIdGenerator
= new AtomicLong)
}
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala
index 470736fbebe..3b5d7dae1b3 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala
@@ -25,7 +25,6 @@ import io.grpc.inprocess.InProcessChannelBuilder
import org.apache.commons.lang3.{JavaVersion, SystemUtils}
import org.scalatest.BeforeAndAfterAll
-import org.apache.spark.connect.proto
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder,
ExpressionEncoder}
import org.apache.spark.sql.connect.client.SparkConnectClient
import org.apache.spark.sql.connect.client.util.ConnectFunSuite
@@ -38,9 +37,7 @@ class SQLImplicitsTestSuite extends ConnectFunSuite with
BeforeAndAfterAll {
override protected def beforeAll(): Unit = {
super.beforeAll()
- val client = SparkConnectClient(
- proto.UserContext.newBuilder().build(),
- InProcessChannelBuilder.forName("/dev/null"))
+ val client =
SparkConnectClient(InProcessChannelBuilder.forName("/dev/null").build())
session =
new SparkSession(client, cleaner = SparkSession.cleaner, planIdGenerator
= new AtomicLong)
}
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala
new file mode 100644
index 00000000000..48d55311da0
--- /dev/null
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql
+
+import org.apache.spark.sql.connect.client.util.ConnectFunSuite
+
+/**
+ * Tests for non-dataframe related SparkSession operations.
+ */
+class SparkSessionSuite extends ConnectFunSuite {
+ test("default") {
+ val session = SparkSession.builder().getOrCreate()
+ assert(session.client.configuration.host == "localhost")
+ assert(session.client.configuration.port == 15002)
+ session.close()
+ }
+
+ test("remote") {
+ val session =
SparkSession.builder().remote("sc://test.me:14099").getOrCreate()
+ assert(session.client.configuration.host == "test.me")
+ assert(session.client.configuration.port == 14099)
+ session.close()
+ }
+
+ test("getOrCreate") {
+ val connectionString = "sc://test.it:17865"
+ val session1 =
SparkSession.builder().remote(connectionString).getOrCreate()
+ val session2 =
SparkSession.builder().remote(connectionString).getOrCreate()
+ try {
+ assert(session1 eq session2)
+ } finally {
+ session1.close()
+ session2.close()
+ }
+ }
+
+ test("create") {
+ val connectionString = "sc://test.it:17845"
+ val session1 = SparkSession.builder().remote(connectionString).create()
+ val session2 = SparkSession.builder().remote(connectionString).create()
+ try {
+ assert(session1 ne session2)
+ assert(session1.client.configuration == session2.client.configuration)
+ } finally {
+ session1.close()
+ session2.close()
+ }
+ }
+
+ test("newSession") {
+ val connectionString = "sc://doit:16845"
+ val session1 = SparkSession.builder().remote(connectionString).create()
+ val session2 = session1.newSession()
+ try {
+ assert(session1 ne session2)
+ assert(session1.client.configuration == session2.client.configuration)
+ } finally {
+ session1.close()
+ session2.close()
+ }
+ }
+}
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
index bc600e5a071..3cccd8eeb33 100755
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
@@ -111,7 +111,7 @@ class SparkConnectClientSuite extends ConnectFunSuite with
BeforeAndAfterEach {
val testPort = 16002
client =
SparkConnectClient.builder().connectionString(s"sc://localhost:$testPort").build()
startDummyServer(testPort)
- val session = SparkSession.builder().client(client).build()
+ val session = SparkSession.builder().client(client).create()
val df = session.range(10)
df.analyze // Trigger RPC
assert(df.plan === service.getAndClearLatestInputPlan())
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala
index 1476b16da5d..b9edf9ac1a5 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala
@@ -168,8 +168,10 @@ trait RemoteSparkSession extends ConnectFunSuite with
BeforeAndAfterAll {
override def beforeAll(): Unit = {
super.beforeAll()
SparkConnectServerUtils.start()
- spark =
-
SparkSession.builder().client(SparkConnectClient.builder().port(serverPort).build()).build()
+ spark = SparkSession
+ .builder()
+ .client(SparkConnectClient.builder().port(serverPort).build())
+ .create()
// Retry and wait for the server to start
val stop = System.nanoTime() + TimeUnit.MINUTES.toNanos(1) // ~1 min
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]