This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 7a4114cdf994 [SPARK-50644][FOLLOWUP][SQL] Fix scalar cast in the
shredded reader
7a4114cdf994 is described below
commit 7a4114cdf9948480d86ec7c870f7a2fa4939cc67
Author: Chenhao Li <[email protected]>
AuthorDate: Thu Dec 26 15:24:23 2024 +0800
[SPARK-50644][FOLLOWUP][SQL] Fix scalar cast in the shredded reader
### What changes were proposed in this pull request?
We mostly use the `Cast` expression to implement the cast, but we need some
custom logic for certain type combinations. We already have special handling
for `long/decimal -> timestamp` in `VariantGet.cast`, so we should do the same
thing in `ScalarCastHelper` to ensure consistency. `ScalarCastHelper` also
needs special handling for `decimal -> string` to strip any trailing zeros.
### Why are the changes needed?
To ensure that cast on shredded variant has the same semantics as cast on
unshredded variant.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Unit tests. They would fail without the change.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #49293 from chenhao-db/fix_shredded_cast.
Authored-by: Chenhao Li <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../expressions/variant/variantExpressions.scala | 36 ++++++++++++++--------
.../datasources/parquet/SparkShreddingUtils.scala | 26 +++++++++++++++-
.../apache/spark/sql/VariantShreddingSuite.scala | 31 +++++++++++++++++++
3 files changed, 80 insertions(+), 13 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
index c19df82e6576..ba910b8c7e5f 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
@@ -427,24 +427,15 @@ case object VariantGet {
messageParameters = Map("id" -> v.getTypeInfo.toString)
)
}
- // We mostly use the `Cast` expression to implement the cast. However,
`Cast` silently
- // ignores the overflow in the long/decimal -> timestamp cast, and we
want to enforce
- // strict overflow checks.
input.dataType match {
case LongType if dataType == TimestampType =>
- try Math.multiplyExact(input.value.asInstanceOf[Long],
MICROS_PER_SECOND)
+ try castLongToTimestamp(input.value.asInstanceOf[Long])
catch {
case _: ArithmeticException => invalidCast()
}
case _: DecimalType if dataType == TimestampType =>
- try {
- input.value
- .asInstanceOf[Decimal]
- .toJavaBigDecimal
- .multiply(new java.math.BigDecimal(MICROS_PER_SECOND))
- .toBigInteger
- .longValueExact()
- } catch {
+ try castDecimalToTimestamp(input.value.asInstanceOf[Decimal])
+ catch {
case _: ArithmeticException => invalidCast()
}
case _ =>
@@ -497,6 +488,27 @@ case object VariantGet {
}
}
}
+
+ // We mostly use the `Cast` expression to implement the cast, but we need
some custom logic for
+ // certain type combinations.
+ //
+ // `castLongToTimestamp/castDecimalToTimestamp`: `Cast` silently ignores the
overflow in the
+ // long/decimal -> timestamp cast, and we want to enforce strict overflow
checks. They both throw
+ // an `ArithmeticException` when overflow happens.
+ def castLongToTimestamp(input: Long): Long =
+ Math.multiplyExact(input, MICROS_PER_SECOND)
+
+ def castDecimalToTimestamp(input: Decimal): Long = {
+ val multiplier = new java.math.BigDecimal(MICROS_PER_SECOND)
+ input.toJavaBigDecimal.multiply(multiplier).toBigInteger.longValueExact()
+ }
+
+ // Cast decimal to string, but strip any trailing zeros. We don't have to
call it if the decimal
+ // is returned by `Variant.getDecimal`, which already strips any trailing
zeros. But we need it
+ // if the decimal is produced by Spark internally, e.g., on a shredded
decimal produced by the
+ // Spark Parquet reader.
+ def castDecimalToString(input: Decimal): UTF8String =
+
UTF8String.fromString(input.toJavaBigDecimal.stripTrailingZeros.toPlainString)
}
abstract class ParseJsonExpressionBuilderBase(failOnError: Boolean) extends
ExpressionBuilder {
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala
index a83ca78455fa..34c167aea363 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala
@@ -111,7 +111,31 @@ case class ScalarCastHelper(
} else {
""
}
- if (cast != null) {
+ val customCast = (child.dataType, dataType) match {
+ case (_: LongType, _: TimestampType) => "castLongToTimestamp"
+ case (_: DecimalType, _: TimestampType) => "castDecimalToTimestamp"
+ case (_: DecimalType, _: StringType) => "castDecimalToString"
+ case _ => null
+ }
+ if (customCast != null) {
+ val childCode = child.genCode(ctx)
+ // We can avoid the try-catch block for decimal -> string, but the
performance benefit is
+ // little. We can also be more specific in the exception type, like
catching
+ // `ArithmeticException` instead of `Exception`, but it is unnecessary.
The `try_cast` codegen
+ // also catches `Exception` instead of specific exceptions.
+ val code = code"""
+ ${childCode.code}
+ boolean ${ev.isNull} = false;
+ ${CodeGenerator.javaType(dataType)} ${ev.value} =
${CodeGenerator.defaultValue(dataType)};
+ try {
+ ${ev.value} =
${classOf[VariantGet].getName}.$customCast(${childCode.value});
+ } catch (Exception e) {
+ ${ev.isNull} = true;
+ $invalidCastCode
+ }
+ """
+ ev.copy(code = code)
+ } else if (cast != null) {
val castCode = cast.genCode(ctx)
val code = code"""
${castCode.code}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/VariantShreddingSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/VariantShreddingSuite.scala
index b6623bb57a71..3443028ba45b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/VariantShreddingSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantShreddingSuite.scala
@@ -24,6 +24,8 @@ import java.time.LocalDateTime
import org.apache.spark.SparkThrowable
import org.apache.spark.sql.catalyst.InternalRow
import
org.apache.spark.sql.catalyst.expressions.variant.VariantExpressionEvalUtils
+import org.apache.spark.sql.catalyst.util.DateTimeConstants._
+import org.apache.spark.sql.catalyst.util.DateTimeUtils._
import org.apache.spark.sql.execution.datasources.parquet.{ParquetTest,
SparkShreddingUtils}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
@@ -349,4 +351,33 @@ class VariantShreddingSuite extends QueryTest with
SharedSparkSession with Parqu
checkExpr(path, "variant_get(v, '$.a')", null, parseJson("null"),
parseJson("1"), null,
parseJson("null"), parseJson("3"))
}
+
+ testWithTempPath("custom casts") { path =>
+ writeRows(path, writeSchema(LongType),
+ Row(metadata(Nil), null, Long.MaxValue / MICROS_PER_SECOND + 1),
+ Row(metadata(Nil), null, Long.MaxValue / MICROS_PER_SECOND))
+
+ // long -> timestamp
+ checkException(path, "cast(v as timestamp)", "INVALID_VARIANT_CAST")
+ checkExpr(path, "try_cast(v as timestamp)",
+ null, toJavaTimestamp(Long.MaxValue / MICROS_PER_SECOND *
MICROS_PER_SECOND))
+
+ writeRows(path, writeSchema(DecimalType(38, 19)),
+ Row(metadata(Nil), null, Decimal("1E18")),
+ Row(metadata(Nil), null, Decimal("100")),
+ Row(metadata(Nil), null, Decimal("10")),
+ Row(metadata(Nil), null, Decimal("1")),
+ Row(metadata(Nil), null, Decimal("0")),
+ Row(metadata(Nil), null, Decimal("0.1")),
+ Row(metadata(Nil), null, Decimal("0.01")),
+ Row(metadata(Nil), null, Decimal("1E-18")))
+
+ checkException(path, "cast(v as timestamp)", "INVALID_VARIANT_CAST")
+ // decimal -> timestamp
+ checkExpr(path, "try_cast(v as timestamp)",
+ (null +: Seq(100000000, 10000000, 1000000, 0, 100000, 10000,
0).map(toJavaTimestamp(_))): _*)
+ // decimal -> string
+ checkExpr(path, "cast(v as string)",
+ "1000000000000000000", "100", "10", "1", "0", "0.1", "0.01",
"0.000000000000000001")
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]