This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 7a4114cdf994 [SPARK-50644][FOLLOWUP][SQL] Fix scalar cast in the 
shredded reader
7a4114cdf994 is described below

commit 7a4114cdf9948480d86ec7c870f7a2fa4939cc67
Author: Chenhao Li <[email protected]>
AuthorDate: Thu Dec 26 15:24:23 2024 +0800

    [SPARK-50644][FOLLOWUP][SQL] Fix scalar cast in the shredded reader
    
    ### What changes were proposed in this pull request?
    
    We mostly use the `Cast` expression to implement the cast, but we need some 
custom logic for certain type combinations. We already have special handling 
for `long/decimal -> timestamp` in `VariantGet.cast`, so we should do the same 
thing in `ScalarCastHelper` to ensure consistency. `ScalarCastHelper` also 
needs special handling for `decimal -> string` to strip any trailing zeros.
    
    ### Why are the changes needed?
    
    To ensure that cast on shredded variant has the same semantics as cast on 
unshredded variant.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Unit tests. They would fail without the change.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #49293 from chenhao-db/fix_shredded_cast.
    
    Authored-by: Chenhao Li <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../expressions/variant/variantExpressions.scala   | 36 ++++++++++++++--------
 .../datasources/parquet/SparkShreddingUtils.scala  | 26 +++++++++++++++-
 .../apache/spark/sql/VariantShreddingSuite.scala   | 31 +++++++++++++++++++
 3 files changed, 80 insertions(+), 13 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
index c19df82e6576..ba910b8c7e5f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
@@ -427,24 +427,15 @@ case object VariantGet {
             messageParameters = Map("id" -> v.getTypeInfo.toString)
           )
         }
-        // We mostly use the `Cast` expression to implement the cast. However, 
`Cast` silently
-        // ignores the overflow in the long/decimal -> timestamp cast, and we 
want to enforce
-        // strict overflow checks.
         input.dataType match {
           case LongType if dataType == TimestampType =>
-            try Math.multiplyExact(input.value.asInstanceOf[Long], 
MICROS_PER_SECOND)
+            try castLongToTimestamp(input.value.asInstanceOf[Long])
             catch {
               case _: ArithmeticException => invalidCast()
             }
           case _: DecimalType if dataType == TimestampType =>
-            try {
-              input.value
-                .asInstanceOf[Decimal]
-                .toJavaBigDecimal
-                .multiply(new java.math.BigDecimal(MICROS_PER_SECOND))
-                .toBigInteger
-                .longValueExact()
-            } catch {
+            try castDecimalToTimestamp(input.value.asInstanceOf[Decimal])
+            catch {
               case _: ArithmeticException => invalidCast()
             }
           case _ =>
@@ -497,6 +488,27 @@ case object VariantGet {
         }
     }
   }
+
+  // We mostly use the `Cast` expression to implement the cast, but we need 
some custom logic for
+  // certain type combinations.
+  //
+  // `castLongToTimestamp/castDecimalToTimestamp`: `Cast` silently ignores the 
overflow in the
+  // long/decimal -> timestamp cast, and we want to enforce strict overflow 
checks. They both throw
+  // an `ArithmeticException` when overflow happens.
+  def castLongToTimestamp(input: Long): Long =
+    Math.multiplyExact(input, MICROS_PER_SECOND)
+
+  def castDecimalToTimestamp(input: Decimal): Long = {
+    val multiplier = new java.math.BigDecimal(MICROS_PER_SECOND)
+    input.toJavaBigDecimal.multiply(multiplier).toBigInteger.longValueExact()
+  }
+
+  // Cast decimal to string, but strip any trailing zeros. We don't have to 
call it if the decimal
+  // is returned by `Variant.getDecimal`, which already strips any trailing 
zeros. But we need it
+  // if the decimal is produced by Spark internally, e.g., on a shredded 
decimal produced by the
+  // Spark Parquet reader.
+  def castDecimalToString(input: Decimal): UTF8String =
+    
UTF8String.fromString(input.toJavaBigDecimal.stripTrailingZeros.toPlainString)
 }
 
 abstract class ParseJsonExpressionBuilderBase(failOnError: Boolean) extends 
