This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 871ca429ac0 [SPARK-41822][CONNECT] Setup gRPC connection for Scala/JVM 
client
871ca429ac0 is described below

commit 871ca429ac02fb65d9c20fb11621641c0c28e26a
Author: vicennial <[email protected]>
AuthorDate: Tue Jan 10 22:43:03 2023 +0800

    [SPARK-41822][CONNECT] Setup gRPC connection for Scala/JVM client
    
    ### What changes were proposed in this pull request?
    
    Set up the gRPC connection logic for the Scala/JVM client and create 
"common configs" to share settings between server/client.
    
    ### Why are the changes needed?
    
    Enables the client to communicate with the Spark Connect server.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Simple unit test.
    
    Closes #39361 from vicennial/setupClientCon.
    
    Authored-by: vicennial <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 connector/connect/client/jvm/pom.xml               |  24 +++-
 .../sql/connect/client/SparkConnectClient.scala    | 121 ++++++++++++++++-
 .../connect/client/SparkConnectClientSuite.scala   | 144 ++++++++++++++++++++-
 .../sql/connect/common/config/ConnectCommon.scala} |  17 +--
 .../apache/spark/sql/connect/config/Connect.scala  |   3 +-
 project/SparkBuild.scala                           |   8 ++
 6 files changed, 294 insertions(+), 23 deletions(-)

diff --git a/connector/connect/client/jvm/pom.xml 
b/connector/connect/client/jvm/pom.xml
index 39de7725de2..29a00a71cf5 100644
--- a/connector/connect/client/jvm/pom.xml
+++ b/connector/connect/client/jvm/pom.xml
@@ -32,6 +32,7 @@
   <url>https://spark.apache.org/</url>
   <properties>
     <sbt.project.name>connect-client-jvm</sbt.project.name>
+    <guava.version>31.0.1-jre</guava.version>
   </properties>
 
   <dependencies>
@@ -52,6 +53,12 @@
       <version>${protobuf.version}</version>
       <scope>compile</scope>
     </dependency>
+    <dependency>
+      <groupId>com.google.guava</groupId>
+      <artifactId>guava</artifactId>
+      <version>${guava.version}</version>
+      <scope>compile</scope>
+    </dependency>
     <dependency>
       <groupId>org.scalacheck</groupId>
       <artifactId>scalacheck_${scala.binary.version}</artifactId>
@@ -66,7 +73,7 @@
   <build>
     
<testOutputDirectory>target/scala-${scala.binary.version}/test-classes</testOutputDirectory>
     <plugins>
-      <!-- Shade all Protobuf dependencies of this build -->
+      <!-- Shade all Guava / Protobuf dependencies of this build -->
       <plugin>
         <groupId>org.apache.maven.plugins</groupId>
         <artifactId>maven-shade-plugin</artifactId>
@@ -74,6 +81,7 @@
           <shadedArtifactAttached>false</shadedArtifactAttached>
           <artifactSet>
             <includes>
+              <include>com.google.guava:*</include>
               <include>com.google.protobuf:*</include>
               
<include>org.apache.spark:spark-connect-common_${scala.binary.version}</include>
             </includes>
@@ -86,6 +94,20 @@
                 <include>com.google.protobuf.**</include>
               </includes>
             </relocation>
+            <relocation>
+              <pattern>com.google.common</pattern>
+              
<shadedPattern>${spark.shade.packageName}.connect.client.guava</shadedPattern>
+              <includes>
+                <include>com.google.common.**</include>
+              </includes>
+            </relocation>
+            <relocation>
+              <pattern>com.google.thirdparty</pattern>
+              
<shadedPattern>${spark.shade.packageName}.connect.client.guava</shadedPattern>
+              <includes>
+                <include>com.google.thirdparty.**</include>
+              </includes>
+            </relocation>
           </relocations>
         </configuration>
       </plugin>
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
index e188ef0d409..cdae9f0ceea 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
@@ -17,9 +17,22 @@
 
 package org.apache.spark.sql.connect.client
 
+import scala.language.existentials
+
+import io.grpc.{ManagedChannel, ManagedChannelBuilder}
+import java.net.URI
+
 import org.apache.spark.connect.proto
