This is an automated email from the ASF dual-hosted git repository. hvanhovell pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 24f0c45dc11 [SPARK-42560][CONNECT] Add ColumnName class 24f0c45dc11 is described below commit 24f0c45dc11eb7ac1ee43ebd630cdb325da30326 Author: Herman van Hovell <her...@databricks.com> AuthorDate: Sun Feb 26 22:38:10 2023 -0400 [SPARK-42560][CONNECT] Add ColumnName class ### What changes were proposed in this pull request? This PR adds the ColumnName for the Spark Connect Scala Client. This is a stepping stone to implement the SQLImplicits. ### Why are the changes needed? API parity with the current API. ### Does this PR introduce _any_ user-facing change? Yes. It adds new API. ### How was this patch tested? Added existing tests tot `ColumnTestSuite`. Closes #40179 from hvanhovell/SPARK-42560. Authored-by: Herman van Hovell <her...@databricks.com> Signed-off-by: Herman van Hovell <her...@databricks.com> --- .../main/scala/org/apache/spark/sql/Column.scala | 134 ++++++++++++++++++++- .../org/apache/spark/sql/ColumnTestSuite.scala | 34 +++++- .../sql/connect/client/CompatibilitySuite.scala | 2 + 3 files changed, 166 insertions(+), 4 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala index 734535c2e14..9af0096fc1c 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.connect.common.DataTypeProtoConverter import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.lit -import org.apache.spark.sql.types.{DataType, Metadata} +import org.apache.spark.sql.types._ /** * A column that will be computed based on the data in a `DataFrame`. @@ -51,6 +51,12 @@ import org.apache.spark.sql.types.{DataType, Metadata} */ class Column private[sql] (private[sql] val expr: proto.Expression) extends Logging { + private[sql] def this(name: String, planId: Option[Long]) = + this(Column.nameToExpression(name, planId)) + + private[sql] def this(name: String) = + this(name, None) + private def fn(name: String): Column = Column.fn(name, this) private def fn(name: String, other: Column): Column = Column.fn(name, this, other) private def fn(name: String, other: Any): Column = Column.fn(name, this, lit(other)) @@ -1270,9 +1276,12 @@ class Column private[sql] (private[sql] val expr: proto.Expression) extends Logg private[sql] object Column { - def apply(name: String): Column = Column(name, None) + def apply(name: String): Column = new Column(name) + + def apply(name: String, planId: Option[Long]): Column = new Column(name, planId) - def apply(name: String, planId: Option[Long]): Column = Column { builder => + def nameToExpression(name: String, planId: Option[Long] = None): proto.Expression = { + val builder = proto.Expression.newBuilder() name match { case "*" => builder.getUnresolvedStarBuilder @@ -1282,6 +1291,7 @@ private[sql] object Column { val attributeBuilder = builder.getUnresolvedAttributeBuilder.setUnparsedIdentifier(name) planId.foreach(attributeBuilder.setPlanId) } + builder.build() } private[sql] def apply(f: proto.Expression.Builder => Unit): Column = { @@ -1302,3 +1312,121 @@ private[sql] object Column { .addAllArguments(inputs.map(_.expr).asJava) } } + +/** + * A convenient class used for constructing schema. + * + * @since 3.4.0 + */ +class ColumnName(name: String) extends Column(name) { + + /** + * Creates a new `StructField` of type boolean. + * @since 3.4.0 + */ + def boolean: StructField = StructField(name, BooleanType) + + /** + * Creates a new `StructField` of type byte. + * @since 3.4.0 + */ + def byte: StructField = StructField(name, ByteType) + + /** + * Creates a new `StructField` of type short. + * @since 3.4.0 + */ + def short: StructField = StructField(name, ShortType) + + /** + * Creates a new `StructField` of type int. + * @since 3.4.0 + */ + def int: StructField = StructField(name, IntegerType) + + /** + * Creates a new `StructField` of type long. + * @since 3.4.0 + */ + def long: StructField = StructField(name, LongType) + + /** + * Creates a new `StructField` of type float. + * @since 3.4.0 + */ + def float: StructField = StructField(name, FloatType) + + /** + * Creates a new `StructField` of type double. + * @since 3.4.0 + */ + def double: StructField = StructField(name, DoubleType) + + /** + * Creates a new `StructField` of type string. + * @since 3.4.0 + */ + def string: StructField = StructField(name, StringType) + + /** + * Creates a new `StructField` of type date. + * @since 3.4.0 + */ + def date: StructField = StructField(name, DateType) + + /** + * Creates a new `StructField` of type decimal. + * @since 3.4.0 + */ + def decimal: StructField = StructField(name, DecimalType.USER_DEFAULT) + + /** + * Creates a new `StructField` of type decimal. + * @since 3.4.0 + */ + def decimal(precision: Int, scale: Int): StructField = + StructField(name, DecimalType(precision, scale)) + + /** + * Creates a new `StructField` of type timestamp. + * @since 3.4.0 + */ + def timestamp: StructField = StructField(name, TimestampType) + + /** + * Creates a new `StructField` of type binary. + * @since 3.4.0 + */ + def binary: StructField = StructField(name, BinaryType) + + /** + * Creates a new `StructField` of type array. + * @since 3.4.0 + */ + def array(dataType: DataType): StructField = StructField(name, ArrayType(dataType)) + + /** + * Creates a new `StructField` of type map. + * @since 3.4.0 + */ + def map(keyType: DataType, valueType: DataType): StructField = + map(MapType(keyType, valueType)) + + /** + * Creates a new `StructField` of type map. + * @since 3.4.0 + */ + def map(mapType: MapType): StructField = StructField(name, mapType) + + /** + * Creates a new `StructField` of type struct. + * @since 3.4.0 + */ + def struct(fields: StructField*): StructField = struct(StructType(fields)) + + /** + * Creates a new `StructField` of type struct. + * @since 3.4.0 + */ + def struct(structType: StructType): StructField = StructField(name, structType) +} diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ColumnTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ColumnTestSuite.scala index 1b4926ac3d0..0d361fe1007 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ColumnTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ColumnTestSuite.scala @@ -22,7 +22,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.{functions => fn} import org.apache.spark.sql.connect.client.util.ConnectFunSuite -import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.types._ /** * Tests for client local Column behavior. @@ -176,4 +176,36 @@ class ColumnTestSuite extends ConnectFunSuite { assert(explain1.contains(fragment)) } } + + private def testColName(dataType: DataType, f: ColumnName => StructField): Unit = { + test("ColumnName " + dataType.catalogString) { + val actual = f(new ColumnName("col")) + val expected = StructField("col", dataType) + assert(actual === expected) + } + } + + testColName(BooleanType, _.boolean) + testColName(ByteType, _.byte) + testColName(ShortType, _.short) + testColName(IntegerType, _.int) + testColName(LongType, _.long) + testColName(FloatType, _.float) + testColName(DoubleType, _.double) + testColName(DecimalType.USER_DEFAULT, _.decimal) + testColName(DecimalType(20, 10), _.decimal(20, 10)) + testColName(DateType, _.date) + testColName(TimestampType, _.timestamp) + testColName(StringType, _.string) + testColName(BinaryType, _.binary) + testColName(ArrayType(IntegerType), _.array(IntegerType)) + + private val mapType = MapType(StringType, StringType) + testColName(mapType, _.map(mapType)) + testColName(MapType(StringType, IntegerType), _.map(StringType, IntegerType)) + + private val structType1 = new StructType().add("a", "int").add("b", "string") + private val structType2 = structType1.add("c", "binary") + testColName(structType1, _.struct(structType1)) + testColName(structType2, _.struct(structType2.fields: _*)) } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CompatibilitySuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CompatibilitySuite.scala index 35cecaa20d7..3f3ee7d04d4 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CompatibilitySuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CompatibilitySuite.scala @@ -72,6 +72,7 @@ class CompatibilitySuite extends ConnectFunSuite { val allProblems = mima.collectProblems(sqlJar, clientJar, List.empty) val includedRules = Seq( IncludeByName("org.apache.spark.sql.Column.*"), + IncludeByName("org.apache.spark.sql.ColumnName.*"), IncludeByName("org.apache.spark.sql.DataFrame.*"), IncludeByName("org.apache.spark.sql.DataFrameReader.*"), IncludeByName("org.apache.spark.sql.DataFrameWriter.*"), @@ -156,6 +157,7 @@ class CompatibilitySuite extends ConnectFunSuite { ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.typedLit"), // RelationalGroupedDataset + ProblemFilters.exclude[Problem]("org.apache.spark.sql.RelationalGroupedDataset.apply"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.RelationalGroupedDataset.as"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.RelationalGroupedDataset.pivot"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.RelationalGroupedDataset.this"), --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org