This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 871ca429ac0 [SPARK-41822][CONNECT] Setup gRPC connection for Scala/JVM
client
871ca429ac0 is described below
commit 871ca429ac02fb65d9c20fb11621641c0c28e26a
Author: vicennial <[email protected]>
AuthorDate: Tue Jan 10 22:43:03 2023 +0800
[SPARK-41822][CONNECT] Setup gRPC connection for Scala/JVM client
### What changes were proposed in this pull request?
Set up the gRPC connection logic for the Scala/JVM client and create
"common configs" to share settings between server/client.
### Why are the changes needed?
Enables the client to communicate with the Spark Connect server.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Simple unit test.
Closes #39361 from vicennial/setupClientCon.
Authored-by: vicennial <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
connector/connect/client/jvm/pom.xml | 24 +++-
.../sql/connect/client/SparkConnectClient.scala | 121 ++++++++++++++++-
.../connect/client/SparkConnectClientSuite.scala | 144 ++++++++++++++++++++-
.../sql/connect/common/config/ConnectCommon.scala} | 17 +--
.../apache/spark/sql/connect/config/Connect.scala | 3 +-
project/SparkBuild.scala | 8 ++
6 files changed, 294 insertions(+), 23 deletions(-)
diff --git a/connector/connect/client/jvm/pom.xml
b/connector/connect/client/jvm/pom.xml
index 39de7725de2..29a00a71cf5 100644
--- a/connector/connect/client/jvm/pom.xml
+++ b/connector/connect/client/jvm/pom.xml
@@ -32,6 +32,7 @@
<url>https://spark.apache.org/</url>
<properties>
<sbt.project.name>connect-client-jvm</sbt.project.name>
+ <guava.version>31.0.1-jre</guava.version>
</properties>
<dependencies>
@@ -52,6 +53,12 @@
<version>${protobuf.version}</version>
<scope>compile</scope>
</dependency>
+ <dependency>
+ <groupId>com.google.guava</groupId>
+ <artifactId>guava</artifactId>
+ <version>${guava.version}</version>
+ <scope>compile</scope>
+ </dependency>
<dependency>
<groupId>org.scalacheck</groupId>
<artifactId>scalacheck_${scala.binary.version}</artifactId>
@@ -66,7 +73,7 @@
<build>
<testOutputDirectory>target/scala-${scala.binary.version}/test-classes</testOutputDirectory>
<plugins>
- <!-- Shade all Protobuf dependencies of this build -->
+ <!-- Shade all Guava / Protobuf dependencies of this build -->
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-shade-plugin</artifactId>
@@ -74,6 +81,7 @@
<shadedArtifactAttached>false</shadedArtifactAttached>
<artifactSet>
<includes>
+ <include>com.google.guava:*</include>
<include>com.google.protobuf:*</include>
<include>org.apache.spark:spark-connect-common_${scala.binary.version}</include>
</includes>
@@ -86,6 +94,20 @@
<include>com.google.protobuf.**</include>
</includes>
</relocation>
+ <relocation>
+ <pattern>com.google.common</pattern>
+
<shadedPattern>${spark.shade.packageName}.connect.client.guava</shadedPattern>
+ <includes>
+ <include>com.google.common.**</include>
+ </includes>
+ </relocation>
+ <relocation>
+ <pattern>com.google.thirdparty</pattern>
+
<shadedPattern>${spark.shade.packageName}.connect.client.guava</shadedPattern>
+ <includes>
+ <include>com.google.thirdparty.**</include>
+ </includes>
+ </relocation>
</relocations>
</configuration>
</plugin>
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
index e188ef0d409..cdae9f0ceea 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
@@ -17,9 +17,22 @@
package org.apache.spark.sql.connect.client
+import scala.language.existentials
+
+import io.grpc.{ManagedChannel, ManagedChannelBuilder}
+import java.net.URI
+
import org.apache.spark.connect.proto
+import org.apache.spark.sql.connect.common.config.ConnectCommon
+
+/**
+ * Conceptually the remote spark session that communicates with the server.
+ */
+class SparkConnectClient(
+ private val userContext: proto.UserContext,
+ private val channel: ManagedChannel) {
-class SparkConnectClient(private val userContext: proto.UserContext) {
+ private[this] val stub =
proto.SparkConnectServiceGrpc.newBlockingStub(channel)
/**
* Placeholder method.
@@ -27,21 +40,125 @@ class SparkConnectClient(private val userContext:
proto.UserContext) {
* User ID.
*/
def userId: String = userContext.getUserId()
+
+ /**
+ * Dispatch the [[proto.AnalyzePlanRequest]] to the Spark Connect server.
+ * @return
+ * A [[proto.AnalyzePlanResponse]] from the Spark Connect server.
+ */
+ def analyze(request: proto.AnalyzePlanRequest): proto.AnalyzePlanResponse =
+ stub.analyzePlan(request)
+
+ /**
+ * Shutdown the client's connection to the server.
+ */
+ def shutdown(): Unit = {
+ channel.shutdownNow()
+ }
}
object SparkConnectClient {
def builder(): Builder = new Builder()
+ /**
+ * This is a helper class that is used to create a GRPC channel based on
either a set host and
+ * port or a NameResolver-compliant URI connection string.
+ */
class Builder() {
private val userContextBuilder = proto.UserContext.newBuilder()
+ private var host: String = "localhost"
+ private var port: Int = ConnectCommon.CONNECT_GRPC_BINDING_PORT
def userId(id: String): Builder = {
userContextBuilder.setUserId(id)
this
}
+ def host(inputHost: String): Builder = {
+ require(inputHost != null)
+ host = inputHost
+ this
+ }
+
+ def port(inputPort: Int): Builder = {
+ port = inputPort
+ this
+ }
+
+ object URIParams {
+ val PARAM_USER_ID = "user_id"
+ val PARAM_USE_SSL = "use_ssl"
+ val PARAM_TOKEN = "token"
+ }
+
+ private def verifyURI(uri: URI): Unit = {
+ if (uri.getScheme != "sc") {
+ throw new IllegalArgumentException("Scheme for connection URI must be
'sc'.")
+ }
+ if (uri.getHost == null) {
+ throw new IllegalArgumentException(s"Host for connection URI must be
defined.")
+ }
+ // Java URI considers everything after the authority segment as "path"
until the
+ // ? (query)/# (fragment) components as shown in the regex
+ // [scheme:][//authority][path][?query][#fragment].
+ // However, with the Spark Connect definition, configuration parameter
are passed in the
+ // style of the HTTP URL Path Parameter Syntax (e.g
+ // sc://hostname:port/;param1=value;param2=value).
+ // Thus, we manually parse the "java path" to get the "correct path" and
configuration
+ // parameters.
+ val pathAndParams = uri.getPath.split(';')
+ if (pathAndParams.nonEmpty && (pathAndParams(0) != "/" &&
pathAndParams(0) != "")) {
+ throw new IllegalArgumentException(
+ s"Path component for connection URI must be empty: " +
+ s"${pathAndParams(0)}")
+ }
+ }
+
+ private def parseURIParams(uri: URI): Unit = {
+ val params = uri.getPath.split(';').drop(1).filter(_ != "")
+ params.foreach { kv =>
+ val (key, value) = {
+ val arr = kv.split('=')
+ if (arr.length != 2) {
+ throw new IllegalArgumentException(
+ s"Parameter $kv is not a valid parameter" +
+ s" key-value pair")
+ }
+ (arr(0), arr(1))
+ }
+ if (key == URIParams.PARAM_USER_ID) {
+ userContextBuilder.setUserId(value)
+ } else {
+ // TODO(SPARK-41917): Support SSL and Auth tokens.
+ throw new UnsupportedOperationException(
+ "Parameters apart from user_id" +
+ " are currently unsupported.")
+ }
+ }
+ }
+
+ /**
+ * Creates the channel with a target connection string, per the
documentation of Spark
+ * Connect.
+ *
+ * Note: The connection string, if used, will override any previous
host/port settings.
+ */
+ def connectionString(connectionString: String): Builder = {
+ // TODO(SPARK-41917): Support SSL and Auth tokens.
+ val uri = new URI(connectionString)
+ verifyURI(uri)
+ parseURIParams(uri)
+ host = uri.getHost
+ val inputPort = uri.getPort
+ if (inputPort != -1) {
+ port = inputPort
+ }
+ this
+ }
+
def build(): SparkConnectClient = {
- new SparkConnectClient(userContextBuilder.build())
+ val channelBuilder = ManagedChannelBuilder.forAddress(host,
port).usePlaintext()
+ new SparkConnectClient(userContextBuilder.build(),
channelBuilder.build())
}
}
}
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
index e0265bb210f..ea810a11272 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
@@ -16,17 +16,151 @@
*/
package org.apache.spark.sql.connect.client
+import java.util.concurrent.TimeUnit
+
+import io.grpc.Server
+import io.grpc.netty.NettyServerBuilder
+import io.grpc.stub.StreamObserver
+import org.scalatest.BeforeAndAfterEach
import org.scalatest.funsuite.AnyFunSuite // scalastyle:ignore funsuite
-import org.apache.spark.connect.proto
-class SparkConnectClientSuite extends AnyFunSuite { // scalastyle:ignore
funsuite
+import org.apache.spark.connect.proto.{AnalyzePlanRequest,
AnalyzePlanResponse, SparkConnectServiceGrpc}
+import org.apache.spark.sql.connect.common.config.ConnectCommon
+
+class SparkConnectClientSuite
+ extends AnyFunSuite // scalastyle:ignore funsuite
+ with BeforeAndAfterEach {
+
+ private var client: SparkConnectClient = _
+ private var server: Server = _
+
+ private def startDummyServer(port: Int): Unit = {
+ val sb = NettyServerBuilder
+ .forPort(port)
+ .addService(new DummySparkConnectService())
+
+ server = sb.build
+ server.start()
+ }
+ override def beforeEach(): Unit = {
+ super.beforeEach()
+ client = null
+ server = null
+ }
+
+ override def afterEach(): Unit = {
+ if (server != null) {
+ server.shutdownNow()
+ assert(server.awaitTermination(5, TimeUnit.SECONDS), "server failed to
shutdown")
+ }
- private def createClient = {
- new SparkConnectClient(proto.UserContext.newBuilder().build())
+ if (client != null) {
+ client.shutdown()
+ }
}
test("Placeholder test: Create SparkConnectClient") {
- val client = SparkConnectClient.builder().userId("abc123").build()
+ client = SparkConnectClient.builder().userId("abc123").build()
assert(client.userId == "abc123")
}
+
+ private def testClientConnection(
+ client: SparkConnectClient,
+ serverPort: Int = ConnectCommon.CONNECT_GRPC_BINDING_PORT): Unit = {
+ startDummyServer(serverPort)
+ val request = AnalyzePlanRequest
+ .newBuilder()
+ .setClientId("abc123")
+ .build()
+
+ val response = client.analyze(request)
+ assert(response.getClientId === "abc123")
+ }
+
+ test("Test connection") {
+ val testPort = 16000
+ client = SparkConnectClient.builder().port(testPort).build()
+ testClientConnection(client, testPort)
+ }
+
+ test("Test connection string") {
+ val testPort = 16000
+ client =
SparkConnectClient.builder().connectionString("sc://localhost:16000").build()
+ testClientConnection(client, testPort)
+ }
+
+ private case class TestPackURI(
+ connectionString: String,
+ isCorrect: Boolean,
+ extraChecks: SparkConnectClient => Unit = _ => {})
+
+ private val URIs = Seq[TestPackURI](
+ TestPackURI("sc://host", isCorrect = true),
+ TestPackURI("sc://localhost/", isCorrect = true, client =>
testClientConnection(client)),
+ TestPackURI(
+ "sc://localhost:123/",
+ isCorrect = true,
+ client => testClientConnection(client, 123)),
+ TestPackURI("sc://localhost/;", isCorrect = true, client =>
testClientConnection(client)),
+ TestPackURI("sc://host:123", isCorrect = true),
+ TestPackURI(
+ "sc://host:123/;user_id=a94",
+ isCorrect = true,
+ client => assert(client.userId == "a94")),
+ TestPackURI("scc://host:12", isCorrect = false),
+ TestPackURI("http://host", isCorrect = false),
+ TestPackURI("sc:/host:1234/path", isCorrect = false),
+ TestPackURI("sc://host/path", isCorrect = false),
+ TestPackURI("sc://host/;parm1;param2", isCorrect = false),
+ TestPackURI("sc://host:123;user_id=a94", isCorrect = false),
+ TestPackURI("sc:///user_id=123", isCorrect = false),
+ TestPackURI("sc://host:-4", isCorrect = false),
+ TestPackURI("sc://:123/", isCorrect = false))
+
+ private def checkTestPack(testPack: TestPackURI): Unit = {
+ val client =
SparkConnectClient.builder().connectionString(testPack.connectionString).build()
+ testPack.extraChecks(client)
+ }
+
+ URIs.foreach { testPack =>
+ test(s"Check URI: ${testPack.connectionString}, isCorrect:
${testPack.isCorrect}") {
+ if (!testPack.isCorrect) {
+ assertThrows[IllegalArgumentException](checkTestPack(testPack))
+ } else {
+ checkTestPack(testPack)
+ }
+ }
+ }
+
+ // TODO(SPARK-41917): Remove test once SSL and Auth tokens are supported.
+ test("Non user-id parameters throw unsupported errors") {
+ assertThrows[UnsupportedOperationException] {
+
SparkConnectClient.builder().connectionString("sc://host/;use_ssl=true").build()
+ }
+
+ assertThrows[UnsupportedOperationException] {
+
SparkConnectClient.builder().connectionString("sc://host/;token=abc").build()
+ }
+
+ assertThrows[UnsupportedOperationException] {
+
SparkConnectClient.builder().connectionString("sc://host/;xyz=abc").build()
+
+ }
+ }
+}
+
+class DummySparkConnectService() extends
SparkConnectServiceGrpc.SparkConnectServiceImplBase {
+
+ override def analyzePlan(
+ request: AnalyzePlanRequest,
+ responseObserver: StreamObserver[AnalyzePlanResponse]): Unit = {
+ // Reply with a dummy response using the same client ID
+ val requestClientId = request.getClientId
+ val response = AnalyzePlanResponse
+ .newBuilder()
+ .setClientId(requestClientId)
+ .build()
+ responseObserver.onNext(response)
+ responseObserver.onCompleted()
+ }
}
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/config/ConnectCommon.scala
similarity index 61%
copy from
connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
copy to
connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/config/ConnectCommon.scala
index e0265bb210f..48ae4d2d77f 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
+++
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/config/ConnectCommon.scala
@@ -14,19 +14,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.spark.sql.connect.client
+package org.apache.spark.sql.connect.common.config
-import org.scalatest.funsuite.AnyFunSuite // scalastyle:ignore funsuite
-import org.apache.spark.connect.proto
-
-class SparkConnectClientSuite extends AnyFunSuite { // scalastyle:ignore
funsuite
-
- private def createClient = {
- new SparkConnectClient(proto.UserContext.newBuilder().build())
- }
-
- test("Placeholder test: Create SparkConnectClient") {
- val client = SparkConnectClient.builder().userId("abc123").build()
- assert(client.userId == "abc123")
- }
+private[connect] object ConnectCommon {
+ val CONNECT_GRPC_BINDING_PORT: Int = 15002
}
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
index 3f255378b97..b38f0f6f6d1 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
@@ -18,6 +18,7 @@ package org.apache.spark.sql.connect.config
import org.apache.spark.internal.config.ConfigBuilder
import org.apache.spark.network.util.ByteUnit
+import org.apache.spark.sql.connect.common.config.ConnectCommon
private[spark] object Connect {
@@ -25,7 +26,7 @@ private[spark] object Connect {
ConfigBuilder("spark.connect.grpc.binding.port")
.version("3.4.0")
.intConf
- .createWithDefault(15002)
+ .createWithDefault(ConnectCommon.CONNECT_GRPC_BINDING_PORT)
val CONNECT_GRPC_INTERCEPTOR_CLASSES =
ConfigBuilder("spark.connect.grpc.interceptor.classes")
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 9f427be1207..70bdfc9d8bb 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -834,13 +834,19 @@ object SparkConnectClient {
// For some reason the resolution from the imported Maven build does not
work for some
// of these dependendencies that we need to shade later on.
libraryDependencies ++= {
+ val guavaVersion =
+
SbtPomKeys.effectivePom.value.getProperties.get("guava.version").asInstanceOf[String]
Seq(
+ "com.google.guava" % "guava" % guavaVersion,
"com.google.protobuf" % "protobuf-java" % protoVersion % "protobuf"
)
},
dependencyOverrides ++= {
+ val guavaVersion =
+
SbtPomKeys.effectivePom.value.getProperties.get("guava.version").asInstanceOf[String]
Seq(
+ "com.google.guava" % "guava" % guavaVersion,
"com.google.protobuf" % "protobuf-java" % protoVersion
)
},
@@ -865,6 +871,8 @@ object SparkConnectClient {
(assembly / assemblyShadeRules) := Seq(
ShadeRule.rename("com.google.protobuf.**" ->
"org.sparkproject.connect.protobuf.@1").inAll,
+ ShadeRule.rename("com.google.common.**" ->
"org.sparkproject.connect.client.guava.@1").inAll,
+ ShadeRule.rename("com.google.thirdparty.**" ->
"org.sparkproject.connect.client.guava.@1").inAll,
),
(assembly / assemblyMergeStrategy) := {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]