This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new 103eca60102 [SPARK-42605][CONNECT] Add TypedColumn
103eca60102 is described below

commit 103eca60102d943fd083bf029db1d0e7f22d67ff
Author: Herman van Hovell <her...@databricks.com>
AuthorDate: Mon Feb 27 15:13:36 2023 -0400

    [SPARK-42605][CONNECT] Add TypedColumn
    
    ### What changes were proposed in this pull request?
    This PR adds TypedColumn to the Spark Connect Scala Client. We also add one 
of the typed select methods for Dataset, and typed count function.
    
    ### Why are the changes needed?
    API Parity.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes.
    
    ### How was this patch tested?
    Added tests to PlanGenerationTestSuite and ClientE2EtestSuite.
    
    Closes #40197 from hvanhovell/SPARK-42605.
    
    Authored-by: Herman van Hovell <her...@databricks.com>
    Signed-off-by: Herman van Hovell <her...@databricks.com>
    (cherry picked from commit 7f64ec302420652932ff515c325ba37938f0b175)
    Signed-off-by: Herman van Hovell <her...@databricks.com>
---
 .../main/scala/org/apache/spark/sql/Column.scala   |  40 +++++++++++++++++++++
 .../main/scala/org/apache/spark/sql/Dataset.scala  |  25 +++++++++++++
 .../scala/org/apache/spark/sql/functions.scala     |  10 ++++++
 .../org/apache/spark/sql/ClientE2ETestSuite.scala  |  40 +++++++++++++++------
 .../apache/spark/sql/PlanGenerationTestSuite.scala |  10 ++++++
 .../sql/connect/client/CompatibilitySuite.scala    |  10 +++---
 .../explain-results/function_count_typed.explain   |   2 ++
 .../explain-results/select_typed_1-arg.explain     |   3 ++
 .../query-tests/queries/function_count_typed.json  |  25 +++++++++++++
 .../queries/function_count_typed.proto.bin         | Bin 0 -> 66 bytes
 .../query-tests/queries/select_typed_1-arg.json    |  39 ++++++++++++++++++++
 .../queries/select_typed_1-arg.proto.bin           | Bin 0 -> 98 bytes
 12 files changed, 189 insertions(+), 15 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala
index 9af0096fc1c..c39d5c9757e 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala
@@ -22,6 +22,7 @@ import org.apache.spark.connect.proto
 import org.apache.spark.connect.proto.Expression.SortOrder.NullOrdering
 import org.apache.spark.connect.proto.Expression.SortOrder.SortDirection
 import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
 import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
 import org.apache.spark.sql.connect.common.DataTypeProtoConverter
 import org.apache.spark.sql.expressions.Window
@@ -70,6 +71,17 @@ class Column private[sql] (private[sql] val expr: 
proto.Expression) extends Logg
 
   override def hashCode: Int = expr.hashCode()
 
+  /**
+   * Provides a type hint about the expected return value of this column. This 
information can be
+   * used by operations such as `select` on a [[Dataset]] to automatically 
convert the results
+   * into the correct JVM types.
+   * @since 3.4.0
+   */
+  def as[U: Encoder]: TypedColumn[Any, U] = {
+    val encoder = implicitly[Encoder[U]].asInstanceOf[AgnosticEncoder[U]]
+    new TypedColumn[Any, U](expr, encoder)
+  }
+
   /**
    * Extracts a value or values from a complex type. The following types of 
extraction are
    * supported:
@@ -1430,3 +1442,31 @@ class ColumnName(name: String) extends Column(name) {
    */
   def struct(structType: StructType): StructField = StructField(name, 
structType)
 }
