This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new b34178781fdd [SPARK-47539][SQL] Make the return value of method `castToString` be `Any => UTF8String` b34178781fdd is described below commit b34178781fdd67dfcfc53e5a00eea0737ee01620 Author: panbingkun <panbing...@baidu.com> AuthorDate: Mon Mar 25 17:26:50 2024 +0900 [SPARK-47539][SQL] Make the return value of method `castToString` be `Any => UTF8String` ### What changes were proposed in this pull request? The pr aims to: - make the method `castToString(from: DataType): Any => Any` to `castToString(from: DataType): Any => UTF8String` in `ToStringBase`. - Add UT for `ToPrettyString` to improve the `coverage` of UT. ### Why are the changes needed? - Let the method return a UTF8String(`Any` -> `UTF8String`), which is more intuitive. - Currently, `ToPrettyString` lacks the corresponding `UT`. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? - Add new UT. - Pass GA. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #45688 from panbingkun/ToPrettyString_improve. Authored-by: panbingkun <panbing...@baidu.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../sql/catalyst/expressions/ToPrettyString.scala | 2 +- .../sql/catalyst/expressions/ToStringBase.scala | 24 ++-- .../catalyst/expressions/ToPrettyStringSuite.scala | 128 +++++++++++++++++++++ 3 files changed, 141 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToPrettyString.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToPrettyString.scala index aea704d4b788..8db08dbbcb81 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToPrettyString.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToPrettyString.scala @@ -51,7 +51,7 @@ case class ToPrettyString(child: Expression, timeZoneId: Option[String] = None) override protected def useHexFormatForBinary: Boolean = true - private[this] lazy val castFunc: Any => Any = castToString(child.dataType) + private[this] lazy val castFunc: Any => UTF8String = castToString(child.dataType) override def eval(input: InternalRow): Any = { val v = child.eval(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToStringBase.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToStringBase.scala index 18b64fd21338..4f35072c4fc7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToStringBase.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToStringBase.scala @@ -47,10 +47,11 @@ trait ToStringBase { self: UnaryExpression with TimeZoneAwareExpression => protected def useHexFormatForBinary: Boolean // Makes the function accept Any type input by doing `asInstanceOf[T]`. - @inline private def acceptAny[T](func: T => Any): Any => Any = i => func(i.asInstanceOf[T]) + @inline private def acceptAny[T](func: T => UTF8String): Any => UTF8String = + i => func(i.asInstanceOf[T]) // Returns a function to convert a value to pretty string. The function assumes input is not null. - protected final def castToString(from: DataType): Any => Any = from match { + protected final def castToString(from: DataType): Any => UTF8String = from match { case CalendarIntervalType => acceptAny[CalendarInterval](i => UTF8String.fromString(i.toString)) case BinaryType if useHexFormatForBinary => @@ -72,7 +73,7 @@ trait ToStringBase { self: UnaryExpression with TimeZoneAwareExpression => if (array.isNullAt(0)) { if (nullString.nonEmpty) builder.append(nullString) } else { - builder.append(toUTF8String(array.get(0, et)).asInstanceOf[UTF8String]) + builder.append(toUTF8String(array.get(0, et))) } var i = 1 while (i < array.numElements()) { @@ -81,7 +82,7 @@ trait ToStringBase { self: UnaryExpression with TimeZoneAwareExpression => if (nullString.nonEmpty) builder.append(" " + nullString) } else { builder.append(" ") - builder.append(toUTF8String(array.get(i, et)).asInstanceOf[UTF8String]) + builder.append(toUTF8String(array.get(i, et))) } i += 1 } @@ -98,25 +99,24 @@ trait ToStringBase { self: UnaryExpression with TimeZoneAwareExpression => val valueArray = map.valueArray() val keyToUTF8String = castToString(kt) val valueToUTF8String = castToString(vt) - builder.append(keyToUTF8String(keyArray.get(0, kt)).asInstanceOf[UTF8String]) + builder.append(keyToUTF8String(keyArray.get(0, kt))) builder.append(" ->") if (valueArray.isNullAt(0)) { if (nullString.nonEmpty) builder.append(" " + nullString) } else { builder.append(" ") - builder.append(valueToUTF8String(valueArray.get(0, vt)).asInstanceOf[UTF8String]) + builder.append(valueToUTF8String(valueArray.get(0, vt))) } var i = 1 while (i < map.numElements()) { builder.append(", ") - builder.append(keyToUTF8String(keyArray.get(i, kt)).asInstanceOf[UTF8String]) + builder.append(keyToUTF8String(keyArray.get(i, kt))) builder.append(" ->") if (valueArray.isNullAt(i)) { if (nullString.nonEmpty) builder.append(" " + nullString) } else { builder.append(" ") - builder.append(valueToUTF8String(valueArray.get(i, vt)) - .asInstanceOf[UTF8String]) + builder.append(valueToUTF8String(valueArray.get(i, vt))) } i += 1 } @@ -134,7 +134,7 @@ trait ToStringBase { self: UnaryExpression with TimeZoneAwareExpression => if (row.isNullAt(0)) { if (nullString.nonEmpty) builder.append(nullString) } else { - builder.append(toUTF8StringFuncs(0)(row.get(0, st(0))).asInstanceOf[UTF8String]) + builder.append(toUTF8StringFuncs(0)(row.get(0, st(0)))) } var i = 1 while (i < row.numFields) { @@ -143,7 +143,7 @@ trait ToStringBase { self: UnaryExpression with TimeZoneAwareExpression => if (nullString.nonEmpty) builder.append(" " + nullString) } else { builder.append(" ") - builder.append(toUTF8StringFuncs(i)(row.get(i, st(i))).asInstanceOf[UTF8String]) + builder.append(toUTF8StringFuncs(i)(row.get(i, st(i)))) } i += 1 } @@ -162,7 +162,7 @@ trait ToStringBase { self: UnaryExpression with TimeZoneAwareExpression => IntervalUtils.toDayTimeIntervalString(i, ANSI_STYLE, startField, endField))) case _: DecimalType if useDecimalPlainString => acceptAny[Decimal](d => UTF8String.fromString(d.toPlainString)) - case _: StringType => identity + case _: StringType => acceptAny[UTF8String](identity[UTF8String]) case _ => o => UTF8String.fromString(o.toString) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ToPrettyStringSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ToPrettyStringSuite.scala new file mode 100644 index 000000000000..4e043e72fad3 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ToPrettyStringSuite.scala @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.UTC_OPT +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.{UTF8String, VariantVal} + +class ToPrettyStringSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("CalendarInterval as pretty strings") { + checkEvaluation( + ToPrettyString(Cast(Literal("interval -3 month 1 day 7 hours"), CalendarIntervalType)), + "-3 months 1 days 7 hours") + } + + test("Binary as pretty strings") { + checkEvaluation(ToPrettyString(Cast(Literal("abcdef"), BinaryType)), "[61 62 63 64 65 66]") + } + + test("Date as pretty strings") { + checkEvaluation(ToPrettyString(Cast(Literal("1980-12-17"), DateType, UTC_OPT)), "1980-12-17") + } + + test("Timestamp as pretty strings") { + checkEvaluation( + ToPrettyString(Cast(Literal("2012-11-30 09:19:00"), TimestampType, UTC_OPT)), + "2012-11-30 01:19:00") + } + + test("TimestampNTZ as pretty strings") { + checkEvaluation(ToPrettyString(Literal(1L, TimestampNTZType)), "1970-01-01 00:00:00.000001") + } + + test("Array as pretty strings") { + checkEvaluation(ToPrettyString(Literal.create(Array(1, 2, 3, 4, 5))), "[1, 2, 3, 4, 5]") + } + + test("Map as pretty strings") { + checkEvaluation( + ToPrettyString(Literal.create(Map(1 -> "a", 2 -> "b", 3 -> "c"))), + "{1 -> a, 2 -> b, 3 -> c}") + } + + test("Struct as pretty strings") { + checkEvaluation(ToPrettyString(Literal.create((1, "a", 0.1))), "{1, a, 0.1}") + checkEvaluation( + ToPrettyString(Literal.create(Tuple2[String, String](null, null))), + "{NULL, NULL}" + ) + } + + test("YearMonthInterval as pretty strings") { + checkEvaluation( + ToPrettyString(Cast(Literal("INTERVAL '1-0' YEAR TO MONTH"), YearMonthIntervalType())), + "INTERVAL '1-0' YEAR TO MONTH") + } + + test("DayTimeInterval as pretty strings") { + checkEvaluation( + ToPrettyString(Cast(Literal("INTERVAL '1 2:03:04' DAY TO SECOND"), DayTimeIntervalType())), + "INTERVAL '1 02:03:04' DAY TO SECOND") + } + + test("Decimal as pretty strings") { + checkEvaluation( + ToPrettyString(Cast(Literal(1234.65), DecimalType(6, 2))), "1234.65") + } + + test("String as pretty strings") { + checkEvaluation(ToPrettyString(Literal("s")), "s") + } + + test("Char as pretty strings") { + checkEvaluation(ToPrettyString(Literal.create('a', CharType(5))), "a") + } + + test("Byte as pretty strings") { + checkEvaluation(ToPrettyString(Cast(Literal(8), ByteType)), "8") + } + + test("Short as pretty strings") { + checkEvaluation(ToPrettyString(Cast(Literal(8), ShortType)), "8") + } + + test("Int as pretty strings") { + checkEvaluation(ToPrettyString(Literal(1)), "1") + } + + test("Long as pretty strings") { + checkEvaluation(ToPrettyString(Literal(1L)), "1") + } + + test("Float as pretty strings") { + checkEvaluation(ToPrettyString(Cast(Literal(8), FloatType)), "8.0") + } + + test("Double as pretty strings") { + checkEvaluation(ToPrettyString(Cast(Literal(8), DoubleType)), "8.0") + } + + test("Boolean as pretty strings") { + checkEvaluation(ToPrettyString(Literal(false)), "false") + checkEvaluation(ToPrettyString(Literal(true)), "true") + } + + test("Variant as pretty strings") { + checkEvaluation( + ToPrettyString(Literal(new VariantVal(Array[Byte](1, 2, 3), Array[Byte](4, 5)))), + UTF8String.fromBytes(Array[Byte](1, 2, 3)).toString) + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org