This is an automated email from the ASF dual-hosted git repository. hvanhovell pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 934d14d73c1 [SPARK-42133] Add basic Dataset API methods to Spark Connect Scala Client 934d14d73c1 is described below commit 934d14d73c14406242ccd11ba67e64fde1f3b955 Author: vicennial <venkata.gud...@databricks.com> AuthorDate: Mon Jan 23 22:10:02 2023 -0400 [SPARK-42133] Add basic Dataset API methods to Spark Connect Scala Client ### What changes were proposed in this pull request? Adds the following methods: - Dataset API methods - project - filter - limit - SparkSession - range (and its variations) This PR also introduces `Column` and `functions` to support the above changes. ### Why are the changes needed? Incremental development of Spark Connect Scala Client. ### Does this PR introduce _any_ user-facing change? Yes, users may now use the proposed API methods. Example: `val df = sparkSession.range(5).limit(3)` ### How was this patch tested? Unit tests + simple E2E test. Closes #39672 from vicennial/SPARK-42133. Authored-by: vicennial <venkata.gud...@databricks.com> Signed-off-by: Herman van Hovell <her...@databricks.com> --- .../main/scala/org/apache/spark/sql/Column.scala | 107 ++++++++++++++++++ .../main/scala/org/apache/spark/sql/Dataset.scala | 50 +++++++++ .../scala/org/apache/spark/sql/SparkSession.scala | 51 +++++++++ .../sql/connect/client/SparkConnectClient.scala | 22 ++++ .../client/package.scala} | 14 ++- .../scala/org/apache/spark/sql/functions.scala | 83 ++++++++++++++ .../org/apache/spark/sql/ClientE2ETestSuite.scala | 10 ++ .../scala/org/apache/spark/sql/DatasetSuite.scala | 114 ++++++++++++++++++++ .../org/apache/spark/sql/SparkSessionSuite.scala | 120 +++++++++++++++++++++ .../connect/client/SparkConnectClientSuite.scala | 10 ++ 10 files changed, 576 insertions(+), 5 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala new file mode 100644 index 00000000000..f25d579d5c3 --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +import scala.collection.JavaConverters._ + +import org.apache.spark.connect.proto +import org.apache.spark.sql.Column.fn +import org.apache.spark.sql.connect.client.unsupported +import org.apache.spark.sql.functions.lit + +/** + * A column that will be computed based on the data in a `DataFrame`. + * + * A new column can be constructed based on the input columns present in a DataFrame: + * + * {{{ + * df("columnName") // On a specific `df` DataFrame. + * col("columnName") // A generic column not yet associated with a DataFrame. + * col("columnName.field") // Extracting a struct field + * col("`a.column.with.dots`") // Escape `.` in column names. + * $"columnName" // Scala short hand for a named column. + * }}} + * + * [[Column]] objects can be composed to form complex expressions: + * + * {{{ + * $"a" + 1 + * }}} + * + * @since 3.4.0 + */ +class Column private[sql] (private[sql] val expr: proto.Expression) { + + /** + * Sum of this expression and another expression. + * {{{ + * // Scala: The following selects the sum of a person's height and weight. + * people.select( people("height") + people("weight") ) + * + * // Java: + * people.select( people.col("height").plus(people.col("weight")) ); + * }}} + * + * @group expr_ops + * @since 3.4.0 + */ + def +(other: Any): Column = fn("+", this, lit(other)) + + /** + * Gives the column a name (alias). + * {{{ + * // Renames colA to colB in select output. + * df.select($"colA".name("colB")) + * }}} + * + * If the current column has metadata associated with it, this metadata will be propagated to + * the new column. If this not desired, use the API `as(alias: String, metadata: Metadata)` with + * explicit metadata. + * + * @group expr_ops + * @since 3.4.0 + */ + def name(alias: String): Column = Column { builder => + builder.getAliasBuilder.addName(alias).setExpr(expr) + } +} + +object Column { + + def apply(name: String): Column = Column { builder => + name match { + case "*" => + builder.getUnresolvedStarBuilder + case _ if name.endsWith(".*") => + unsupported("* with prefix is not supported yet.") + case _ => + builder.getUnresolvedAttributeBuilder.setUnparsedIdentifier(name) + } + } + + private[sql] def apply(f: proto.Expression.Builder => Unit): Column = { + val builder = proto.Expression.newBuilder() + f(builder) + new Column(builder.build()) + } + + private[sql] def fn(name: String, inputs: Column*): Column = Column { builder => + builder.getUnresolvedFunctionBuilder + .setFunctionName(name) + .addAllArguments(inputs.map(_.expr).asJava) + } +} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala index f7ed764a11e..6891b2f5bed 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -16,9 +16,59 @@ */ package org.apache.spark.sql +import scala.collection.JavaConverters._ + import org.apache.spark.connect.proto import org.apache.spark.sql.connect.client.SparkResult class Dataset(val session: SparkSession, private[sql] val plan: proto.Plan) { + + /** + * Selects a set of column based expressions. + * {{{ + * ds.select($"colA", $"colB" + 1) + * }}} + * + * @group untypedrel + * @since 3.4.0 + */ + @scala.annotation.varargs + def select(cols: Column*): Dataset = session.newDataset { builder => + builder.getProjectBuilder + .setInput(plan.getRoot) + .addAllExpressions(cols.map(_.expr).asJava) + } + + /** + * Filters rows using the given condition. + * {{{ + * // The following are equivalent: + * peopleDs.filter($"age" > 15) + * peopleDs.where($"age" > 15) + * }}} + * + * @group typedrel + * @since 3.4.0 + */ + def filter(condition: Column): Dataset = session.newDataset { builder => + builder.getFilterBuilder.setInput(plan.getRoot).setCondition(condition.expr) + } + + /** + * Returns a new Dataset by taking the first `n` rows. The difference between this function and + * `head` is that `head` is an action and returns an array (by triggering query execution) while + * `limit` returns a new Dataset. + * + * @group typedrel + * @since 3.4.0 + */ + def limit(n: Int): Dataset = session.newDataset { builder => + builder.getLimitBuilder + .setInput(plan.getRoot) + .setLimit(n) + } + + private[sql] def analyze: proto.AnalyzePlanResponse = session.analyze(plan) + def collectResult(): SparkResult = session.execute(plan) } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index 21f4ebd75db..0c4f702ca34 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -57,6 +57,54 @@ class SparkSession(private val client: SparkConnectClient, private val cleaner: builder.setSql(proto.SQL.newBuilder().setQuery(query)) } + /** + * Creates a [[Dataset]] with a single `LongType` column named `id`, containing elements in a + * range from 0 to `end` (exclusive) with step value 1. + * + * @since 3.4.0 + */ + def range(end: Long): Dataset = range(0, end) + + /** + * Creates a [[Dataset]] with a single `LongType` column named `id`, containing elements in a + * range from `start` to `end` (exclusive) with step value 1. + * + * @since 3.4.0 + */ + def range(start: Long, end: Long): Dataset = { + range(start, end, step = 1) + } + + /** + * Creates a [[Dataset]] with a single `LongType` column named `id`, containing elements in a + * range from `start` to `end` (exclusive) with a step value. + * + * @since 3.4.0 + */ + def range(start: Long, end: Long, step: Long): Dataset = { + range(start, end, step, None) + } + + /** + * Creates a [[Dataset]] with a single `LongType` column named `id`, containing elements in a + * range from `start` to `end` (exclusive) with a step value, with partition number specified. + * + * @since 3.4.0 + */ + def range(start: Long, end: Long, step: Long, numPartitions: Int): Dataset = { + range(start, end, step, Option(numPartitions)) + } + + private def range(start: Long, end: Long, step: Long, numPartitions: Option[Int]): Dataset = { + newDataset { builder => + val rangeBuilder = builder.getRangeBuilder + .setStart(start) + .setEnd(end) + .setStep(step) + numPartitions.foreach(rangeBuilder.setNumPartitions) + } + } + private[sql] def newDataset(f: proto.Relation.Builder => Unit): Dataset = { val builder = proto.Relation.newBuilder() f(builder) @@ -64,6 +112,9 @@ class SparkSession(private val client: SparkConnectClient, private val cleaner: new Dataset(this, plan) } + private[sql] def analyze(plan: proto.Plan): proto.AnalyzePlanResponse = + client.analyze(plan) + private[sql] def execute(plan: proto.Plan): SparkResult = { val value = client.execute(plan) val result = new SparkResult(value, allocator) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala index 8ad84631531..8252f8aef76 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala @@ -21,6 +21,7 @@ import scala.language.existentials import io.grpc.{ManagedChannel, ManagedChannelBuilder} import java.net.URI +import java.util.UUID import org.apache.spark.connect.proto import org.apache.spark.sql.connect.common.config.ConnectCommon @@ -41,6 +42,11 @@ class SparkConnectClient( */ def userId: String = userContext.getUserId() + // Generate a unique session ID for this client. This UUID must be unique to allow + // concurrent Spark sessions of the same user. If the channel is closed, creating + // a new client will create a new session ID. + private[client] val sessionId: String = UUID.randomUUID.toString + /** * Dispatch the [[proto.AnalyzePlanRequest]] to the Spark Connect server. * @return @@ -58,6 +64,22 @@ class SparkConnectClient( stub.executePlan(request) } + /** + * Builds a [[proto.AnalyzePlanRequest]] from `plan` and dispatched it to the Spark Connect + * server. + * @return + * A [[proto.AnalyzePlanResponse]] from the Spark Connect server. + */ + def analyze(plan: proto.Plan): proto.AnalyzePlanResponse = { + val request = proto.AnalyzePlanRequest + .newBuilder() + .setPlan(plan) + .setUserContext(userContext) + .setClientId(sessionId) + .build() + analyze(request) + } + /** * Shutdown the client's connection to the server. */ diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/package.scala similarity index 74% copy from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala copy to connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/package.scala index f7ed764a11e..9c173076ab8 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/package.scala @@ -14,11 +14,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.connect -import org.apache.spark.connect.proto -import org.apache.spark.sql.connect.client.SparkResult +package object client { -class Dataset(val session: SparkSession, private[sql] val plan: proto.Plan) { - def collectResult(): SparkResult = session.execute(plan) + private[sql] def unsupported(): Nothing = { + throw new UnsupportedOperationException + } + + private[sql] def unsupported(message: String): Nothing = { + throw new UnsupportedOperationException(message) + } } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala new file mode 100644 index 00000000000..bae394785be --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +import java.math.{BigDecimal => JBigDecimal} +import java.time.LocalDate + +import com.google.protobuf.ByteString + +import org.apache.spark.connect.proto +import org.apache.spark.sql.connect.client.unsupported + +/** + * Commonly used functions available for DataFrame operations. + * + * @since 3.4.0 + */ +// scalastyle:off +object functions { +// scalastyle:on + + private def createLiteral(f: proto.Expression.Literal.Builder => Unit): Column = Column { + builder => + val literalBuilder = proto.Expression.Literal.newBuilder() + f(literalBuilder) + builder.setLiteral(literalBuilder) + } + + private def createDecimalLiteral(precision: Int, scale: Int, value: String): Column = + createLiteral { builder => + builder.getDecimalBuilder + .setPrecision(precision) + .setScale(scale) + .setValue(value) + } + + /** + * Creates a [[Column]] of literal value. + * + * The passed in object is returned directly if it is already a [[Column]]. If the object is a + * Scala Symbol, it is converted into a [[Column]] also. Otherwise, a new [[Column]] is created + * to represent the literal value. + * + * @since 3.4.0 + */ + def lit(literal: Any): Column = { + literal match { + case c: Column => c + case s: Symbol => Column(s.name) + case v: Boolean => createLiteral(_.setBoolean(v)) + case v: Byte => createLiteral(_.setByte(v)) + case v: Short => createLiteral(_.setShort(v)) + case v: Int => createLiteral(_.setInteger(v)) + case v: Long => createLiteral(_.setLong(v)) + case v: Float => createLiteral(_.setFloat(v)) + case v: Double => createLiteral(_.setDouble(v)) + case v: BigDecimal => createDecimalLiteral(v.precision, v.scale, v.toString) + case v: JBigDecimal => createDecimalLiteral(v.precision, v.scale, v.toString) + case v: String => createLiteral(_.setString(v)) + case v: Char => createLiteral(_.setString(v.toString)) + case v: Array[Char] => createLiteral(_.setString(String.valueOf(v))) + case v: Array[Byte] => createLiteral(_.setBinary(ByteString.copyFrom(v))) + case v: collection.mutable.WrappedArray[_] => lit(v.array) + case v: LocalDate => createLiteral(_.setDate(v.toEpochDay.toInt)) + case null => unsupported("Null literals not supported yet.") + case _ => unsupported(s"literal $literal not supported (yet).") + } + } +} diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala index 1427dc49f86..e31f121ca10 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala @@ -38,6 +38,16 @@ class ClientE2ETestSuite extends RemoteSparkSession { assert(array(1).getString(0) == "World") } + test("simple dataset test") { + val df = spark.range(10).limit(3) + val result = df.collectResult() + assert(result.length == 3) + val array = result.toArray + assert(array(0).getLong(0) == 0) + assert(array(1).getLong(0) == 1) + assert(array(2).getLong(0) == 2) + } + // TODO test large result when we can create table or view // test("test spark large result") } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala new file mode 100644 index 00000000000..5e7cc9f4b6d --- /dev/null +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +import scala.collection.JavaConverters._ + +import io.grpc.Server +import io.grpc.netty.NettyServerBuilder +import java.util.concurrent.TimeUnit +import org.scalatest.BeforeAndAfterEach +import org.scalatest.funsuite.AnyFunSuite // scalastyle:ignore funsuite + +import org.apache.spark.connect.proto +import org.apache.spark.sql.connect.client.{DummySparkConnectService, SparkConnectClient} + +class DatasetSuite + extends AnyFunSuite // scalastyle:ignore funsuite + with BeforeAndAfterEach { + + private var server: Server = _ + private var service: DummySparkConnectService = _ + private var ss: SparkSession = _ + + private def getNewSparkSession(port: Int): SparkSession = { + assert(port != 0) + SparkSession + .builder() + .client( + SparkConnectClient + .builder() + .connectionString(s"sc://localhost:$port") + .build()) + .build() + } + + private def startDummyServer(): Unit = { + service = new DummySparkConnectService() + val sb = NettyServerBuilder + // Let server bind to any free port + .forPort(0) + .addService(service) + + server = sb.build + server.start() + } + + override def beforeEach(): Unit = { + super.beforeEach() + startDummyServer() + ss = getNewSparkSession(server.getPort) + } + + override def afterEach(): Unit = { + if (server != null) { + server.shutdownNow() + assert(server.awaitTermination(5, TimeUnit.SECONDS), "server failed to shutdown") + } + } + + test("limit") { + val df = ss.newDataset(_ => ()) + val builder = proto.Relation.newBuilder() + builder.getLimitBuilder.setInput(df.plan.getRoot).setLimit(10) + + val expectedPlan = proto.Plan.newBuilder().setRoot(builder).build() + df.limit(10).analyze + val actualPlan = service.getAndClearLatestInputPlan() + assert(actualPlan.equals(expectedPlan)) + } + + test("select") { + val df = ss.newDataset(_ => ()) + + val builder = proto.Relation.newBuilder() + val dummyCols = Seq[Column](Column("a"), Column("b")) + builder.getProjectBuilder + .setInput(df.plan.getRoot) + .addAllExpressions(dummyCols.map(_.expr).asJava) + val expectedPlan = proto.Plan.newBuilder().setRoot(builder).build() + + df.select(dummyCols: _*).analyze + val actualPlan = service.getAndClearLatestInputPlan() + assert(actualPlan.equals(expectedPlan)) + } + + test("filter") { + val df = ss.newDataset(_ => ()) + + val builder = proto.Relation.newBuilder() + val dummyCondition = Column.fn("dummy func", Column("a")) + builder.getFilterBuilder + .setInput(df.plan.getRoot) + .setCondition(dummyCondition.expr) + val expectedPlan = proto.Plan.newBuilder().setRoot(builder).build() + + df.filter(dummyCondition).analyze + val actualPlan = service.getAndClearLatestInputPlan() + assert(actualPlan.equals(expectedPlan)) + } +} diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala new file mode 100644 index 00000000000..760609a703f --- /dev/null +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +import io.grpc.Server +import io.grpc.netty.NettyServerBuilder +import java.util.concurrent.TimeUnit +import org.scalatest.BeforeAndAfterEach +import org.scalatest.funsuite.AnyFunSuite // scalastyle:ignore funsuite + +import org.apache.spark.connect.proto +import org.apache.spark.sql.connect.client.{DummySparkConnectService, SparkConnectClient} + +class SparkSessionSuite + extends AnyFunSuite // scalastyle:ignore funsuite + with BeforeAndAfterEach { + + private var server: Server = _ + private var service: DummySparkConnectService = _ + private val SERVER_PORT = 15250 + + private def startDummyServer(port: Int): Unit = { + service = new DummySparkConnectService() + val sb = NettyServerBuilder + .forPort(port) + .addService(service) + + server = sb.build + server.start() + } + + override def beforeEach(): Unit = { + super.beforeEach() + startDummyServer(SERVER_PORT) + } + + override def afterEach(): Unit = { + if (server != null) { + server.shutdownNow() + assert(server.awaitTermination(5, TimeUnit.SECONDS), "server failed to shutdown") + } + } + + test("SparkSession initialisation with connection string") { + val ss = SparkSession + .builder() + .client( + SparkConnectClient + .builder() + .connectionString(s"sc://localhost:$SERVER_PORT") + .build()) + .build() + val plan = proto.Plan.newBuilder().build() + ss.analyze(plan) + assert(plan.equals(service.getAndClearLatestInputPlan())) + } + + private def rangePlanCreator( + start: Long, + end: Long, + step: Long, + numPartitions: Option[Int]): proto.Plan = { + val builder = proto.Relation.newBuilder() + val rangeBuilder = builder.getRangeBuilder + .setStart(start) + .setEnd(end) + .setStep(step) + numPartitions.foreach(rangeBuilder.setNumPartitions) + proto.Plan.newBuilder().setRoot(builder).build() + } + + private def testRange( + start: Long, + end: Long, + step: Long, + numPartitions: Option[Int], + failureHint: String): Unit = { + val expectedPlan = rangePlanCreator(start, end, step, numPartitions) + val actualPlan = service.getAndClearLatestInputPlan() + assert(actualPlan.equals(expectedPlan), failureHint) + } + + test("range query") { + val ss = SparkSession + .builder() + .client( + SparkConnectClient + .builder() + .connectionString(s"sc://localhost:$SERVER_PORT") + .build()) + .build() + + ss.range(10).analyze + testRange(0, 10, 1, None, "Case: range(10)") + + ss.range(0, 20).analyze + testRange(0, 20, 1, None, "Case: range(0, 20)") + + ss.range(6, 20, 3).analyze + testRange(6, 20, 3, None, "Case: range(6, 20, 3)") + + ss.range(10, 100, 5, 2).analyze + testRange(10, 100, 5, Some(2), "Case: range(6, 20, 3, Some(2))") + } + +} diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala index 1229a91aa54..f3caba28ffd 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala @@ -24,6 +24,7 @@ import io.grpc.stub.StreamObserver import org.scalatest.BeforeAndAfterEach import org.scalatest.funsuite.AnyFunSuite // scalastyle:ignore funsuite +import org.apache.spark.connect.proto import org.apache.spark.connect.proto.{AnalyzePlanRequest, AnalyzePlanResponse, SparkConnectServiceGrpc} import org.apache.spark.sql.connect.common.config.ConnectCommon @@ -151,11 +152,20 @@ class SparkConnectClientSuite class DummySparkConnectService() extends SparkConnectServiceGrpc.SparkConnectServiceImplBase { + private var inputPlan: proto.Plan = _ + + private[sql] def getAndClearLatestInputPlan(): proto.Plan = { + val plan = inputPlan + inputPlan = null + plan + } + override def analyzePlan( request: AnalyzePlanRequest, responseObserver: StreamObserver[AnalyzePlanResponse]): Unit = { // Reply with a dummy response using the same client ID val requestClientId = request.getClientId + inputPlan = request.getPlan val response = AnalyzePlanResponse .newBuilder() .setClientId(requestClientId) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org