+
+/**
+ * A [[Column]] where an [[Encoder]] has been given for the expected input and 
return type. To
+ * create a [[TypedColumn]], use the `as` function on a [[Column]].
+ *
+ * @tparam T
+ *   The input type expected for this expression. Can be `Any` if the 
expression is type checked
+ *   by the analyzer instead of the compiler (i.e. `expr("sum(...)")`).
+ * @tparam U
+ *   The output type of this column.
+ *
+ * @since 3.4.0
+ */
+class TypedColumn[-T, U] private[sql] (
+    expr: proto.Expression,
+    private[sql] val encoder: AgnosticEncoder[U])
+    extends Column(expr) {
+
+  /**
+   * Gives the [[TypedColumn]] a name (alias). If the current `TypedColumn` 
has metadata
+   * associated with it, this metadata will be propagated to the new column.
+   *
+   * @group expr_ops
+   * @since 3.4.0
+   */
+  override def name(alias: String): TypedColumn[T, U] =
+    new TypedColumn[T, U](super.name(alias).expr, encoder)
+}
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
index 73de35456fc..1015d61a9c2 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -1012,6 +1012,31 @@ class Dataset[T] private[sql] (
     select(exprs.map(functions.expr): _*)
   }
 
+  /**
+   * Returns a new Dataset by computing the given [[Column]] expression for 
each element.
+   *
+   * {{{
+   *   val ds = Seq(1, 2, 3).toDS()
+   *   val newDS = ds.select(expr("value + 1").as[Int])
+   * }}}
+   *
+   * @group typedrel
+   * @since 3.4.0
+   */
+  def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = {
+    val encoder = c1.encoder
+    val expr = if (encoder.schema == encoder.dataType) {
+      functions.inline(functions.array(c1)).expr
+    } else {
+      c1.expr
+    }
+    sparkSession.newDataset(encoder) { builder =>
+      builder.getProjectBuilder
+        .setInput(plan.getRoot)
+        .addExpressions(expr)
+    }
+  }
+
   /**
    * Filters rows using the given condition.
    * {{{
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
index 94882087eee..386219a699c 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
@@ -27,6 +27,7 @@ import scala.reflect.runtime.universe.{typeTag, TypeTag}
 import com.google.protobuf.ByteString
 
 import org.apache.spark.connect.proto
+import 
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.PrimitiveLongEncoder
 import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils}
 import org.apache.spark.sql.connect.client.unsupported
 import org.apache.spark.sql.expressions.{ScalarUserDefinedFunction, 
UserDefinedFunction}
@@ -401,6 +402,15 @@ object functions {
    */
   def count(e: Column): Column = Column.fn("count", e)
 
