This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 934d14d73c1 [SPARK-42133] Add basic Dataset API methods to Spark 
Connect Scala Client
934d14d73c1 is described below

commit 934d14d73c14406242ccd11ba67e64fde1f3b955
Author: vicennial <venkata.gud...@databricks.com>
AuthorDate: Mon Jan 23 22:10:02 2023 -0400

    [SPARK-42133] Add basic Dataset API methods to Spark Connect Scala Client
    
    ### What changes were proposed in this pull request?
    Adds the following methods:
    - Dataset API methods
      - project
      - filter
      - limit
    -  SparkSession
       - range (and its variations)
    
    This PR also introduces `Column` and `functions` to support the above 
changes.
    
    ### Why are the changes needed?
    Incremental development of Spark Connect Scala Client.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, users may now use the proposed API methods.
    Example: `val df = sparkSession.range(5).limit(3)`
    
    ### How was this patch tested?
    
    Unit tests + simple E2E test.
    
    Closes #39672 from vicennial/SPARK-42133.
    
    Authored-by: vicennial <venkata.gud...@databricks.com>
    Signed-off-by: Herman van Hovell <her...@databricks.com>
---
 .../main/scala/org/apache/spark/sql/Column.scala   | 107 ++++++++++++++++++
 .../main/scala/org/apache/spark/sql/Dataset.scala  |  50 +++++++++
 .../scala/org/apache/spark/sql/SparkSession.scala  |  51 +++++++++
 .../sql/connect/client/SparkConnectClient.scala    |  22 ++++
 .../client/package.scala}                          |  14 ++-
 .../scala/org/apache/spark/sql/functions.scala     |  83 ++++++++++++++
 .../org/apache/spark/sql/ClientE2ETestSuite.scala  |  10 ++
 .../scala/org/apache/spark/sql/DatasetSuite.scala  | 114 ++++++++++++++++++++
 .../org/apache/spark/sql/SparkSessionSuite.scala   | 120 +++++++++++++++++++++
 .../connect/client/SparkConnectClientSuite.scala   |  10 ++
 10 files changed, 576 insertions(+), 5 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala
new file mode 100644
index 00000000000..f25d579d5c3
--- /dev/null
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.connect.proto
+import org.apache.spark.sql.Column.fn
+import org.apache.spark.sql.connect.client.unsupported
+import org.apache.spark.sql.functions.lit
+
+/**
+ * A column that will be computed based on the data in a `DataFrame`.
+ *
+ * A new column can be constructed based on the input columns present in a 
DataFrame:
+ *
+ * {{{
+ *   df("columnName")            // On a specific `df` DataFrame.
+ *   col("columnName")           // A generic column not yet associated with a 
DataFrame.
+ *   col("columnName.field")     // Extracting a struct field
+ *   col("`a.column.with.dots`") // Escape `.` in column names.
+ *   $"columnName"               // Scala short hand for a named column.
+ * }}}
+ *
+ * [[Column]] objects can be composed to form complex expressions:
+ *
+ * {{{
+ *   $"a" + 1
+ * }}}
+ *
+ * @since 3.4.0
+ */
+class Column private[sql] (private[sql] val expr: proto.Expression) {
+
+  /**
+   * Sum of this expression and another expression.
+   * {{{
+   *   // Scala: The following selects the sum of a person's height and weight.
+   *   people.select( people("height") + people("weight") )
+   *
+   *   // Java:
+   *   people.select( people.col("height").plus(people.col("weight")) );
+   * }}}
+   *
+   * @group expr_ops
+   * @since 3.4.0
+   */
+  def +(other: Any): Column = fn("+", this, lit(other))
+
+  /**
+   * Gives the column a name (alias).
+   * {{{
+   *   // Renames colA to colB in select output.
+   *   df.select($"colA".name("colB"))
+   * }}}
+   *
+   * If the current column has metadata associated with it, this metadata will 
be propagated to
+   * the new column. If this not desired, use the API `as(alias: String, 
metadata: Metadata)` with
+   * explicit metadata.
+   *
+   * @group expr_ops
+   * @since 3.4.0
+   */
+  def name(alias: String): Column = Column { builder =>
+    builder.getAliasBuilder.addName(alias).setExpr(expr)
+  }
+}
+
+object Column {
+
+  def apply(name: String): Column = Column { builder =>
+    name match {
+      case "*" =>
+        builder.getUnresolvedStarBuilder
+      case _ if name.endsWith(".*") =>
+        unsupported("* with prefix is not supported yet.")
+      case _ =>
+        builder.getUnresolvedAttributeBuilder.setUnparsedIdentifier(name)
+    }
+  }
+
+  private[sql] def apply(f: proto.Expression.Builder => Unit): Column = {
+    val builder = proto.Expression.newBuilder()
+    f(builder)
+    new Column(builder.build())
+  }
+
+  private[sql] def fn(name: String, inputs: Column*): Column = Column { 
builder =>
+    builder.getUnresolvedFunctionBuilder
+      .setFunctionName(name)
+      .addAllArguments(inputs.map(_.expr).asJava)
+  }
+}
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
index f7ed764a11e..6891b2f5bed 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -16,9 +16,59 @@
  */
 package org.apache.spark.sql
 
