This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new b34178781fdd [SPARK-47539][SQL] Make the return value of method
`castToString` be `Any => UTF8String`
b34178781fdd is described below
commit b34178781fdd67dfcfc53e5a00eea0737ee01620
Author: panbingkun <[email protected]>
AuthorDate: Mon Mar 25 17:26:50 2024 +0900
[SPARK-47539][SQL] Make the return value of method `castToString` be `Any
=> UTF8String`
### What changes were proposed in this pull request?
The pr aims to:
- make the method `castToString(from: DataType): Any => Any` to
`castToString(from: DataType): Any => UTF8String` in `ToStringBase`.
- Add UT for `ToPrettyString` to improve the `coverage` of UT.
### Why are the changes needed?
- Let the method return a UTF8String(`Any` -> `UTF8String`), which is more
intuitive.
- Currently, `ToPrettyString` lacks the corresponding `UT`.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
- Add new UT.
- Pass GA.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #45688 from panbingkun/ToPrettyString_improve.
Authored-by: panbingkun <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../sql/catalyst/expressions/ToPrettyString.scala | 2 +-
.../sql/catalyst/expressions/ToStringBase.scala | 24 ++--
.../catalyst/expressions/ToPrettyStringSuite.scala | 128 +++++++++++++++++++++
3 files changed, 141 insertions(+), 13 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToPrettyString.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToPrettyString.scala
index aea704d4b788..8db08dbbcb81 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToPrettyString.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToPrettyString.scala
@@ -51,7 +51,7 @@ case class ToPrettyString(child: Expression, timeZoneId:
Option[String] = None)
override protected def useHexFormatForBinary: Boolean = true
- private[this] lazy val castFunc: Any => Any = castToString(child.dataType)
+ private[this] lazy val castFunc: Any => UTF8String =
castToString(child.dataType)
override def eval(input: InternalRow): Any = {
val v = child.eval(input)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToStringBase.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToStringBase.scala
index 18b64fd21338..4f35072c4fc7 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToStringBase.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToStringBase.scala
@@ -47,10 +47,11 @@ trait ToStringBase { self: UnaryExpression with
TimeZoneAwareExpression =>
protected def useHexFormatForBinary: Boolean
// Makes the function accept Any type input by doing `asInstanceOf[T]`.
- @inline private def acceptAny[T](func: T => Any): Any => Any = i =>
func(i.asInstanceOf[T])
+ @inline private def acceptAny[T](func: T => UTF8String): Any => UTF8String =
+ i => func(i.asInstanceOf[T])
// Returns a function to convert a value to pretty string. The function
assumes input is not null.
- protected final def castToString(from: DataType): Any => Any = from match {
+ protected final def castToString(from: DataType): Any => UTF8String = from
match {
case CalendarIntervalType =>
acceptAny[CalendarInterval](i => UTF8String.fromString(i.toString))
case BinaryType if useHexFormatForBinary =>
@@ -72,7 +73,7 @@ trait ToStringBase { self: UnaryExpression with
TimeZoneAwareExpression =>
if (array.isNullAt(0)) {
if (nullString.nonEmpty) builder.append(nullString)
} else {
- builder.append(toUTF8String(array.get(0,
et)).asInstanceOf[UTF8String])
+ builder.append(toUTF8String(array.get(0, et)))
}
var i = 1
while (i < array.numElements()) {
@@ -81,7 +82,7 @@ trait ToStringBase { self: UnaryExpression with
TimeZoneAwareExpression =>
if (nullString.nonEmpty) builder.append(" " + nullString)
} else {
builder.append(" ")
- builder.append(toUTF8String(array.get(i,
et)).asInstanceOf[UTF8String])
+ builder.append(toUTF8String(array.get(i, et)))
}
i += 1
}
@@ -98,25 +99,24 @@ trait ToStringBase { self: UnaryExpression with
TimeZoneAwareExpression =>
val valueArray = map.valueArray()
val keyToUTF8String = castToString(kt)
val valueToUTF8String = castToString(vt)
- builder.append(keyToUTF8String(keyArray.get(0,
kt)).asInstanceOf[UTF8String])
+ builder.append(keyToUTF8String(keyArray.get(0, kt)))
builder.append(" ->")
if (valueArray.isNullAt(0)) {
if (nullString.nonEmpty) builder.append(" " + nullString)
} else {
builder.append(" ")
- builder.append(valueToUTF8String(valueArray.get(0,
vt)).asInstanceOf[UTF8String])
+ builder.append(valueToUTF8String(valueArray.get(0, vt)))
}
var i = 1
while (i < map.numElements()) {
builder.append(", ")
- builder.append(keyToUTF8String(keyArray.get(i,
kt)).asInstanceOf[UTF8String])
+ builder.append(keyToUTF8String(keyArray.get(i, kt)))
builder.append(" ->")
if (valueArray.isNullAt(i)) {
if (nullString.nonEmpty) builder.append(" " + nullString)
} else {
builder.append(" ")
- builder.append(valueToUTF8String(valueArray.get(i, vt))
- .asInstanceOf[UTF8String])
+ builder.append(valueToUTF8String(valueArray.get(i, vt)))
}
i += 1
}
@@ -134,7 +134,7 @@ trait ToStringBase { self: UnaryExpression with
TimeZoneAwareExpression =>
if (row.isNullAt(0)) {
if (nullString.nonEmpty) builder.append(nullString)
} else {
- builder.append(toUTF8StringFuncs(0)(row.get(0,
st(0))).asInstanceOf[UTF8String])
+ builder.append(toUTF8StringFuncs(0)(row.get(0, st(0))))
}
var i = 1
while (i < row.numFields) {
@@ -143,7 +143,7 @@ trait ToStringBase { self: UnaryExpression with
TimeZoneAwareExpression =>
if (nullString.nonEmpty) builder.append(" " + nullString)
} else {
builder.append(" ")
- builder.append(toUTF8StringFuncs(i)(row.get(i,
st(i))).asInstanceOf[UTF8String])
+ builder.append(toUTF8StringFuncs(i)(row.get(i, st(i))))
}
i += 1
}
@@ -162,7 +162,7 @@ trait ToStringBase { self: UnaryExpression with
TimeZoneAwareExpression =>
IntervalUtils.toDayTimeIntervalString(i, ANSI_STYLE, startField,
endField)))
case _: DecimalType if useDecimalPlainString =>
acceptAny[Decimal](d => UTF8String.fromString(d.toPlainString))
- case _: StringType => identity
+ case _: StringType => acceptAny[UTF8String](identity[UTF8String])
case _ => o => UTF8String.fromString(o.toString)
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ToPrettyStringSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ToPrettyStringSuite.scala
new file mode 100644
index 000000000000..4e043e72fad3
--- /dev/null
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ToPrettyStringSuite.scala
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.UTC_OPT
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.{UTF8String, VariantVal}
+
+class ToPrettyStringSuite extends SparkFunSuite with ExpressionEvalHelper {
+
+ test("CalendarInterval as pretty strings") {
+ checkEvaluation(
+ ToPrettyString(Cast(Literal("interval -3 month 1 day 7 hours"),
CalendarIntervalType)),
+ "-3 months 1 days 7 hours")
+ }
+
+ test("Binary as pretty strings") {
+ checkEvaluation(ToPrettyString(Cast(Literal("abcdef"), BinaryType)), "[61
62 63 64 65 66]")
+ }
+
+ test("Date as pretty strings") {
+ checkEvaluation(ToPrettyString(Cast(Literal("1980-12-17"), DateType,
UTC_OPT)), "1980-12-17")
+ }
+
+ test("Timestamp as pretty strings") {
+ checkEvaluation(
+ ToPrettyString(Cast(Literal("2012-11-30 09:19:00"), TimestampType,
UTC_OPT)),
+ "2012-11-30 01:19:00")
+ }
+
+ test("TimestampNTZ as pretty strings") {
+ checkEvaluation(ToPrettyString(Literal(1L, TimestampNTZType)), "1970-01-01
00:00:00.000001")
+ }
+
+ test("Array as pretty strings") {
+ checkEvaluation(ToPrettyString(Literal.create(Array(1, 2, 3, 4, 5))), "[1,
2, 3, 4, 5]")
+ }
+
+ test("Map as pretty strings") {
+ checkEvaluation(
+ ToPrettyString(Literal.create(Map(1 -> "a", 2 -> "b", 3 -> "c"))),
+ "{1 -> a, 2 -> b, 3 -> c}")
+ }
+
+ test("Struct as pretty strings") {
+ checkEvaluation(ToPrettyString(Literal.create((1, "a", 0.1))), "{1, a,
0.1}")
+ checkEvaluation(
+ ToPrettyString(Literal.create(Tuple2[String, String](null, null))),
+ "{NULL, NULL}"
+ )
+ }
+
+ test("YearMonthInterval as pretty strings") {
+ checkEvaluation(
+ ToPrettyString(Cast(Literal("INTERVAL '1-0' YEAR TO MONTH"),
YearMonthIntervalType())),
+ "INTERVAL '1-0' YEAR TO MONTH")
+ }
+
+ test("DayTimeInterval as pretty strings") {
+ checkEvaluation(
+ ToPrettyString(Cast(Literal("INTERVAL '1 2:03:04' DAY TO SECOND"),
DayTimeIntervalType())),
+ "INTERVAL '1 02:03:04' DAY TO SECOND")
+ }
+
+ test("Decimal as pretty strings") {
+ checkEvaluation(
+ ToPrettyString(Cast(Literal(1234.65), DecimalType(6, 2))), "1234.65")
+ }
+
+ test("String as pretty strings") {
+ checkEvaluation(ToPrettyString(Literal("s")), "s")
+ }
+
+ test("Char as pretty strings") {
+ checkEvaluation(ToPrettyString(Literal.create('a', CharType(5))), "a")
+ }
+
+ test("Byte as pretty strings") {
+ checkEvaluation(ToPrettyString(Cast(Literal(8), ByteType)), "8")
+ }
+
+ test("Short as pretty strings") {
+ checkEvaluation(ToPrettyString(Cast(Literal(8), ShortType)), "8")
+ }
+
+ test("Int as pretty strings") {
+ checkEvaluation(ToPrettyString(Literal(1)), "1")
+ }
+
+ test("Long as pretty strings") {
+ checkEvaluation(ToPrettyString(Literal(1L)), "1")
+ }
+
+ test("Float as pretty strings") {
+ checkEvaluation(ToPrettyString(Cast(Literal(8), FloatType)), "8.0")
+ }
+
+ test("Double as pretty strings") {
+ checkEvaluation(ToPrettyString(Cast(Literal(8), DoubleType)), "8.0")
+ }
+
+ test("Boolean as pretty strings") {
+ checkEvaluation(ToPrettyString(Literal(false)), "false")
+ checkEvaluation(ToPrettyString(Literal(true)), "true")
+ }
+
+ test("Variant as pretty strings") {
+ checkEvaluation(
+ ToPrettyString(Literal(new VariantVal(Array[Byte](1, 2, 3),
Array[Byte](4, 5)))),
+ UTF8String.fromBytes(Array[Byte](1, 2, 3)).toString)
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]