This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 24f0c45dc11 [SPARK-42560][CONNECT] Add ColumnName class
24f0c45dc11 is described below

commit 24f0c45dc11eb7ac1ee43ebd630cdb325da30326
Author: Herman van Hovell <her...@databricks.com>
AuthorDate: Sun Feb 26 22:38:10 2023 -0400

    [SPARK-42560][CONNECT] Add ColumnName class
    
    ### What changes were proposed in this pull request?
    This PR adds the ColumnName for the Spark Connect Scala Client. This is a 
stepping stone to implement the SQLImplicits.
    
    ### Why are the changes needed?
    API parity with the current API.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes. It adds new API.
    
    ### How was this patch tested?
    Added existing tests tot `ColumnTestSuite`.
    
    Closes #40179 from hvanhovell/SPARK-42560.
    
    Authored-by: Herman van Hovell <her...@databricks.com>
    Signed-off-by: Herman van Hovell <her...@databricks.com>
---
 .../main/scala/org/apache/spark/sql/Column.scala   | 134 ++++++++++++++++++++-
 .../org/apache/spark/sql/ColumnTestSuite.scala     |  34 +++++-
 .../sql/connect/client/CompatibilitySuite.scala    |   2 +
 3 files changed, 166 insertions(+), 4 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala
index 734535c2e14..9af0096fc1c 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
 import org.apache.spark.sql.connect.common.DataTypeProtoConverter
 import org.apache.spark.sql.expressions.Window
 import org.apache.spark.sql.functions.lit
-import org.apache.spark.sql.types.{DataType, Metadata}
+import org.apache.spark.sql.types._
 
 /**
  * A column that will be computed based on the data in a `DataFrame`.
@@ -51,6 +51,12 @@ import org.apache.spark.sql.types.{DataType, Metadata}
  */
 class Column private[sql] (private[sql] val expr: proto.Expression) extends 