+  /**
+   * Aggregate function: returns the number of items in a group.
+   *
+   * @group agg_funcs
+   * @since 3.4.0
+   */
+  def count(columnName: String): TypedColumn[Any, Long] =
+    count(Column(columnName)).as(PrimitiveLongEncoder)
+
   /**
    * Aggregate function: returns the number of distinct items in a group.
    *
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
index debb314f8c3..3f00f7c9c36 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
@@ -30,7 +30,7 @@ import org.scalactic.TolerantNumerics
 import org.apache.spark.SPARK_VERSION
 import org.apache.spark.sql.catalyst.ScalaReflection
 import org.apache.spark.sql.connect.client.util.{IntegrationTestUtils, 
RemoteSparkSession}
-import org.apache.spark.sql.functions.{aggregate, array, col, lit, rand, 
sequence, shuffle, transform, udf}
+import org.apache.spark.sql.functions.{aggregate, array, col, count, lit, 
rand, sequence, shuffle, struct, transform, udf}
 import org.apache.spark.sql.types._
 
 class ClientE2ETestSuite extends RemoteSparkSession {
@@ -412,16 +412,13 @@ class ClientE2ETestSuite extends RemoteSparkSession {
     }
   }
 
-  test("Dataset collect complex type") {
-    val result = spark
-      .range(3)
-      .select(
-        (col("id") / lit(10.0d)).as("b"),
-        col("id"),
-        lit("world").as("d"),
-        (col("id") % 2).cast("int").as("a"))
-      .as[MyType]
-      .collect()
+  private val generateMyTypeColumns = Seq(
+    (col("id") / lit(10.0d)).as("b"),
+    col("id"),
+    lit("world").as("d"),
+    (col("id") % 2).cast("int").as("a"))
+
+  private def validateMyTypeResult(result: Array[MyType]): Unit = {
     result.zipWithIndex.foreach { case (MyType(id, a, b), i) =>
       assert(id == i)
       assert(a == id % 2)
@@ -429,6 +426,27 @@ class ClientE2ETestSuite extends RemoteSparkSession {
     }
   }
 
+  test("Dataset collect complex type") {
+    val result = spark
+      .range(3)
+      .select(generateMyTypeColumns: _*)
+      .as[MyType]
+      .collect()
+    validateMyTypeResult(result)
+  }
+
+  test("Dataset typed select - simple column") {
+    val numRows = spark.range(1000).select(count("id")).first()
+    assert(numRows === 1000)
+  }
+
+  test("Dataset typed select - complex column") {
+    val ds = spark
+      .range(3)
+      .select(struct(generateMyTypeColumns: _*).as[MyType])
+    validateMyTypeResult(ds.collect())
+  }
+
   test("lambda functions") {
     // This test is mostly to validate lambda variables are properly resolved.
     val result = spark
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
index 67ea148cb87..52e5d892012 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
@@ -30,6 +30,7 @@ import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
 import org.apache.spark.connect.proto
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.{functions => fn}
+import org.apache.spark.sql.catalyst.ScalaReflection
 import org.apache.spark.sql.connect.client.SparkConnectClient
 import org.apache.spark.sql.connect.client.util.ConnectFunSuite
 import org.apache.spark.sql.expressions.Window
@@ -287,6 +288,11 @@ class PlanGenerationTestSuite
     simple.select(fn.col("id"))
   }
 
+  test("select typed 1-arg") {
+    val encoder = ScalaReflection.encoderFor[(Long, Int)]
+    simple.select(fn.struct(fn.col("id"), fn.col("a")).as(encoder))
+  }
+
   test("limit") {
     simple.limit(10)
   }
@@ -876,6 +882,10 @@ class PlanGenerationTestSuite
     fn.count(fn.col("a"))
   }
 
+  test("function count typed") {
+    simple.select(fn.count("a"))
+  }
+
   functionTest("countDistinct") {
     fn.countDistinct("a", "g")
   }
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CompatibilitySuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CompatibilitySuite.scala
index bb480e0ee08..ccee3b550eb 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CompatibilitySuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CompatibilitySuite.scala
@@ -81,7 +81,8 @@ class CompatibilitySuite extends ConnectFunSuite {
       IncludeByName("org.apache.spark.sql.functions.*"),
       IncludeByName("org.apache.spark.sql.RelationalGroupedDataset.*"),
       IncludeByName("org.apache.spark.sql.SparkSession.*"),
-      IncludeByName("org.apache.spark.sql.RuntimeConfig.*"))
+      IncludeByName("org.apache.spark.sql.RuntimeConfig.*"),
+      IncludeByName("org.apache.spark.sql.TypedColumn.*"))
     val excludeRules = Seq(
       // Filter unsupported rules:
       // Note when muting errors for a method, checks on all overloading 
methods are also muted.
@@ -136,7 +137,6 @@ class CompatibilitySuite extends ConnectFunSuite {
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.unwrap_udt"),
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.udaf"),
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.broadcast"),
-      ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.count"),
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.typedlit"),
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.typedLit"),
 
@@ -178,10 +178,12 @@ class CompatibilitySuite extends ConnectFunSuite {
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.clearDefaultSession"),
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.getActiveSession"),
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.getDefaultSession"),
-      
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.range"),
 
       // RuntimeConfig
-      
ProblemFilters.exclude[Problem]("org.apache.spark.sql.RuntimeConfig.this"))
+      
ProblemFilters.exclude[Problem]("org.apache.spark.sql.RuntimeConfig.this"),
+
+      // TypedColumn
+      ProblemFilters.exclude[Problem]("org.apache.spark.sql.TypedColumn.this"))
     val problems = allProblems
       .filter { p =>
         includedRules.exists(rule => rule(p))
diff --git 
a/connector/connect/common/src/test/resources/query-tests/explain-results/function_count_typed.explain
 
b/connector/connect/common/src/test/resources/query-tests/explain-results/function_count_typed.explain
new file mode 100644
index 00000000000..200513a1181
--- /dev/null
+++ 
b/connector/connect/common/src/test/resources/query-tests/explain-results/function_count_typed.explain
@@ -0,0 +1,2 @@
+Aggregate [count(a#0) AS count(a)#0L]
++- LocalRelation <empty>, [id#0L, a#0, b#0]
diff --git 
a/connector/connect/common/src/test/resources/query-tests/explain-results/select_typed_1-arg.explain
 
b/connector/connect/common/src/test/resources/query-tests/explain-results/select_typed_1-arg.explain
new file mode 100644
index 00000000000..64017a5e073
--- /dev/null
+++ 
b/connector/connect/common/src/test/resources/query-tests/explain-results/select_typed_1-arg.explain
@@ -0,0 +1,3 @@
+Project [id#0L, a#0]
++- Generate inline(array(struct(id, id#0L, a, a#0))), false, [id#0L, a#0]
+   +- LocalRelation <empty>, [id#0L, a#0, b#0]
diff --git 
a/connector/connect/common/src/test/resources/query-tests/queries/function_count_typed.json
 
b/connector/connect/common/src/test/resources/query-tests/queries/function_count_typed.json
new file mode 100644
index 00000000000..1c5df90b79c
--- /dev/null
+++ 
b/connector/connect/common/src/test/resources/query-tests/queries/function_count_typed.json
@@ -0,0 +1,25 @@
+{
+  "common": {
+    "planId": "1"
+  },
+  "project": {
+    "input": {
+      "common": {
+        "planId": "0"
+      },
+      "localRelation": {
+        "schema": "struct\u003cid:bigint,a:int,b:double\u003e"
+      }
+    },
+    "expressions": [{
+      "unresolvedFunction": {
+        "functionName": "count",
+        "arguments": [{
+          "unresolvedAttribute": {
+            "unparsedIdentifier": "a"
+          }
+        }]
+      }
+    }]
+  }
+}
\ No newline at end of file
diff --git 
a/connector/connect/common/src/test/resources/query-tests/queries/function_count_typed.proto.bin
 
b/connector/connect/common/src/test/resources/query-tests/queries/function_count_typed.proto.bin
new file mode 100644
index 00000000000..44b613eb40c
Binary files /dev/null and 
b/connector/connect/common/src/test/resources/query-tests/queries/function_count_typed.proto.bin
 differ
diff --git 
a/connector/connect/common/src/test/resources/query-tests/queries/select_typed_1-arg.json
 
b/connector/connect/common/src/test/resources/query-tests/queries/select_typed_1-arg.json
new file mode 100644
index 00000000000..90ef62c5f41
--- /dev/null
+++ 
b/connector/connect/common/src/test/resources/query-tests/queries/select_typed_1-arg.json
@@ -0,0 +1,39 @@
+{
+  "common": {
+    "planId": "1"
+  },
+  "project": {
+    "input": {
+      "common": {
+        "planId": "0"
+      },
+      "localRelation": {
+        "schema": "struct\u003cid:bigint,a:int,b:double\u003e"
+      }
+    },
+    "expressions": [{
+      "unresolvedFunction": {
+        "functionName": "inline",
+        "arguments": [{
+          "unresolvedFunction": {
+            "functionName": "array",
+            "arguments": [{
+              "unresolvedFunction": {
+                "functionName": "struct",
+                "arguments": [{
+                  "unresolvedAttribute": {
+                    "unparsedIdentifier": "id"
+                  }
+                }, {
+                  "unresolvedAttribute": {
+                    "unparsedIdentifier": "a"
+                  }
+                }]
+              }
+            }]
+          }
+        }]
+      }
+    }]
+  }
+}
\ No newline at end of file
diff --git 
a/connector/connect/common/src/test/resources/query-tests/queries/select_typed_1-arg.proto.bin
 
b/connector/connect/common/src/test/resources/query-tests/queries/select_typed_1-arg.proto.bin
new file mode 100644
index 00000000000..2273a16d4e6
Binary files /dev/null and 
b/connector/connect/common/src/test/resources/query-tests/queries/select_typed_1-arg.proto.bin
 differ


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to