+import scala.collection.JavaConverters._
+
 import org.apache.spark.connect.proto
 import org.apache.spark.sql.connect.client.SparkResult
 
 class Dataset(val session: SparkSession, private[sql] val plan: proto.Plan) {
+
+  /**
+   * Selects a set of column based expressions.
+   * {{{
+   *   ds.select($"colA", $"colB" + 1)
+   * }}}
+   *
+   * @group untypedrel
+   * @since 3.4.0
+   */
+  @scala.annotation.varargs
+  def select(cols: Column*): Dataset = session.newDataset { builder =>
+    builder.getProjectBuilder
+      .setInput(plan.getRoot)
+      .addAllExpressions(cols.map(_.expr).asJava)
+  }
+
+  /**
+   * Filters rows using the given condition.
+   * {{{
+   *   // The following are equivalent:
+   *   peopleDs.filter($"age" > 15)
+   *   peopleDs.where($"age" > 15)
+   * }}}
+   *
+   * @group typedrel
+   * @since 3.4.0
+   */
+  def filter(condition: Column): Dataset = session.newDataset { builder =>
+    
builder.getFilterBuilder.setInput(plan.getRoot).setCondition(condition.expr)
+  }
+
+  /**
+   * Returns a new Dataset by taking the first `n` rows. The difference 
between this function and
+   * `head` is that `head` is an action and returns an array (by triggering 
query execution) while
+   * `limit` returns a new Dataset.
+   *
+   * @group typedrel
+   * @since 3.4.0
+   */
+  def limit(n: Int): Dataset = session.newDataset { builder =>
+    builder.getLimitBuilder
+      .setInput(plan.getRoot)
+      .setLimit(n)
+  }
+
+  private[sql] def analyze: proto.AnalyzePlanResponse = session.analyze(plan)
+
   def collectResult(): SparkResult = session.execute(plan)
 }
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 21f4ebd75db..0c4f702ca34 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -57,6 +57,54 @@ class SparkSession(private val client: SparkConnectClient, 
private val cleaner:
     builder.setSql(proto.SQL.newBuilder().setQuery(query))
   }
 