Logging {
 
+  private[sql] def this(name: String, planId: Option[Long]) =
+    this(Column.nameToExpression(name, planId))
+
+  private[sql] def this(name: String) =
+    this(name, None)
+
   private def fn(name: String): Column = Column.fn(name, this)
   private def fn(name: String, other: Column): Column = Column.fn(name, this, 
other)
   private def fn(name: String, other: Any): Column = Column.fn(name, this, 
lit(other))
@@ -1270,9 +1276,12 @@ class Column private[sql] (private[sql] val expr: 
proto.Expression) extends Logg
 
 private[sql] object Column {
 
-  def apply(name: String): Column = Column(name, None)
+  def apply(name: String): Column = new Column(name)
+
+  def apply(name: String, planId: Option[Long]): Column = new Column(name, 
planId)
 
-  def apply(name: String, planId: Option[Long]): Column = Column { builder =>
+  def nameToExpression(name: String, planId: Option[Long] = None): 
proto.Expression = {
+    val builder = proto.Expression.newBuilder()
     name match {
       case "*" =>
         builder.getUnresolvedStarBuilder
@@ -1282,6 +1291,7 @@ private[sql] object Column {
         val attributeBuilder = 
builder.getUnresolvedAttributeBuilder.setUnparsedIdentifier(name)
         planId.foreach(attributeBuilder.setPlanId)
     }
+    builder.build()
   }
 
   private[sql] def apply(f: proto.Expression.Builder => Unit): Column = {
@@ -1302,3 +1312,121 @@ private[sql] object Column {
         .addAllArguments(inputs.map(_.expr).asJava)
   }
 }
+
+/**
+ * A convenient class used for constructing schema.
+ *
+ * @since 3.4.0
+ */
+class ColumnName(name: String) extends Column(name) {
+
+  /**
+   * Creates a new `StructField` of type boolean.
+   * @since 3.4.0
+   */
+  def boolean: StructField = StructField(name, BooleanType)
+
+  /**
+   * Creates a new `StructField` of type byte.
+   * @since 3.4.0
+   */
+  def byte: StructField = StructField(name, ByteType)
+
+  /**
+   * Creates a new `StructField` of type short.
+   * @since 3.4.0
+   */
+  def short: StructField = StructField(name, ShortType)
+
+  /**
+   * Creates a new `StructField` of type int.
+   * @since 3.4.0
+   */
+  def int: StructField = StructField(name, IntegerType)
+
+  /**
+   * Creates a new `StructField` of type long.
+   * @since 3.4.0
+   */
+  def long: StructField = StructField(name, LongType)
+
+  /**
+   * Creates a new `StructField` of type float.
+   * @since 3.4.0
+   */
+  def float: StructField = StructField(name, FloatType)
+
+  /**
+   * Creates a new `StructField` of type double.
+   * @since 3.4.0
+   */
+  def double: StructField = StructField(name, DoubleType)
+
+  /**
+   * Creates a new `StructField` of type string.
+   * @since 3.4.0
+   */
+  def string: StructField = StructField(name, StringType)
+
+  /**
+   * Creates a new `StructField` of type date.
+   * @since 3.4.0
+   */
+  def date: StructField = StructField(name, DateType)
+
+  /**
+   * Creates a new `StructField` of type decimal.
+   * @since 3.4.0
+   */
+  def decimal: StructField = StructField(name, DecimalType.USER_DEFAULT)
+
+  /**
+   * Creates a new `StructField` of type decimal.
+   * @since 3.4.0
+   */
+  def decimal(precision: Int, scale: Int): StructField =
+    StructField(name, DecimalType(precision, scale))
+
+  /**
+   * Creates a new `StructField` of type timestamp.
+   * @since 3.4.0
+   */
+  def timestamp: StructField = StructField(name, TimestampType)
+
+  /**
+   * Creates a new `StructField` of type binary.
+   * @since 3.4.0
+   */
+  def binary: StructField = StructField(name, BinaryType)
+
+  /**
+   * Creates a new `StructField` of type array.
+   * @since 3.4.0
+   */
+  def array(dataType: DataType): StructField = StructField(name, 
ArrayType(dataType))
+
+  /**
+   * Creates a new `StructField` of type map.
+   * @since 3.4.0
+   */
+  def map(keyType: DataType, valueType: DataType): StructField =
+    map(MapType(keyType, valueType))
+
+  /**
+   * Creates a new `StructField` of type map.
+   * @since 3.4.0
+   */
+  def map(mapType: MapType): StructField = StructField(name, mapType)
+
+  /**
+   * Creates a new `StructField` of type struct.
+   * @since 3.4.0
+   */
+  def struct(fields: StructField*): StructField = struct(StructType(fields))
+
+  /**
+   * Creates a new `StructField` of type struct.
+   * @since 3.4.0
+   */
+  def struct(structType: StructType): StructField = StructField(name, 
structType)
+}
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ColumnTestSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ColumnTestSuite.scala
index 1b4926ac3d0..0d361fe1007 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ColumnTestSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ColumnTestSuite.scala
@@ -22,7 +22,7 @@ import scala.collection.JavaConverters._
 
 import org.apache.spark.sql.{functions => fn}
 import org.apache.spark.sql.connect.client.util.ConnectFunSuite
-import org.apache.spark.sql.types.StringType
+import org.apache.spark.sql.types._
 
 /**
  * Tests for client local Column behavior.
@@ -176,4 +176,36 @@ class ColumnTestSuite extends ConnectFunSuite {
       assert(explain1.contains(fragment))
     }
   }
+
+  private def testColName(dataType: DataType, f: ColumnName => StructField): 
Unit = {
+    test("ColumnName " + dataType.catalogString) {
+      val actual = f(new ColumnName("col"))
+      val expected = StructField("col", dataType)
+      assert(actual === expected)
+    }
+  }
+
+  testColName(BooleanType, _.boolean)
+  testColName(ByteType, _.byte)
+  testColName(ShortType, _.short)
+  testColName(IntegerType, _.int)
+  testColName(LongType, _.long)
+  testColName(FloatType, _.float)
+  testColName(DoubleType, _.double)
+  testColName(DecimalType.USER_DEFAULT, _.decimal)
+  testColName(DecimalType(20, 10), _.decimal(20, 10))
+  testColName(DateType, _.date)
+  testColName(TimestampType, _.timestamp)
+  testColName(StringType, _.string)
+  testColName(BinaryType, _.binary)
+  testColName(ArrayType(IntegerType), _.array(IntegerType))
+
+  private val mapType = MapType(StringType, StringType)
+  testColName(mapType, _.map(mapType))
+  testColName(MapType(StringType, IntegerType), _.map(StringType, IntegerType))
+
+  private val structType1 = new StructType().add("a", "int").add("b", "string")
+  private val structType2 = structType1.add("c", "binary")
+  testColName(structType1, _.struct(structType1))
+  testColName(structType2, _.struct(structType2.fields: _*))
 }
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CompatibilitySuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CompatibilitySuite.scala
index 35cecaa20d7..3f3ee7d04d4 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CompatibilitySuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CompatibilitySuite.scala
@@ -72,6 +72,7 @@ class CompatibilitySuite extends ConnectFunSuite {
     val allProblems = mima.collectProblems(sqlJar, clientJar, List.empty)
     val includedRules = Seq(
       IncludeByName("org.apache.spark.sql.Column.*"),
+      IncludeByName("org.apache.spark.sql.ColumnName.*"),
       IncludeByName("org.apache.spark.sql.DataFrame.*"),
       IncludeByName("org.apache.spark.sql.DataFrameReader.*"),
       IncludeByName("org.apache.spark.sql.DataFrameWriter.*"),
@@ -156,6 +157,7 @@ class CompatibilitySuite extends ConnectFunSuite {
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.typedLit"),
 
       // RelationalGroupedDataset
+      
ProblemFilters.exclude[Problem]("org.apache.spark.sql.RelationalGroupedDataset.apply"),
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.RelationalGroupedDataset.as"),
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.RelationalGroupedDataset.pivot"),
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.RelationalGroupedDataset.this"),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to