This is an automated email from the ASF dual-hosted git repository.
hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 934d14d73c1 [SPARK-42133] Add basic Dataset API methods to Spark
Connect Scala Client
934d14d73c1 is described below
commit 934d14d73c14406242ccd11ba67e64fde1f3b955
Author: vicennial <[email protected]>
AuthorDate: Mon Jan 23 22:10:02 2023 -0400
[SPARK-42133] Add basic Dataset API methods to Spark Connect Scala Client
### What changes were proposed in this pull request?
Adds the following methods:
- Dataset API methods
- project
- filter
- limit
- SparkSession
- range (and its variations)
This PR also introduces `Column` and `functions` to support the above
changes.
### Why are the changes needed?
Incremental development of Spark Connect Scala Client.
### Does this PR introduce _any_ user-facing change?
Yes, users may now use the proposed API methods.
Example: `val df = sparkSession.range(5).limit(3)`
### How was this patch tested?
Unit tests + simple E2E test.
Closes #39672 from vicennial/SPARK-42133.
Authored-by: vicennial <[email protected]>
Signed-off-by: Herman van Hovell <[email protected]>
---
.../main/scala/org/apache/spark/sql/Column.scala | 107 ++++++++++++++++++
.../main/scala/org/apache/spark/sql/Dataset.scala | 50 +++++++++
.../scala/org/apache/spark/sql/SparkSession.scala | 51 +++++++++
.../sql/connect/client/SparkConnectClient.scala | 22 ++++
.../client/package.scala} | 14 ++-
.../scala/org/apache/spark/sql/functions.scala | 83 ++++++++++++++
.../org/apache/spark/sql/ClientE2ETestSuite.scala | 10 ++
.../scala/org/apache/spark/sql/DatasetSuite.scala | 114 ++++++++++++++++++++
.../org/apache/spark/sql/SparkSessionSuite.scala | 120 +++++++++++++++++++++
.../connect/client/SparkConnectClientSuite.scala | 10 ++
10 files changed, 576 insertions(+), 5 deletions(-)
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala
new file mode 100644
index 00000000000..f25d579d5c3
--- /dev/null
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.connect.proto
+import org.apache.spark.sql.Column.fn
+import org.apache.spark.sql.connect.client.unsupported
+import org.apache.spark.sql.functions.lit
+
+/**
+ * A column that will be computed based on the data in a `DataFrame`.
+ *
+ * A new column can be constructed based on the input columns present in a
DataFrame:
+ *
+ * {{{
+ * df("columnName") // On a specific `df` DataFrame.
+ * col("columnName") // A generic column not yet associated with a
DataFrame.
+ * col("columnName.field") // Extracting a struct field
+ * col("`a.column.with.dots`") // Escape `.` in column names.
+ * $"columnName" // Scala short hand for a named column.
+ * }}}
+ *
+ * [[Column]] objects can be composed to form complex expressions:
+ *
+ * {{{
+ * $"a" + 1
+ * }}}
+ *
+ * @since 3.4.0
+ */
+class Column private[sql] (private[sql] val expr: proto.Expression) {
+
+ /**
+ * Sum of this expression and another expression.
+ * {{{
+ * // Scala: The following selects the sum of a person's height and weight.
+ * people.select( people("height") + people("weight") )
+ *
+ * // Java:
+ * people.select( people.col("height").plus(people.col("weight")) );
+ * }}}
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def +(other: Any): Column = fn("+", this, lit(other))
+
+ /**
+ * Gives the column a name (alias).
+ * {{{
+ * // Renames colA to colB in select output.
+ * df.select($"colA".name("colB"))
+ * }}}
+ *
+ * If the current column has metadata associated with it, this metadata will
be propagated to
+ * the new column. If this not desired, use the API `as(alias: String,
metadata: Metadata)` with
+ * explicit metadata.
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def name(alias: String): Column = Column { builder =>
+ builder.getAliasBuilder.addName(alias).setExpr(expr)
+ }
+}
+
+object Column {
+
+ def apply(name: String): Column = Column { builder =>
+ name match {
+ case "*" =>
+ builder.getUnresolvedStarBuilder
+ case _ if name.endsWith(".*") =>
+ unsupported("* with prefix is not supported yet.")
+ case _ =>
+ builder.getUnresolvedAttributeBuilder.setUnparsedIdentifier(name)
+ }
+ }
+
+ private[sql] def apply(f: proto.Expression.Builder => Unit): Column = {
+ val builder = proto.Expression.newBuilder()
+ f(builder)
+ new Column(builder.build())
+ }
+
+ private[sql] def fn(name: String, inputs: Column*): Column = Column {
builder =>
+ builder.getUnresolvedFunctionBuilder
+ .setFunctionName(name)
+ .addAllArguments(inputs.map(_.expr).asJava)
+ }
+}
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
index f7ed764a11e..6891b2f5bed 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -16,9 +16,59 @@
*/
package org.apache.spark.sql
+import scala.collection.JavaConverters._
+
import org.apache.spark.connect.proto
import org.apache.spark.sql.connect.client.SparkResult
class Dataset(val session: SparkSession, private[sql] val plan: proto.Plan) {
+
+ /**
+ * Selects a set of column based expressions.
+ * {{{
+ * ds.select($"colA", $"colB" + 1)
+ * }}}
+ *
+ * @group untypedrel
+ * @since 3.4.0
+ */
+ @scala.annotation.varargs
+ def select(cols: Column*): Dataset = session.newDataset { builder =>
+ builder.getProjectBuilder
+ .setInput(plan.getRoot)
+ .addAllExpressions(cols.map(_.expr).asJava)
+ }
+
+ /**
+ * Filters rows using the given condition.
+ * {{{
+ * // The following are equivalent:
+ * peopleDs.filter($"age" > 15)
+ * peopleDs.where($"age" > 15)
+ * }}}
+ *
+ * @group typedrel
+ * @since 3.4.0
+ */
+ def filter(condition: Column): Dataset = session.newDataset { builder =>
+
builder.getFilterBuilder.setInput(plan.getRoot).setCondition(condition.expr)
+ }
+
+ /**
+ * Returns a new Dataset by taking the first `n` rows. The difference
between this function and
+ * `head` is that `head` is an action and returns an array (by triggering
query execution) while
+ * `limit` returns a new Dataset.
+ *
+ * @group typedrel
+ * @since 3.4.0
+ */
+ def limit(n: Int): Dataset = session.newDataset { builder =>
+ builder.getLimitBuilder
+ .setInput(plan.getRoot)
+ .setLimit(n)
+ }
+
+ private[sql] def analyze: proto.AnalyzePlanResponse = session.analyze(plan)
+
def collectResult(): SparkResult = session.execute(plan)
}
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 21f4ebd75db..0c4f702ca34 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -57,6 +57,54 @@ class SparkSession(private val client: SparkConnectClient,
private val cleaner:
builder.setSql(proto.SQL.newBuilder().setQuery(query))
}
+ /**
+ * Creates a [[Dataset]] with a single `LongType` column named `id`,
containing elements in a
+ * range from 0 to `end` (exclusive) with step value 1.
+ *
+ * @since 3.4.0
+ */
+ def range(end: Long): Dataset = range(0, end)
+
+ /**
+ * Creates a [[Dataset]] with a single `LongType` column named `id`,
containing elements in a
+ * range from `start` to `end` (exclusive) with step value 1.
+ *
+ * @since 3.4.0
+ */
+ def range(start: Long, end: Long): Dataset = {
+ range(start, end, step = 1)
+ }
+
+ /**
+ * Creates a [[Dataset]] with a single `LongType` column named `id`,
containing elements in a
+ * range from `start` to `end` (exclusive) with a step value.
+ *
+ * @since 3.4.0
+ */
+ def range(start: Long, end: Long, step: Long): Dataset = {
+ range(start, end, step, None)
+ }
+
+ /**
+ * Creates a [[Dataset]] with a single `LongType` column named `id`,
containing elements in a
+ * range from `start` to `end` (exclusive) with a step value, with partition
number specified.
+ *
+ * @since 3.4.0
+ */
+ def range(start: Long, end: Long, step: Long, numPartitions: Int): Dataset =
{
+ range(start, end, step, Option(numPartitions))
+ }
+
+ private def range(start: Long, end: Long, step: Long, numPartitions:
Option[Int]): Dataset = {
+ newDataset { builder =>
+ val rangeBuilder = builder.getRangeBuilder
+ .setStart(start)
+ .setEnd(end)
+ .setStep(step)
+ numPartitions.foreach(rangeBuilder.setNumPartitions)
+ }
+ }
+
private[sql] def newDataset(f: proto.Relation.Builder => Unit): Dataset = {
val builder = proto.Relation.newBuilder()
f(builder)
@@ -64,6 +112,9 @@ class SparkSession(private val client: SparkConnectClient,
private val cleaner:
new Dataset(this, plan)
}
+ private[sql] def analyze(plan: proto.Plan): proto.AnalyzePlanResponse =
+ client.analyze(plan)
+
private[sql] def execute(plan: proto.Plan): SparkResult = {
val value = client.execute(plan)
val result = new SparkResult(value, allocator)
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
index 8ad84631531..8252f8aef76 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
@@ -21,6 +21,7 @@ import scala.language.existentials
import io.grpc.{ManagedChannel, ManagedChannelBuilder}
import java.net.URI
+import java.util.UUID
import org.apache.spark.connect.proto
import org.apache.spark.sql.connect.common.config.ConnectCommon
@@ -41,6 +42,11 @@ class SparkConnectClient(
*/
def userId: String = userContext.getUserId()
+ // Generate a unique session ID for this client. This UUID must be unique to
allow
+ // concurrent Spark sessions of the same user. If the channel is closed,
creating
+ // a new client will create a new session ID.
+ private[client] val sessionId: String = UUID.randomUUID.toString
+
/**
* Dispatch the [[proto.AnalyzePlanRequest]] to the Spark Connect server.
* @return
@@ -58,6 +64,22 @@ class SparkConnectClient(
stub.executePlan(request)
}
+ /**
+ * Builds a [[proto.AnalyzePlanRequest]] from `plan` and dispatched it to
the Spark Connect
+ * server.
+ * @return
+ * A [[proto.AnalyzePlanResponse]] from the Spark Connect server.
+ */
+ def analyze(plan: proto.Plan): proto.AnalyzePlanResponse = {
+ val request = proto.AnalyzePlanRequest
+ .newBuilder()
+ .setPlan(plan)
+ .setUserContext(userContext)
+ .setClientId(sessionId)
+ .build()
+ analyze(request)
+ }
+
/**
* Shutdown the client's connection to the server.
*/
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/package.scala
similarity index 74%
copy from
connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
copy to
connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/package.scala
index f7ed764a11e..9c173076ab8 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/package.scala
@@ -14,11 +14,15 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.spark.sql
+package org.apache.spark.sql.connect
-import org.apache.spark.connect.proto
-import org.apache.spark.sql.connect.client.SparkResult
+package object client {
-class Dataset(val session: SparkSession, private[sql] val plan: proto.Plan) {
- def collectResult(): SparkResult = session.execute(plan)
+ private[sql] def unsupported(): Nothing = {
+ throw new UnsupportedOperationException
+ }
+
+ private[sql] def unsupported(message: String): Nothing = {
+ throw new UnsupportedOperationException(message)
+ }
}
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
new file mode 100644
index 00000000000..bae394785be
--- /dev/null
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql
+
+import java.math.{BigDecimal => JBigDecimal}
+import java.time.LocalDate
+
+import com.google.protobuf.ByteString
+
+import org.apache.spark.connect.proto
+import org.apache.spark.sql.connect.client.unsupported
+
+/**
+ * Commonly used functions available for DataFrame operations.
+ *
+ * @since 3.4.0
+ */
+// scalastyle:off
+object functions {
+// scalastyle:on
+
+ private def createLiteral(f: proto.Expression.Literal.Builder => Unit):
Column = Column {
+ builder =>
+ val literalBuilder = proto.Expression.Literal.newBuilder()
+ f(literalBuilder)
+ builder.setLiteral(literalBuilder)
+ }
+
+ private def createDecimalLiteral(precision: Int, scale: Int, value: String):
Column =
+ createLiteral { builder =>
+ builder.getDecimalBuilder
+ .setPrecision(precision)
+ .setScale(scale)
+ .setValue(value)
+ }
+
+ /**
+ * Creates a [[Column]] of literal value.
+ *
+ * The passed in object is returned directly if it is already a [[Column]].
If the object is a
+ * Scala Symbol, it is converted into a [[Column]] also. Otherwise, a new
[[Column]] is created
+ * to represent the literal value.
+ *
+ * @since 3.4.0
+ */
+ def lit(literal: Any): Column = {
+ literal match {
+ case c: Column => c
+ case s: Symbol => Column(s.name)
+ case v: Boolean => createLiteral(_.setBoolean(v))
+ case v: Byte => createLiteral(_.setByte(v))
+ case v: Short => createLiteral(_.setShort(v))
+ case v: Int => createLiteral(_.setInteger(v))
+ case v: Long => createLiteral(_.setLong(v))
+ case v: Float => createLiteral(_.setFloat(v))
+ case v: Double => createLiteral(_.setDouble(v))
+ case v: BigDecimal => createDecimalLiteral(v.precision, v.scale,
v.toString)
+ case v: JBigDecimal => createDecimalLiteral(v.precision, v.scale,
v.toString)
+ case v: String => createLiteral(_.setString(v))
+ case v: Char => createLiteral(_.setString(v.toString))
+ case v: Array[Char] => createLiteral(_.setString(String.valueOf(v)))
+ case v: Array[Byte] => createLiteral(_.setBinary(ByteString.copyFrom(v)))
+ case v: collection.mutable.WrappedArray[_] => lit(v.array)
+ case v: LocalDate => createLiteral(_.setDate(v.toEpochDay.toInt))
+ case null => unsupported("Null literals not supported yet.")
+ case _ => unsupported(s"literal $literal not supported (yet).")
+ }
+ }
+}
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
index 1427dc49f86..e31f121ca10 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
@@ -38,6 +38,16 @@ class ClientE2ETestSuite extends RemoteSparkSession {
assert(array(1).getString(0) == "World")
}
+ test("simple dataset test") {
+ val df = spark.range(10).limit(3)
+ val result = df.collectResult()
+ assert(result.length == 3)
+ val array = result.toArray
+ assert(array(0).getLong(0) == 0)
+ assert(array(1).getLong(0) == 1)
+ assert(array(2).getLong(0) == 2)
+ }
+
// TODO test large result when we can create table or view
// test("test spark large result")
}
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
new file mode 100644
index 00000000000..5e7cc9f4b6d
--- /dev/null
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql
+
+import scala.collection.JavaConverters._
+
+import io.grpc.Server
+import io.grpc.netty.NettyServerBuilder
+import java.util.concurrent.TimeUnit
+import org.scalatest.BeforeAndAfterEach
+import org.scalatest.funsuite.AnyFunSuite // scalastyle:ignore funsuite
+
+import org.apache.spark.connect.proto
+import org.apache.spark.sql.connect.client.{DummySparkConnectService,
SparkConnectClient}
+
+class DatasetSuite
+ extends AnyFunSuite // scalastyle:ignore funsuite
+ with BeforeAndAfterEach {
+
+ private var server: Server = _
+ private var service: DummySparkConnectService = _
+ private var ss: SparkSession = _
+
+ private def getNewSparkSession(port: Int): SparkSession = {
+ assert(port != 0)
+ SparkSession
+ .builder()
+ .client(
+ SparkConnectClient
+ .builder()
+ .connectionString(s"sc://localhost:$port")
+ .build())
+ .build()
+ }
+
+ private def startDummyServer(): Unit = {
+ service = new DummySparkConnectService()
+ val sb = NettyServerBuilder
+ // Let server bind to any free port
+ .forPort(0)
+ .addService(service)
+
+ server = sb.build
+ server.start()
+ }
+
+ override def beforeEach(): Unit = {
+ super.beforeEach()
+ startDummyServer()
+ ss = getNewSparkSession(server.getPort)
+ }
+
+ override def afterEach(): Unit = {
+ if (server != null) {
+ server.shutdownNow()
+ assert(server.awaitTermination(5, TimeUnit.SECONDS), "server failed to
shutdown")
+ }
+ }
+
+ test("limit") {
+ val df = ss.newDataset(_ => ())
+ val builder = proto.Relation.newBuilder()
+ builder.getLimitBuilder.setInput(df.plan.getRoot).setLimit(10)
+
+ val expectedPlan = proto.Plan.newBuilder().setRoot(builder).build()
+ df.limit(10).analyze
+ val actualPlan = service.getAndClearLatestInputPlan()
+ assert(actualPlan.equals(expectedPlan))
+ }
+
+ test("select") {
+ val df = ss.newDataset(_ => ())
+
+ val builder = proto.Relation.newBuilder()
+ val dummyCols = Seq[Column](Column("a"), Column("b"))
+ builder.getProjectBuilder
+ .setInput(df.plan.getRoot)
+ .addAllExpressions(dummyCols.map(_.expr).asJava)
+ val expectedPlan = proto.Plan.newBuilder().setRoot(builder).build()
+
+ df.select(dummyCols: _*).analyze
+ val actualPlan = service.getAndClearLatestInputPlan()
+ assert(actualPlan.equals(expectedPlan))
+ }
+
+ test("filter") {
+ val df = ss.newDataset(_ => ())
+
+ val builder = proto.Relation.newBuilder()
+ val dummyCondition = Column.fn("dummy func", Column("a"))
+ builder.getFilterBuilder
+ .setInput(df.plan.getRoot)
+ .setCondition(dummyCondition.expr)
+ val expectedPlan = proto.Plan.newBuilder().setRoot(builder).build()
+
+ df.filter(dummyCondition).analyze
+ val actualPlan = service.getAndClearLatestInputPlan()
+ assert(actualPlan.equals(expectedPlan))
+ }
+}
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala
new file mode 100644
index 00000000000..760609a703f
--- /dev/null
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql
+
+import io.grpc.Server
+import io.grpc.netty.NettyServerBuilder
+import java.util.concurrent.TimeUnit
+import org.scalatest.BeforeAndAfterEach
+import org.scalatest.funsuite.AnyFunSuite // scalastyle:ignore funsuite
+
+import org.apache.spark.connect.proto
+import org.apache.spark.sql.connect.client.{DummySparkConnectService,
SparkConnectClient}
+
+class SparkSessionSuite
+ extends AnyFunSuite // scalastyle:ignore funsuite
+ with BeforeAndAfterEach {
+
+ private var server: Server = _
+ private var service: DummySparkConnectService = _
+ private val SERVER_PORT = 15250
+
+ private def startDummyServer(port: Int): Unit = {
+ service = new DummySparkConnectService()
+ val sb = NettyServerBuilder
+ .forPort(port)
+ .addService(service)
+
+ server = sb.build
+ server.start()
+ }
+
+ override def beforeEach(): Unit = {
+ super.beforeEach()
+ startDummyServer(SERVER_PORT)
+ }
+
+ override def afterEach(): Unit = {
+ if (server != null) {
+ server.shutdownNow()
+ assert(server.awaitTermination(5, TimeUnit.SECONDS), "server failed to
shutdown")
+ }
+ }
+
+ test("SparkSession initialisation with connection string") {
+ val ss = SparkSession
+ .builder()
+ .client(
+ SparkConnectClient
+ .builder()
+ .connectionString(s"sc://localhost:$SERVER_PORT")
+ .build())
+ .build()
+ val plan = proto.Plan.newBuilder().build()
+ ss.analyze(plan)
+ assert(plan.equals(service.getAndClearLatestInputPlan()))
+ }
+
+ private def rangePlanCreator(
+ start: Long,
+ end: Long,
+ step: Long,
+ numPartitions: Option[Int]): proto.Plan = {
+ val builder = proto.Relation.newBuilder()
+ val rangeBuilder = builder.getRangeBuilder
+ .setStart(start)
+ .setEnd(end)
+ .setStep(step)
+ numPartitions.foreach(rangeBuilder.setNumPartitions)
+ proto.Plan.newBuilder().setRoot(builder).build()
+ }
+
+ private def testRange(
+ start: Long,
+ end: Long,
+ step: Long,
+ numPartitions: Option[Int],
+ failureHint: String): Unit = {
+ val expectedPlan = rangePlanCreator(start, end, step, numPartitions)
+ val actualPlan = service.getAndClearLatestInputPlan()
+ assert(actualPlan.equals(expectedPlan), failureHint)
+ }
+
+ test("range query") {
+ val ss = SparkSession
+ .builder()
+ .client(
+ SparkConnectClient
+ .builder()
+ .connectionString(s"sc://localhost:$SERVER_PORT")
+ .build())
+ .build()
+
+ ss.range(10).analyze
+ testRange(0, 10, 1, None, "Case: range(10)")
+
+ ss.range(0, 20).analyze
+ testRange(0, 20, 1, None, "Case: range(0, 20)")
+
+ ss.range(6, 20, 3).analyze
+ testRange(6, 20, 3, None, "Case: range(6, 20, 3)")
+
+ ss.range(10, 100, 5, 2).analyze
+ testRange(10, 100, 5, Some(2), "Case: range(6, 20, 3, Some(2))")
+ }
+
+}
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
index 1229a91aa54..f3caba28ffd 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
@@ -24,6 +24,7 @@ import io.grpc.stub.StreamObserver
import org.scalatest.BeforeAndAfterEach
import org.scalatest.funsuite.AnyFunSuite // scalastyle:ignore funsuite
+import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.{AnalyzePlanRequest,
AnalyzePlanResponse, SparkConnectServiceGrpc}
import org.apache.spark.sql.connect.common.config.ConnectCommon
@@ -151,11 +152,20 @@ class SparkConnectClientSuite
class DummySparkConnectService() extends
SparkConnectServiceGrpc.SparkConnectServiceImplBase {
+ private var inputPlan: proto.Plan = _
+
+ private[sql] def getAndClearLatestInputPlan(): proto.Plan = {
+ val plan = inputPlan
+ inputPlan = null
+ plan
+ }
+
override def analyzePlan(
request: AnalyzePlanRequest,
responseObserver: StreamObserver[AnalyzePlanResponse]): Unit = {
// Reply with a dummy response using the same client ID
val requestClientId = request.getClientId
+ inputPlan = request.getPlan
val response = AnalyzePlanResponse
.newBuilder()
.setClientId(requestClientId)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]