This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new b34178781fdd [SPARK-47539][SQL] Make the return value of method 
`castToString` be `Any => UTF8String`
b34178781fdd is described below

commit b34178781fdd67dfcfc53e5a00eea0737ee01620
Author: panbingkun <panbing...@baidu.com>
AuthorDate: Mon Mar 25 17:26:50 2024 +0900

    [SPARK-47539][SQL] Make the return value of method `castToString` be `Any 
=> UTF8String`
    
    ### What changes were proposed in this pull request?
    The pr aims to:
    - make the method `castToString(from: DataType): Any => Any` to 
`castToString(from: DataType): Any => UTF8String` in `ToStringBase`.
    - Add UT for `ToPrettyString` to improve the `coverage` of UT.
    
    ### Why are the changes needed?
    - Let the method return a UTF8String(`Any` -> `UTF8String`), which is more 
intuitive.
    - Currently, `ToPrettyString` lacks the corresponding `UT`.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    - Add new UT.
    - Pass GA.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #45688 from panbingkun/ToPrettyString_improve.
    
    Authored-by: panbingkun <panbing...@baidu.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../sql/catalyst/expressions/ToPrettyString.scala  |   2 +-
 .../sql/catalyst/expressions/ToStringBase.scala    |  24 ++--
 .../catalyst/expressions/ToPrettyStringSuite.scala | 128 +++++++++++++++++++++
 3 files changed, 141 insertions(+), 13 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToPrettyString.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToPrettyString.scala
index aea704d4b788..8db08dbbcb81 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToPrettyString.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToPrettyString.scala
@@ -51,7 +51,7 @@ case class ToPrettyString(child: Expression, timeZoneId: 
Option[String] = None)
 
   override protected def useHexFormatForBinary: Boolean = true
 
