This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 3a61eeff340a [SPARK-50619][SQL] Refactor VariantGet.cast to pack the 
cast arguments
3a61eeff340a is described below

commit 3a61eeff340a0f979dfb2929aeff128c60f18a2c
Author: Chenhao Li <[email protected]>
AuthorDate: Thu Dec 19 17:04:34 2024 +0800

    [SPARK-50619][SQL] Refactor VariantGet.cast to pack the cast arguments
    
    ### What changes were proposed in this pull request?
    
    As the title. It refactors the code for simplification.
    
    ### Why are the changes needed?
    
    The refactor will make it simpler for the shredded user to use 
`VariantGet.cast`.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Existing tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #49239 from chenhao-db/VariantCastArgs.
    
    Authored-by: Chenhao Li <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../spark/sql/catalyst/expressions/Cast.scala      | 13 +++--
 .../expressions/variant/variantExpressions.scala   | 61 +++++++++-------------
 2 files changed, 34 insertions(+), 40 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index d4ebdf10ef11..abd635e22f26 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -572,6 +572,11 @@ case class Cast(
     }
   }
 
+  private lazy val castArgs = variant.VariantCastArgs(
+    evalMode != EvalMode.TRY,
+    timeZoneId,
+    zoneId)
+
   def needsTimeZone: Boolean = Cast.needsTimeZone(child.dataType, dataType)
 
   // [[func]] assumes the input is no longer null because eval already does 