ExpressionBuilder {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala
index a83ca78455fa..34c167aea363 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala
@@ -111,7 +111,31 @@ case class ScalarCastHelper(
     } else {
       ""
     }
-    if (cast != null) {
+    val customCast = (child.dataType, dataType) match {
+      case (_: LongType, _: TimestampType) => "castLongToTimestamp"
+      case (_: DecimalType, _: TimestampType) => "castDecimalToTimestamp"
+      case (_: DecimalType, _: StringType) => "castDecimalToString"
+      case _ => null
+    }
+    if (customCast != null) {
+      val childCode = child.genCode(ctx)
+      // We can avoid the try-catch block for decimal -> string, but the 
performance benefit is
+      // little. We can also be more specific in the exception type, like 
catching
+      // `ArithmeticException` instead of `Exception`, but it is unnecessary. 
The `try_cast` codegen
+      // also catches `Exception` instead of specific exceptions.
+      val code = code"""
+        ${childCode.code}
+        boolean ${ev.isNull} = false;
+        ${CodeGenerator.javaType(dataType)} ${ev.value} = 
${CodeGenerator.defaultValue(dataType)};
+        try {
+          ${ev.value} = 
${classOf[VariantGet].getName}.$customCast(${childCode.value});
+        } catch (Exception e) {
+          ${ev.isNull} = true;
+          $invalidCastCode
+        }
+      """
+      ev.copy(code = code)
+    } else if (cast != null) {
       val castCode = cast.genCode(ctx)
       val code = code"""
         ${castCode.code}
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/VariantShreddingSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/VariantShreddingSuite.scala
index b6623bb57a71..3443028ba45b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/VariantShreddingSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantShreddingSuite.scala
@@ -24,6 +24,8 @@ import java.time.LocalDateTime
 import org.apache.spark.SparkThrowable
 import org.apache.spark.sql.catalyst.InternalRow
 import 
org.apache.spark.sql.catalyst.expressions.variant.VariantExpressionEvalUtils
+import org.apache.spark.sql.catalyst.util.DateTimeConstants._
+import org.apache.spark.sql.catalyst.util.DateTimeUtils._
 import org.apache.spark.sql.execution.datasources.parquet.{ParquetTest, 
SparkShreddingUtils}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSparkSession
@@ -349,4 +351,33 @@ class VariantShreddingSuite extends QueryTest with 
SharedSparkSession with Parqu
     checkExpr(path, "variant_get(v, '$.a')", null, parseJson("null"), 
parseJson("1"), null,
       parseJson("null"), parseJson("3"))
   }
+
+  testWithTempPath("custom casts") { path =>
+    writeRows(path, writeSchema(LongType),
+      Row(metadata(Nil), null, Long.MaxValue / MICROS_PER_SECOND + 1),
+      Row(metadata(Nil), null, Long.MaxValue / MICROS_PER_SECOND))
+
+    // long -> timestamp
+    checkException(path, "cast(v as timestamp)", "INVALID_VARIANT_CAST")
+    checkExpr(path, "try_cast(v as timestamp)",
+      null, toJavaTimestamp(Long.MaxValue / MICROS_PER_SECOND * 
MICROS_PER_SECOND))
+
+    writeRows(path, writeSchema(DecimalType(38, 19)),
+      Row(metadata(Nil), null, Decimal("1E18")),
+      Row(metadata(Nil), null, Decimal("100")),
+      Row(metadata(Nil), null, Decimal("10")),
+      Row(metadata(Nil), null, Decimal("1")),
+      Row(metadata(Nil), null, Decimal("0")),
+      Row(metadata(Nil), null, Decimal("0.1")),
+      Row(metadata(Nil), null, Decimal("0.01")),
+      Row(metadata(Nil), null, Decimal("1E-18")))
+
+    checkException(path, "cast(v as timestamp)", "INVALID_VARIANT_CAST")
+    // decimal -> timestamp
+    checkExpr(path, "try_cast(v as timestamp)",
+      (null +: Seq(100000000, 10000000, 1000000, 0, 100000, 10000, 
0).map(toJavaTimestamp(_))): _*)
+    // decimal -> string
+    checkExpr(path, "cast(v as string)",
+      "1000000000000000000", "100", "10", "1", "0", "0.1", "0.01", 
"0.000000000000000001")
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to