-  private[this] lazy val castFunc: Any => Any = castToString(child.dataType)
+  private[this] lazy val castFunc: Any => UTF8String = 
castToString(child.dataType)
 
   override def eval(input: InternalRow): Any = {
     val v = child.eval(input)
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToStringBase.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToStringBase.scala
index 18b64fd21338..4f35072c4fc7 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToStringBase.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToStringBase.scala
@@ -47,10 +47,11 @@ trait ToStringBase { self: UnaryExpression with 
TimeZoneAwareExpression =>
   protected def useHexFormatForBinary: Boolean
 
   // Makes the function accept Any type input by doing `asInstanceOf[T]`.
-  @inline private def acceptAny[T](func: T => Any): Any => Any = i => 
func(i.asInstanceOf[T])
+  @inline private def acceptAny[T](func: T => UTF8String): Any => UTF8String =
+    i => func(i.asInstanceOf[T])
 
   // Returns a function to convert a value to pretty string. The function 
assumes input is not null.
-  protected final def castToString(from: DataType): Any => Any = from match {
+  protected final def castToString(from: DataType): Any => UTF8String = from 
match {
     case CalendarIntervalType =>
       acceptAny[CalendarInterval](i => UTF8String.fromString(i.toString))
     case BinaryType if useHexFormatForBinary =>
@@ -72,7 +73,7 @@ trait ToStringBase { self: UnaryExpression with 
TimeZoneAwareExpression =>
           if (array.isNullAt(0)) {
             if (nullString.nonEmpty) builder.append(nullString)
           } else {
-            builder.append(toUTF8String(array.get(0, 
et)).asInstanceOf[UTF8String])
+            builder.append(toUTF8String(array.get(0, et)))
           }
           var i = 1
           while (i < array.numElements()) {
@@ -81,7 +82,7 @@ trait ToStringBase { self: UnaryExpression with 
TimeZoneAwareExpression =>
               if (nullString.nonEmpty) builder.append(" " + nullString)
             } else {
               builder.append(" ")
-              builder.append(toUTF8String(array.get(i, 
et)).asInstanceOf[UTF8String])
+              builder.append(toUTF8String(array.get(i, et)))
             }
             i += 1
           }
@@ -98,25 +99,24 @@ trait ToStringBase { self: UnaryExpression with 
TimeZoneAwareExpression =>
           val valueArray = map.valueArray()
           val keyToUTF8String = castToString(kt)
           val valueToUTF8String = castToString(vt)
-          builder.append(keyToUTF8String(keyArray.get(0, 
kt)).asInstanceOf[UTF8String])
+          builder.append(keyToUTF8String(keyArray.get(0, kt)))
           builder.append(" ->")
           if (valueArray.isNullAt(0)) {
             if (nullString.nonEmpty) builder.append(" " + nullString)
           } else {
             builder.append(" ")
-            builder.append(valueToUTF8String(valueArray.get(0, 
vt)).asInstanceOf[UTF8String])
+            builder.append(valueToUTF8String(valueArray.get(0, vt)))
           }
           var i = 1
           while (i < map.numElements()) {
             builder.append(", ")
-            builder.append(keyToUTF8String(keyArray.get(i, 
kt)).asInstanceOf[UTF8String])
+            builder.append(keyToUTF8String(keyArray.get(i, kt)))
             builder.append(" ->")
             if (valueArray.isNullAt(i)) {
               if (nullString.nonEmpty) builder.append(" " + nullString)
             } else {
               builder.append(" ")
-              builder.append(valueToUTF8String(valueArray.get(i, vt))
-                .asInstanceOf[UTF8String])
+              builder.append(valueToUTF8String(valueArray.get(i, vt)))
             }
             i += 1
           }
@@ -134,7 +134,7 @@ trait ToStringBase { self: UnaryExpression with 
TimeZoneAwareExpression =>
           if (row.isNullAt(0)) {
             if (nullString.nonEmpty) builder.append(nullString)
           } else {
-            builder.append(toUTF8StringFuncs(0)(row.get(0, 
st(0))).asInstanceOf[UTF8String])
+            builder.append(toUTF8StringFuncs(0)(row.get(0, st(0))))
           }
           var i = 1
           while (i < row.numFields) {
@@ -143,7 +143,7 @@ trait ToStringBase { self: UnaryExpression with 
TimeZoneAwareExpression =>
               if (nullString.nonEmpty) builder.append(" " + nullString)
             } else {
               builder.append(" ")
-              builder.append(toUTF8StringFuncs(i)(row.get(i, 
st(i))).asInstanceOf[UTF8String])
+              builder.append(toUTF8StringFuncs(i)(row.get(i, st(i))))
             }
             i += 1
           }
@@ -162,7 +162,7 @@ trait ToStringBase { self: UnaryExpression with 
TimeZoneAwareExpression =>
         IntervalUtils.toDayTimeIntervalString(i, ANSI_STYLE, startField, 
endField)))
     case _: DecimalType if useDecimalPlainString =>
       acceptAny[Decimal](d => UTF8String.fromString(d.toPlainString))
-    case _: StringType => identity
+    case _: StringType => acceptAny[UTF8String](identity[UTF8String])
     case _ => o => UTF8String.fromString(o.toString)
   }
 
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ToPrettyStringSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ToPrettyStringSuite.scala
new file mode 100644
index 000000000000..4e043e72fad3
--- /dev/null
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ToPrettyStringSuite.scala
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.UTC_OPT
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.{UTF8String, VariantVal}
+
+class ToPrettyStringSuite extends SparkFunSuite with ExpressionEvalHelper {
+
+  test("CalendarInterval as pretty strings") {
+    checkEvaluation(
+      ToPrettyString(Cast(Literal("interval -3 month 1 day 7 hours"), 
CalendarIntervalType)),
+      "-3 months 1 days 7 hours")
+  }
+
+  test("Binary as pretty strings") {
+    checkEvaluation(ToPrettyString(Cast(Literal("abcdef"), BinaryType)), "[61 
62 63 64 65 66]")
+  }
+
+  test("Date as pretty strings") {
+    checkEvaluation(ToPrettyString(Cast(Literal("1980-12-17"), DateType, 
UTC_OPT)), "1980-12-17")
+  }
+
+  test("Timestamp as pretty strings") {
+    checkEvaluation(
+      ToPrettyString(Cast(Literal("2012-11-30 09:19:00"), TimestampType, 
UTC_OPT)),
+      "2012-11-30 01:19:00")
+  }
+
+  test("TimestampNTZ as pretty strings") {
+    checkEvaluation(ToPrettyString(Literal(1L, TimestampNTZType)), "1970-01-01 
00:00:00.000001")
+  }
+
+  test("Array as pretty strings") {
+    checkEvaluation(ToPrettyString(Literal.create(Array(1, 2, 3, 4, 5))), "[1, 
2, 3, 4, 5]")
+  }
+
+  test("Map as pretty strings") {
+    checkEvaluation(
+      ToPrettyString(Literal.create(Map(1 -> "a", 2 -> "b", 3 -> "c"))),
+      "{1 -> a, 2 -> b, 3 -> c}")
+  }
+
+  test("Struct as pretty strings") {
+    checkEvaluation(ToPrettyString(Literal.create((1, "a", 0.1))), "{1, a, 
0.1}")
+    checkEvaluation(
+      ToPrettyString(Literal.create(Tuple2[String, String](null, null))),
+      "{NULL, NULL}"
+    )
+  }
+
+  test("YearMonthInterval as pretty strings") {
+    checkEvaluation(
+      ToPrettyString(Cast(Literal("INTERVAL '1-0' YEAR TO MONTH"), 
YearMonthIntervalType())),
+      "INTERVAL '1-0' YEAR TO MONTH")
+  }
+
+  test("DayTimeInterval as pretty strings") {
+    checkEvaluation(
+      ToPrettyString(Cast(Literal("INTERVAL '1 2:03:04' DAY TO SECOND"), 
DayTimeIntervalType())),
+      "INTERVAL '1 02:03:04' DAY TO SECOND")
+  }
+
+  test("Decimal as pretty strings") {
+    checkEvaluation(
+      ToPrettyString(Cast(Literal(1234.65), DecimalType(6, 2))), "1234.65")
+  }
+
+  test("String as pretty strings") {
+    checkEvaluation(ToPrettyString(Literal("s")), "s")
+  }
+
+  test("Char as pretty strings") {
+    checkEvaluation(ToPrettyString(Literal.create('a', CharType(5))), "a")
+  }
+
+  test("Byte as pretty strings") {
+    checkEvaluation(ToPrettyString(Cast(Literal(8), ByteType)), "8")
+  }
+
+  test("Short as pretty strings") {
+    checkEvaluation(ToPrettyString(Cast(Literal(8), ShortType)), "8")
+  }
+
+  test("Int as pretty strings") {
+    checkEvaluation(ToPrettyString(Literal(1)), "1")
+  }
+
+  test("Long as pretty strings") {
+    checkEvaluation(ToPrettyString(Literal(1L)), "1")
+  }
+
+  test("Float as pretty strings") {
+    checkEvaluation(ToPrettyString(Cast(Literal(8), FloatType)), "8.0")
+  }
+
+  test("Double as pretty strings") {
+    checkEvaluation(ToPrettyString(Cast(Literal(8), DoubleType)), "8.0")
+  }
+
+  test("Boolean as pretty strings") {
+    checkEvaluation(ToPrettyString(Literal(false)), "false")
+    checkEvaluation(ToPrettyString(Literal(true)), "true")
+  }
+
+  test("Variant as pretty strings") {
+    checkEvaluation(
+      ToPrettyString(Literal(new VariantVal(Array[Byte](1, 2, 3), 
Array[Byte](4, 5)))),
+      UTF8String.fromBytes(Array[Byte](1, 2, 3)).toString)
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to