the null check.
@@ -1127,7 +1132,7 @@ case class Cast(
       _ => throw QueryExecutionErrors.cannotCastFromNullTypeError(to)
     } else if (from.isInstanceOf[VariantType]) {
       buildCast[VariantVal](_, v => {
-        variant.VariantGet.cast(v, to, evalMode != EvalMode.TRY, timeZoneId, 
zoneId)
+        variant.VariantGet.cast(v, to, castArgs)
       })
     } else {
       to match {
@@ -1225,12 +1230,10 @@ case class Cast(
     case _ if from.isInstanceOf[VariantType] => (c, evPrim, evNull) =>
       val tmp = ctx.freshVariable("tmp", classOf[Object])
       val dataTypeArg = ctx.addReferenceObj("dataType", to)
-      val zoneStrArg = ctx.addReferenceObj("zoneStr", timeZoneId)
-      val zoneIdArg = ctx.addReferenceObj("zoneId", zoneId, 
classOf[ZoneId].getName)
-      val failOnError = evalMode != EvalMode.TRY
+      val castArgsArg = ctx.addReferenceObj("castArgs", castArgs)
       val cls = classOf[variant.VariantGet].getName
       code"""
-        Object $tmp = $cls.cast($c, $dataTypeArg, $failOnError, $zoneStrArg, 
$zoneIdArg);
+        Object $tmp = $cls.cast($c, $dataTypeArg, $castArgsArg);
         if ($tmp == null) {
           $evNull = true;
         } else {
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
index 2fa0ce0f570c..c19df82e6576 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
@@ -278,14 +278,13 @@ case class VariantGet(
   override def nullable: Boolean = true
   override def nullIntolerant: Boolean = true
 
+  private lazy val castArgs = VariantCastArgs(
+    failOnError,
+    timeZoneId,
+    zoneId)
+
   protected override def nullSafeEval(input: Any, path: Any): Any = {
-    VariantGet.variantGet(
-      input.asInstanceOf[VariantVal],
-      parsedPath,
-      dataType,
-      failOnError,
-      timeZoneId,
-      zoneId)
+    VariantGet.variantGet(input.asInstanceOf[VariantVal], parsedPath, 
dataType, castArgs)
   }
 
   protected override def doGenCode(ctx: CodegenContext, ev: ExprCode): 
ExprCode = {
@@ -293,15 +292,14 @@ case class VariantGet(
     val tmp = ctx.freshVariable("tmp", classOf[Object])
     val parsedPathArg = ctx.addReferenceObj("parsedPath", parsedPath)
     val dataTypeArg = ctx.addReferenceObj("dataType", dataType)
-    val zoneStrArg = ctx.addReferenceObj("zoneStr", timeZoneId)
-    val zoneIdArg = ctx.addReferenceObj("zoneId", zoneId, 
classOf[ZoneId].getName)
+    val castArgsArg = ctx.addReferenceObj("castArgs", castArgs)
     val code = code"""
       ${childCode.code}
       boolean ${ev.isNull} = ${childCode.isNull};
       ${CodeGenerator.javaType(dataType)} ${ev.value} = 
${CodeGenerator.defaultValue(dataType)};
       if (!${ev.isNull}) {
         Object $tmp = 
org.apache.spark.sql.catalyst.expressions.variant.VariantGet.variantGet(
-          ${childCode.value}, $parsedPathArg, $dataTypeArg, $failOnError, 
$zoneStrArg, $zoneIdArg);
+          ${childCode.value}, $parsedPathArg, $dataTypeArg, $castArgsArg);
         if ($tmp == null) {
           ${ev.isNull} = true;
         } else {
@@ -323,6 +321,12 @@ case class VariantGet(
   override def withTimeZone(timeZoneId: String): VariantGet = copy(timeZoneId 
= Option(timeZoneId))
 }
 
+// Several parameters used by `VariantGet.cast`. Packed together to simplify 
parameter passing.
+case class VariantCastArgs(
+    failOnError: Boolean,
+    zoneStr: Option[String],
+    zoneId: ZoneId)
+
 case object VariantGet {
   /**
    * Returns whether a data type can be cast into/from variant. For scalar 
types, we allow a subset
@@ -347,9 +351,7 @@ case object VariantGet {
       input: VariantVal,
       parsedPath: Array[VariantPathParser.PathSegment],
       dataType: DataType,
-      failOnError: Boolean,
-      zoneStr: Option[String],
-      zoneId: ZoneId): Any = {
+      castArgs: VariantCastArgs): Any = {
     var v = new Variant(input.getValue, input.getMetadata)
     for (path <- parsedPath) {
       v = path match {
@@ -359,21 +361,16 @@ case object VariantGet {
       }
       if (v == null) return null
     }
-    VariantGet.cast(v, dataType, failOnError, zoneStr, zoneId)
+    VariantGet.cast(v, dataType, castArgs)
   }
 
   /**
    * A simple wrapper of the `cast` function that takes `Variant` rather than 
`VariantVal`. The
    * `Cast` expression uses it and makes the implementation simpler.
    */
-  def cast(
-      input: VariantVal,
-      dataType: DataType,
-      failOnError: Boolean,
-      zoneStr: Option[String],
-      zoneId: ZoneId): Any = {
+  def cast(input: VariantVal, dataType: DataType, castArgs: VariantCastArgs): 
Any = {
     val v = new Variant(input.getValue, input.getMetadata)
-    VariantGet.cast(v, dataType, failOnError, zoneStr, zoneId)
+    VariantGet.cast(v, dataType, castArgs)
   }
 
   /**
@@ -383,15 +380,10 @@ case object VariantGet {
    * "hello" to int). If the cast fails, throw an exception when `failOnError` 
is true, or return a
    * SQL NULL when it is false.
    */
-  def cast(
-      v: Variant,
-      dataType: DataType,
-      failOnError: Boolean,
-      zoneStr: Option[String],
-      zoneId: ZoneId): Any = {
+  def cast(v: Variant, dataType: DataType, castArgs: VariantCastArgs): Any = {
     def invalidCast(): Any = {
-      if (failOnError) {
-        throw QueryExecutionErrors.invalidVariantCast(v.toJson(zoneId), 
dataType)
+      if (castArgs.failOnError) {
+        throw 
QueryExecutionErrors.invalidVariantCast(v.toJson(castArgs.zoneId), dataType)
       } else {
         null
       }
@@ -411,7 +403,7 @@ case object VariantGet {
         val input = variantType match {
           case Type.OBJECT | Type.ARRAY =>
             return if (dataType.isInstanceOf[StringType]) {
-              UTF8String.fromString(v.toJson(zoneId))
+              UTF8String.fromString(v.toJson(castArgs.zoneId))
             } else {
               invalidCast()
             }
@@ -457,7 +449,7 @@ case object VariantGet {
             }
           case _ =>
             if (Cast.canAnsiCast(input.dataType, dataType)) {
-              val result = Cast(input, dataType, zoneStr, EvalMode.TRY).eval()
+              val result = Cast(input, dataType, castArgs.zoneStr, 
EvalMode.TRY).eval()
               if (result == null) invalidCast() else result
             } else {
               invalidCast()
@@ -468,7 +460,7 @@ case object VariantGet {
           val size = v.arraySize()
           val array = new Array[Any](size)
           for (i <- 0 until size) {
-            array(i) = cast(v.getElementAtIndex(i), elementType, failOnError, 
zoneStr, zoneId)
+            array(i) = cast(v.getElementAtIndex(i), elementType, castArgs)
           }
           new GenericArrayData(array)
         } else {
@@ -482,7 +474,7 @@ case object VariantGet {
           for (i <- 0 until size) {
             val field = v.getFieldAtIndex(i)
             keyArray(i) = UTF8String.fromString(field.key)
-            valueArray(i) = cast(field.value, valueType, failOnError, zoneStr, 
zoneId)
+            valueArray(i) = cast(field.value, valueType, castArgs)
           }
           ArrayBasedMapData(keyArray, valueArray)
         } else {
@@ -495,8 +487,7 @@ case object VariantGet {
             val field = v.getFieldAtIndex(i)
             st.getFieldIndex(field.key) match {
               case Some(idx) =>
-                row.update(idx,
-                  cast(field.value, fields(idx).dataType, failOnError, 
zoneStr, zoneId))
+                row.update(idx, cast(field.value, fields(idx).dataType, 
castArgs))
               case _ =>
             }
           }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to