+  /**
+   * Creates a [[Dataset]] with a single `LongType` column named `id`, 
containing elements in a
+   * range from 0 to `end` (exclusive) with step value 1.
+   *
+   * @since 3.4.0
+   */
+  def range(end: Long): Dataset = range(0, end)
+
+  /**
+   * Creates a [[Dataset]] with a single `LongType` column named `id`, 
containing elements in a
+   * range from `start` to `end` (exclusive) with step value 1.
+   *
+   * @since 3.4.0
+   */
+  def range(start: Long, end: Long): Dataset = {
+    range(start, end, step = 1)
+  }
+
+  /**
+   * Creates a [[Dataset]] with a single `LongType` column named `id`, 
containing elements in a
+   * range from `start` to `end` (exclusive) with a step value.
+   *
+   * @since 3.4.0
+   */
+  def range(start: Long, end: Long, step: Long): Dataset = {
+    range(start, end, step, None)
+  }
+
+  /**
+   * Creates a [[Dataset]] with a single `LongType` column named `id`, 
containing elements in a
+   * range from `start` to `end` (exclusive) with a step value, with partition 
number specified.
+   *
+   * @since 3.4.0
+   */
+  def range(start: Long, end: Long, step: Long, numPartitions: Int): Dataset = 
{
+    range(start, end, step, Option(numPartitions))
+  }
+
+  private def range(start: Long, end: Long, step: Long, numPartitions: 
Option[Int]): Dataset = {
+    newDataset { builder =>
+      val rangeBuilder = builder.getRangeBuilder
+        .setStart(start)
+        .setEnd(end)
+        .setStep(step)
+      numPartitions.foreach(rangeBuilder.setNumPartitions)
+    }
+  }
+
   private[sql] def newDataset(f: proto.Relation.Builder => Unit): Dataset = {
     val builder = proto.Relation.newBuilder()
     f(builder)
@@ -64,6 +112,9 @@ class SparkSession(private val client: SparkConnectClient, 
private val cleaner:
     new Dataset(this, plan)
   }
 
+  private[sql] def analyze(plan: proto.Plan): proto.AnalyzePlanResponse =
+    client.analyze(plan)
+
   private[sql] def execute(plan: proto.Plan): SparkResult = {
     val value = client.execute(plan)
     val result = new SparkResult(value, allocator)
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
index 8ad84631531..8252f8aef76 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
@@ -21,6 +21,7 @@ import scala.language.existentials
 
 import io.grpc.{ManagedChannel, ManagedChannelBuilder}
 import java.net.URI
+import java.util.UUID
 
 import org.apache.spark.connect.proto
 import org.apache.spark.sql.connect.common.config.ConnectCommon
@@ -41,6 +42,11 @@ class SparkConnectClient(
    */
   def userId: String = userContext.getUserId()
 
+  // Generate a unique session ID for this client. This UUID must be unique to 
allow
+  // concurrent Spark sessions of the same user. If the channel is closed, 
creating
+  // a new client will create a new session ID.
+  private[client] val sessionId: String = UUID.randomUUID.toString
+
   /**
    * Dispatch the [[proto.AnalyzePlanRequest]] to the Spark Connect server.
    * @return
@@ -58,6 +64,22 @@ class SparkConnectClient(
     stub.executePlan(request)
   }
 
+  /**
+   * Builds a [[proto.AnalyzePlanRequest]] from `plan` and dispatched it to 
the Spark Connect
+   * server.
+   * @return
+   *   A [[proto.AnalyzePlanResponse]] from the Spark Connect server.
+   */
+  def analyze(plan: proto.Plan): proto.AnalyzePlanResponse = {
+    val request = proto.AnalyzePlanRequest
+      .newBuilder()
+      .setPlan(plan)
+      .setUserContext(userContext)
+      .setClientId(sessionId)
+      .build()
+    analyze(request)
+  }
+
   /**
    * Shutdown the client's connection to the server.
    */
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/package.scala
similarity index 74%
copy from 
connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
copy to 
connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/package.scala
index f7ed764a11e..9c173076ab8 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/package.scala
@@ -14,11 +14,15 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.spark.sql
+package org.apache.spark.sql.connect
 
-import org.apache.spark.connect.proto
-import org.apache.spark.sql.connect.client.SparkResult
+package object client {
 
-class Dataset(val session: SparkSession, private[sql] val plan: proto.Plan) {
-  def collectResult(): SparkResult = session.execute(plan)
+  private[sql] def unsupported(): Nothing = {
+    throw new UnsupportedOperationException
+  }
+
+  private[sql] def unsupported(message: String): Nothing = {
+    throw new UnsupportedOperationException(message)
+  }
 }
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
new file mode 100644
index 00000000000..bae394785be
--- /dev/null
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql
+
+import java.math.{BigDecimal => JBigDecimal}
+import java.time.LocalDate
+
+import com.google.protobuf.ByteString
+
+import org.apache.spark.connect.proto
+import org.apache.spark.sql.connect.client.unsupported
+
+/**
+ * Commonly used functions available for DataFrame operations.
+ *
+ * @since 3.4.0
+ */
+// scalastyle:off
+object functions {
+// scalastyle:on
+
+  private def createLiteral(f: proto.Expression.Literal.Builder => Unit): 
Column = Column {
+    builder =>
+      val literalBuilder = proto.Expression.Literal.newBuilder()
+      f(literalBuilder)
+      builder.setLiteral(literalBuilder)
+  }
+
+  private def createDecimalLiteral(precision: Int, scale: Int, value: String): 
Column =
+    createLiteral { builder =>
+      builder.getDecimalBuilder
+        .setPrecision(precision)
+        .setScale(scale)
+        .setValue(value)
+    }
+
+  /**
+   * Creates a [[Column]] of literal value.
+   *
+   * The passed in object is returned directly if it is already a [[Column]]. 
If the object is a
+   * Scala Symbol, it is converted into a [[Column]] also. Otherwise, a new 
[[Column]] is created
+   * to represent the literal value.
+   *
+   * @since 3.4.0
+   */
+  def lit(literal: Any): Column = {
+    literal match {
+      case c: Column => c
+      case s: Symbol => Column(s.name)
+      case v: Boolean => createLiteral(_.setBoolean(v))
+      case v: Byte => createLiteral(_.setByte(v))
+      case v: Short => createLiteral(_.setShort(v))
+      case v: Int => createLiteral(_.setInteger(v))
+      case v: Long => createLiteral(_.setLong(v))
+      case v: Float => createLiteral(_.setFloat(v))
+      case v: Double => createLiteral(_.setDouble(v))
+      case v: BigDecimal => createDecimalLiteral(v.precision, v.scale, 
v.toString)
+      case v: JBigDecimal => createDecimalLiteral(v.precision, v.scale, 
v.toString)
+      case v: String => createLiteral(_.setString(v))
+      case v: Char => createLiteral(_.setString(v.toString))
+      case v: Array[Char] => createLiteral(_.setString(String.valueOf(v)))
+      case v: Array[Byte] => createLiteral(_.setBinary(ByteString.copyFrom(v)))
+      case v: collection.mutable.WrappedArray[_] => lit(v.array)
+      case v: LocalDate => createLiteral(_.setDate(v.toEpochDay.toInt))
+      case null => unsupported("Null literals not supported yet.")
+      case _ => unsupported(s"literal $literal not supported (yet).")
+    }
+  }
+}
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
index 1427dc49f86..e31f121ca10 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
@@ -38,6 +38,16 @@ class ClientE2ETestSuite extends RemoteSparkSession {
     assert(array(1).getString(0) == "World")
   }
 
+  test("simple dataset test") {
+    val df = spark.range(10).limit(3)
+    val result = df.collectResult()
+    assert(result.length == 3)
+    val array = result.toArray
+    assert(array(0).getLong(0) == 0)
+    assert(array(1).getLong(0) == 1)
+    assert(array(2).getLong(0) == 2)
+  }
+
   // TODO test large result when we can create table or view
   // test("test spark large result")
 }
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
new file mode 100644
index 00000000000..5e7cc9f4b6d
--- /dev/null
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql
+
+import scala.collection.JavaConverters._
+
+import io.grpc.Server
+import io.grpc.netty.NettyServerBuilder
+import java.util.concurrent.TimeUnit
+import org.scalatest.BeforeAndAfterEach
+import org.scalatest.funsuite.AnyFunSuite // scalastyle:ignore funsuite
+
+import org.apache.spark.connect.proto
+import org.apache.spark.sql.connect.client.{DummySparkConnectService, 
SparkConnectClient}
+
+class DatasetSuite
+    extends AnyFunSuite // scalastyle:ignore funsuite
+    with BeforeAndAfterEach {
+
+  private var server: Server = _
+  private var service: DummySparkConnectService = _
+  private var ss: SparkSession = _
+
+  private def getNewSparkSession(port: Int): SparkSession = {
+    assert(port != 0)
+    SparkSession
+      .builder()
+      .client(
+        SparkConnectClient
+          .builder()
+          .connectionString(s"sc://localhost:$port")
+          .build())
+      .build()
+  }
+
+  private def startDummyServer(): Unit = {
+    service = new DummySparkConnectService()
+    val sb = NettyServerBuilder
+      // Let server bind to any free port
+      .forPort(0)
+      .addService(service)
+
+    server = sb.build
+    server.start()
+  }
+
+  override def beforeEach(): Unit = {
+    super.beforeEach()
+    startDummyServer()
+    ss = getNewSparkSession(server.getPort)
+  }
+
+  override def afterEach(): Unit = {
+    if (server != null) {
+      server.shutdownNow()
+      assert(server.awaitTermination(5, TimeUnit.SECONDS), "server failed to 
shutdown")
+    }
+  }
+
+  test("limit") {
+    val df = ss.newDataset(_ => ())
+    val builder = proto.Relation.newBuilder()
+    builder.getLimitBuilder.setInput(df.plan.getRoot).setLimit(10)
+
+    val expectedPlan = proto.Plan.newBuilder().setRoot(builder).build()
+    df.limit(10).analyze
+    val actualPlan = service.getAndClearLatestInputPlan()
+    assert(actualPlan.equals(expectedPlan))
+  }
+
+  test("select") {
+    val df = ss.newDataset(_ => ())
+
+    val builder = proto.Relation.newBuilder()
+    val dummyCols = Seq[Column](Column("a"), Column("b"))
+    builder.getProjectBuilder
+      .setInput(df.plan.getRoot)
+      .addAllExpressions(dummyCols.map(_.expr).asJava)
+    val expectedPlan = proto.Plan.newBuilder().setRoot(builder).build()
+
+    df.select(dummyCols: _*).analyze
+    val actualPlan = service.getAndClearLatestInputPlan()
+    assert(actualPlan.equals(expectedPlan))
+  }
+
+  test("filter") {
+    val df = ss.newDataset(_ => ())
+
+    val builder = proto.Relation.newBuilder()
+    val dummyCondition = Column.fn("dummy func", Column("a"))
+    builder.getFilterBuilder
+      .setInput(df.plan.getRoot)
+      .setCondition(dummyCondition.expr)
+    val expectedPlan = proto.Plan.newBuilder().setRoot(builder).build()
+
+    df.filter(dummyCondition).analyze
+    val actualPlan = service.getAndClearLatestInputPlan()
+    assert(actualPlan.equals(expectedPlan))
+  }
+}
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala
new file mode 100644
index 00000000000..760609a703f
--- /dev/null
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql
+
+import io.grpc.Server
+import io.grpc.netty.NettyServerBuilder
+import java.util.concurrent.TimeUnit
+import org.scalatest.BeforeAndAfterEach
+import org.scalatest.funsuite.AnyFunSuite // scalastyle:ignore funsuite
+
+import org.apache.spark.connect.proto
+import org.apache.spark.sql.connect.client.{DummySparkConnectService, 
SparkConnectClient}
+
+class SparkSessionSuite
+    extends AnyFunSuite // scalastyle:ignore funsuite
+    with BeforeAndAfterEach {
+
+  private var server: Server = _
+  private var service: DummySparkConnectService = _
+  private val SERVER_PORT = 15250
+
+  private def startDummyServer(port: Int): Unit = {
+    service = new DummySparkConnectService()
+    val sb = NettyServerBuilder
+      .forPort(port)
+      .addService(service)
+
+    server = sb.build
+    server.start()
+  }
+
+  override def beforeEach(): Unit = {
+    super.beforeEach()
+    startDummyServer(SERVER_PORT)
+  }
+
+  override def afterEach(): Unit = {
+    if (server != null) {
+      server.shutdownNow()
+      assert(server.awaitTermination(5, TimeUnit.SECONDS), "server failed to 
shutdown")
+    }
+  }
+
+  test("SparkSession initialisation with connection string") {
+    val ss = SparkSession
+      .builder()
+      .client(
+        SparkConnectClient
+          .builder()
+          .connectionString(s"sc://localhost:$SERVER_PORT")
+          .build())
+      .build()
+    val plan = proto.Plan.newBuilder().build()
+    ss.analyze(plan)
+    assert(plan.equals(service.getAndClearLatestInputPlan()))
+  }
+
+  private def rangePlanCreator(
+      start: Long,
+      end: Long,
+      step: Long,
+      numPartitions: Option[Int]): proto.Plan = {
+    val builder = proto.Relation.newBuilder()
+    val rangeBuilder = builder.getRangeBuilder
+      .setStart(start)
+      .setEnd(end)
+      .setStep(step)
+    numPartitions.foreach(rangeBuilder.setNumPartitions)
+    proto.Plan.newBuilder().setRoot(builder).build()
+  }
+
+  private def testRange(
+      start: Long,
+      end: Long,
+      step: Long,
+      numPartitions: Option[Int],
+      failureHint: String): Unit = {
+    val expectedPlan = rangePlanCreator(start, end, step, numPartitions)
+    val actualPlan = service.getAndClearLatestInputPlan()
+    assert(actualPlan.equals(expectedPlan), failureHint)
+  }
+
+  test("range query") {
+    val ss = SparkSession
+      .builder()
+      .client(
+        SparkConnectClient
+          .builder()
+          .connectionString(s"sc://localhost:$SERVER_PORT")
+          .build())
+      .build()
+
+    ss.range(10).analyze
+    testRange(0, 10, 1, None, "Case: range(10)")
+
+    ss.range(0, 20).analyze
+    testRange(0, 20, 1, None, "Case: range(0, 20)")
+
+    ss.range(6, 20, 3).analyze
+    testRange(6, 20, 3, None, "Case: range(6, 20, 3)")
+
+    ss.range(10, 100, 5, 2).analyze
+    testRange(10, 100, 5, Some(2), "Case: range(6, 20, 3, Some(2))")
+  }
+
+}
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
index 1229a91aa54..f3caba28ffd 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
@@ -24,6 +24,7 @@ import io.grpc.stub.StreamObserver
 import org.scalatest.BeforeAndAfterEach
 import org.scalatest.funsuite.AnyFunSuite // scalastyle:ignore funsuite
 
+import org.apache.spark.connect.proto
 import org.apache.spark.connect.proto.{AnalyzePlanRequest, 
AnalyzePlanResponse, SparkConnectServiceGrpc}
 import org.apache.spark.sql.connect.common.config.ConnectCommon
 
@@ -151,11 +152,20 @@ class SparkConnectClientSuite
 
 class DummySparkConnectService() extends 
SparkConnectServiceGrpc.SparkConnectServiceImplBase {
 
+  private var inputPlan: proto.Plan = _
+
+  private[sql] def getAndClearLatestInputPlan(): proto.Plan = {
+    val plan = inputPlan
+    inputPlan = null
+    plan
+  }
+
   override def analyzePlan(
       request: AnalyzePlanRequest,
       responseObserver: StreamObserver[AnalyzePlanResponse]): Unit = {
     // Reply with a dummy response using the same client ID
     val requestClientId = request.getClientId
+    inputPlan = request.getPlan
     val response = AnalyzePlanResponse
       .newBuilder()
       .setClientId(requestClientId)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to