+import org.apache.spark.sql.connect.common.config.ConnectCommon
+
+/**
+ * Conceptually the remote spark session that communicates with the server.
+ */
+class SparkConnectClient(
+    private val userContext: proto.UserContext,
+    private val channel: ManagedChannel) {
 
-class SparkConnectClient(private val userContext: proto.UserContext) {
+  private[this] val stub = 
proto.SparkConnectServiceGrpc.newBlockingStub(channel)
 
   /**
    * Placeholder method.
@@ -27,21 +40,125 @@ class SparkConnectClient(private val userContext: 
proto.UserContext) {
    *   User ID.
    */
   def userId: String = userContext.getUserId()
+
+  /**
+   * Dispatch the [[proto.AnalyzePlanRequest]] to the Spark Connect server.
+   * @return
+   *   A [[proto.AnalyzePlanResponse]] from the Spark Connect server.
+   */
+  def analyze(request: proto.AnalyzePlanRequest): proto.AnalyzePlanResponse =
+    stub.analyzePlan(request)
+
+  /**
+   * Shutdown the client's connection to the server.
+   */
+  def shutdown(): Unit = {
+    channel.shutdownNow()
+  }
 }
 
 object SparkConnectClient {
   def builder(): Builder = new Builder()
 
+  /**
+   * This is a helper class that is used to create a GRPC channel based on 
either a set host and
+   * port or a NameResolver-compliant URI connection string.
+   */
   class Builder() {
     private val userContextBuilder = proto.UserContext.newBuilder()
+    private var host: String = "localhost"
+    private var port: Int = ConnectCommon.CONNECT_GRPC_BINDING_PORT
 
     def userId(id: String): Builder = {
       userContextBuilder.setUserId(id)
       this
     }
 
+    def host(inputHost: String): Builder = {
+      require(inputHost != null)
+      host = inputHost
+      this
+    }
+
+    def port(inputPort: Int): Builder = {
+      port = inputPort
+      this
+    }
+
+    object URIParams {
+      val PARAM_USER_ID = "user_id"
+      val PARAM_USE_SSL = "use_ssl"
+      val PARAM_TOKEN = "token"
+    }
+
+    private def verifyURI(uri: URI): Unit = {
+      if (uri.getScheme != "sc") {
+        throw new IllegalArgumentException("Scheme for connection URI must be 
'sc'.")
+      }
+      if (uri.getHost == null) {
+        throw new IllegalArgumentException(s"Host for connection URI must be 
defined.")
+      }
+      // Java URI considers everything after the authority segment as "path" 
until the
+      // ? (query)/# (fragment) components as shown in the regex
+      // [scheme:][//authority][path][?query][#fragment].
+      // However, with the Spark Connect definition, configuration parameter 
are passed in the
+      // style of the HTTP URL Path Parameter Syntax (e.g
+      // sc://hostname:port/;param1=value;param2=value).
+      // Thus, we manually parse the "java path" to get the "correct path" and 
configuration
+      // parameters.
+      val pathAndParams = uri.getPath.split(';')
+      if (pathAndParams.nonEmpty && (pathAndParams(0) != "/" && 
pathAndParams(0) != "")) {
+        throw new IllegalArgumentException(
+          s"Path component for connection URI must be empty: " +
+            s"${pathAndParams(0)}")
+      }
+    }
+
+    private def parseURIParams(uri: URI): Unit = {
+      val params = uri.getPath.split(';').drop(1).filter(_ != "")
+      params.foreach { kv =>
+        val (key, value) = {
+          val arr = kv.split('=')
+          if (arr.length != 2) {
+            throw new IllegalArgumentException(
+              s"Parameter $kv is not a valid parameter" +
+                s" key-value pair")
+          }
+          (arr(0), arr(1))
+        }
+        if (key == URIParams.PARAM_USER_ID) {
+          userContextBuilder.setUserId(value)
+        } else {
+          // TODO(SPARK-41917): Support SSL and Auth tokens.
+          throw new UnsupportedOperationException(
+            "Parameters apart from user_id" +
+              " are currently unsupported.")
+        }
+      }
+    }
+
+    /**
+     * Creates the channel with a target connection string, per the 
documentation of Spark
+     * Connect.
+     *
+     * Note: The connection string, if used, will override any previous 
host/port settings.
+     */
+    def connectionString(connectionString: String): Builder = {
+      // TODO(SPARK-41917): Support SSL and Auth tokens.
+      val uri = new URI(connectionString)
+      verifyURI(uri)
+      parseURIParams(uri)
+      host = uri.getHost
+      val inputPort = uri.getPort
+      if (inputPort != -1) {
+        port = inputPort
+      }
+      this
+    }
+
     def build(): SparkConnectClient = {
-      new SparkConnectClient(userContextBuilder.build())
+      val channelBuilder = ManagedChannelBuilder.forAddress(host, 
port).usePlaintext()
+      new SparkConnectClient(userContextBuilder.build(), 
channelBuilder.build())
     }
   }
 }
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
index e0265bb210f..ea810a11272 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
@@ -16,17 +16,151 @@
  */
 package org.apache.spark.sql.connect.client
 
+import java.util.concurrent.TimeUnit
+
+import io.grpc.Server
+import io.grpc.netty.NettyServerBuilder
+import io.grpc.stub.StreamObserver
+import org.scalatest.BeforeAndAfterEach
 import org.scalatest.funsuite.AnyFunSuite // scalastyle:ignore funsuite
-import org.apache.spark.connect.proto
 
-class SparkConnectClientSuite extends AnyFunSuite { // scalastyle:ignore 
funsuite
+import org.apache.spark.connect.proto.{AnalyzePlanRequest, 
AnalyzePlanResponse, SparkConnectServiceGrpc}
+import org.apache.spark.sql.connect.common.config.ConnectCommon
+
+class SparkConnectClientSuite
+    extends AnyFunSuite // scalastyle:ignore funsuite
+    with BeforeAndAfterEach {
+
+  private var client: SparkConnectClient = _
+  private var server: Server = _
+
+  private def startDummyServer(port: Int): Unit = {
+    val sb = NettyServerBuilder
+      .forPort(port)
+      .addService(new DummySparkConnectService())
+
+    server = sb.build
+    server.start()
+  }
+  override def beforeEach(): Unit = {
+    super.beforeEach()
+    client = null
+    server = null
+  }
+
+  override def afterEach(): Unit = {
+    if (server != null) {
+      server.shutdownNow()
+      assert(server.awaitTermination(5, TimeUnit.SECONDS), "server failed to 
shutdown")
+    }
 
-  private def createClient = {
-    new SparkConnectClient(proto.UserContext.newBuilder().build())
+    if (client != null) {
+      client.shutdown()
+    }
   }
 
   test("Placeholder test: Create SparkConnectClient") {
-    val client = SparkConnectClient.builder().userId("abc123").build()
+    client = SparkConnectClient.builder().userId("abc123").build()
     assert(client.userId == "abc123")
   }
+
+  private def testClientConnection(
+      client: SparkConnectClient,
+      serverPort: Int = ConnectCommon.CONNECT_GRPC_BINDING_PORT): Unit = {
+    startDummyServer(serverPort)
+    val request = AnalyzePlanRequest
+      .newBuilder()
+      .setClientId("abc123")
+      .build()
+
+    val response = client.analyze(request)
+    assert(response.getClientId === "abc123")
+  }
+
+  test("Test connection") {
+    val testPort = 16000
+    client = SparkConnectClient.builder().port(testPort).build()
+    testClientConnection(client, testPort)
+  }
+
+  test("Test connection string") {
+    val testPort = 16000
+    client = 
SparkConnectClient.builder().connectionString("sc://localhost:16000").build()
+    testClientConnection(client, testPort)
+  }
+
+  private case class TestPackURI(
+      connectionString: String,
+      isCorrect: Boolean,
+      extraChecks: SparkConnectClient => Unit = _ => {})
+
+  private val URIs = Seq[TestPackURI](
+    TestPackURI("sc://host", isCorrect = true),
+    TestPackURI("sc://localhost/", isCorrect = true, client => 
testClientConnection(client)),
+    TestPackURI(
+      "sc://localhost:123/",
+      isCorrect = true,
+      client => testClientConnection(client, 123)),
+    TestPackURI("sc://localhost/;", isCorrect = true, client => 
testClientConnection(client)),
+    TestPackURI("sc://host:123", isCorrect = true),
+    TestPackURI(
+      "sc://host:123/;user_id=a94",
+      isCorrect = true,
+      client => assert(client.userId == "a94")),
+    TestPackURI("scc://host:12", isCorrect = false),
+    TestPackURI("http://host";, isCorrect = false),
+    TestPackURI("sc:/host:1234/path", isCorrect = false),
+    TestPackURI("sc://host/path", isCorrect = false),
+    TestPackURI("sc://host/;parm1;param2", isCorrect = false),
+    TestPackURI("sc://host:123;user_id=a94", isCorrect = false),
+    TestPackURI("sc:///user_id=123", isCorrect = false),
+    TestPackURI("sc://host:-4", isCorrect = false),
+    TestPackURI("sc://:123/", isCorrect = false))
+
+  private def checkTestPack(testPack: TestPackURI): Unit = {
+    val client = 
SparkConnectClient.builder().connectionString(testPack.connectionString).build()
+    testPack.extraChecks(client)
+  }
+
+  URIs.foreach { testPack =>
+    test(s"Check URI: ${testPack.connectionString}, isCorrect: 
${testPack.isCorrect}") {
+      if (!testPack.isCorrect) {
+        assertThrows[IllegalArgumentException](checkTestPack(testPack))
+      } else {
+        checkTestPack(testPack)
+      }
+    }
+  }
+
+  // TODO(SPARK-41917): Remove test once SSL and Auth tokens are supported.
+  test("Non user-id parameters throw unsupported errors") {
+    assertThrows[UnsupportedOperationException] {
+      
SparkConnectClient.builder().connectionString("sc://host/;use_ssl=true").build()
+    }
+
+    assertThrows[UnsupportedOperationException] {
+      
SparkConnectClient.builder().connectionString("sc://host/;token=abc").build()
+    }
+
+    assertThrows[UnsupportedOperationException] {
+      
SparkConnectClient.builder().connectionString("sc://host/;xyz=abc").build()
+
+    }
+  }
+}
+
+class DummySparkConnectService() extends 
SparkConnectServiceGrpc.SparkConnectServiceImplBase {
+
+  override def analyzePlan(
+      request: AnalyzePlanRequest,
+      responseObserver: StreamObserver[AnalyzePlanResponse]): Unit = {
+    // Reply with a dummy response using the same client ID
+    val requestClientId = request.getClientId
+    val response = AnalyzePlanResponse
+      .newBuilder()
+      .setClientId(requestClientId)
+      .build()
+    responseObserver.onNext(response)
+    responseObserver.onCompleted()
+  }
 }
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/config/ConnectCommon.scala
similarity index 61%
copy from 
connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
copy to 
connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/config/ConnectCommon.scala
index e0265bb210f..48ae4d2d77f 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
+++ 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/config/ConnectCommon.scala
@@ -14,19 +14,8 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.spark.sql.connect.client
+package org.apache.spark.sql.connect.common.config
 
-import org.scalatest.funsuite.AnyFunSuite // scalastyle:ignore funsuite
-import org.apache.spark.connect.proto
-
-class SparkConnectClientSuite extends AnyFunSuite { // scalastyle:ignore 
funsuite
-
-  private def createClient = {
-    new SparkConnectClient(proto.UserContext.newBuilder().build())
-  }
-
-  test("Placeholder test: Create SparkConnectClient") {
-    val client = SparkConnectClient.builder().userId("abc123").build()
-    assert(client.userId == "abc123")
-  }
+private[connect] object ConnectCommon {
+  val CONNECT_GRPC_BINDING_PORT: Int = 15002
 }
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
index 3f255378b97..b38f0f6f6d1 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
@@ -18,6 +18,7 @@ package org.apache.spark.sql.connect.config
 
 import org.apache.spark.internal.config.ConfigBuilder
 import org.apache.spark.network.util.ByteUnit
+import org.apache.spark.sql.connect.common.config.ConnectCommon
 
 private[spark] object Connect {
 
@@ -25,7 +26,7 @@ private[spark] object Connect {
     ConfigBuilder("spark.connect.grpc.binding.port")
       .version("3.4.0")
       .intConf
-      .createWithDefault(15002)
+      .createWithDefault(ConnectCommon.CONNECT_GRPC_BINDING_PORT)
 
   val CONNECT_GRPC_INTERCEPTOR_CLASSES =
     ConfigBuilder("spark.connect.grpc.interceptor.classes")
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 9f427be1207..70bdfc9d8bb 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -834,13 +834,19 @@ object SparkConnectClient {
     // For some reason the resolution from the imported Maven build does not 
work for some
     // of these dependendencies that we need to shade later on.
     libraryDependencies ++= {
+      val guavaVersion =
+        
SbtPomKeys.effectivePom.value.getProperties.get("guava.version").asInstanceOf[String]
       Seq(
+        "com.google.guava" % "guava" % guavaVersion,
         "com.google.protobuf" % "protobuf-java" % protoVersion % "protobuf"
       )
     },
 
     dependencyOverrides ++= {
+      val guavaVersion =
+        
SbtPomKeys.effectivePom.value.getProperties.get("guava.version").asInstanceOf[String]
       Seq(
+        "com.google.guava" % "guava" % guavaVersion,
         "com.google.protobuf" % "protobuf-java" % protoVersion
       )
     },
@@ -865,6 +871,8 @@ object SparkConnectClient {
 
     (assembly / assemblyShadeRules) := Seq(
       ShadeRule.rename("com.google.protobuf.**" -> 
"org.sparkproject.connect.protobuf.@1").inAll,
+      ShadeRule.rename("com.google.common.**" -> 
"org.sparkproject.connect.client.guava.@1").inAll,
+      ShadeRule.rename("com.google.thirdparty.**" -> 
"org.sparkproject.connect.client.guava.@1").inAll,
     ),
 
     (assembly / assemblyMergeStrategy) := {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to