This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 3fd38d4c07f6 [SPARK-47803][FOLLOWUP] Check nulls when casting nested 
type to variant
3fd38d4c07f6 is described below

commit 3fd38d4c07f6c998ec8bb234796f83a6aecfc0d2
Author: Chenhao Li <chenhao...@databricks.com>
AuthorDate: Thu May 9 22:45:10 2024 +0800

    [SPARK-47803][FOLLOWUP] Check nulls when casting nested type to variant
    
    ### What changes were proposed in this pull request?
    
    It adds null checks when accessing a nested element when casting a nested 
type to variant. It is necessary because the `get` API doesn't guarantee to 
return null when the slot is null. For example, `ColumnarArray.get` may return 
the default value of a primitive type if the slot is null.
    
    ### Why are the changes needed?
    
    It is a bug fix is necessary for the cast-to-variant expression to work 
correctly.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Two new unit tests. One directly uses `ColumnarArray` as the input of the 
cast. The other creates a real-world situation where `ColumnarArray` is the 
input of the cast (scan). Both of them would fail without the code change in 
this PR.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #46486 from chenhao-db/fix_cast_nested_to_variant.
    
    Authored-by: Chenhao Li <chenhao...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../variant/VariantExpressionEvalUtils.scala       |  9 ++++--
 .../apache/spark/sql/VariantEndToEndSuite.scala    | 33 ++++++++++++++++++++--
 2 files changed, 37 insertions(+), 5 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala
index eb235eb854e0..f7f7097173bb 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala
@@ -103,7 +103,8 @@ object VariantExpressionEvalUtils {
         val offsets = new 
java.util.ArrayList[java.lang.Integer](data.numElements())
         for (i <- 0 until data.numElements()) {
           offsets.add(builder.getWritePos - start)
-          buildVariant(builder, data.get(i, elementType), elementType)
+          val element = if (data.isNullAt(i)) null else data.get(i, 
elementType)
+          buildVariant(builder, element, elementType)
         }
         builder.finishWritingArray(start, offsets)
       case MapType(StringType, valueType, _) =>
@@ -116,7 +117,8 @@ object VariantExpressionEvalUtils {
           val key = keys.getUTF8String(i).toString
           val id = builder.addKey(key)
           fields.add(new VariantBuilder.FieldEntry(key, id, 
builder.getWritePos - start))
-          buildVariant(builder, values.get(i, valueType), valueType)
+          val value = if (values.isNullAt(i)) null else values.get(i, 
valueType)
+          buildVariant(builder, value, valueType)
         }
         builder.finishWritingObject(start, fields)
       case StructType(structFields) =>
@@ -127,7 +129,8 @@ object VariantExpressionEvalUtils {
           val key = structFields(i).name
           val id = builder.addKey(key)
           fields.add(new VariantBuilder.FieldEntry(key, id, 
builder.getWritePos - start))
-          buildVariant(builder, data.get(i, structFields(i).dataType), 
structFields(i).dataType)
+          val value = if (data.isNullAt(i)) null else data.get(i, 
structFields(i).dataType)
+          buildVariant(builder, value, structFields(i).dataType)
         }
         builder.finishWritingObject(start, fields)
     }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala
index 3964bf3aedec..53be9d50d351 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala
@@ -16,11 +16,13 @@
  */
 package org.apache.spark.sql
 
-import org.apache.spark.sql.catalyst.expressions.{CreateArray, 
CreateNamedStruct, JsonToStructs, Literal, StructsToJson}
+import org.apache.spark.sql.catalyst.expressions.{Cast, CreateArray, 
CreateNamedStruct, JsonToStructs, Literal, StructsToJson}
 import org.apache.spark.sql.catalyst.expressions.variant.ParseJson
 import org.apache.spark.sql.execution.WholeStageCodegenExec
+import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
 import org.apache.spark.sql.test.SharedSparkSession
-import org.apache.spark.sql.types.VariantType
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.vectorized.ColumnarArray
 import org.apache.spark.types.variant.VariantBuilder
 import org.apache.spark.unsafe.types.VariantVal
 
@@ -250,4 +252,31 @@ class VariantEndToEndSuite extends QueryTest with 
SharedSparkSession {
         Seq.fill(3)(Row("STRUCT<a: ARRAY<STRING>>")) ++ Seq(Row("STRUCT<a: 
ARRAY<BIGINT>>")))
     }
   }
+
+  test("cast to variant with ColumnarArray input") {
+    val dataVector = new OnHeapColumnVector(4, LongType)
+    dataVector.appendNull()
+    dataVector.appendLong(123)
+    dataVector.appendNull()
+    dataVector.appendLong(456)
+    val array = new ColumnarArray(dataVector, 0, 4)
+    val variant = Cast(Literal(array, ArrayType(LongType)), VariantType).eval()
+    assert(variant.toString == "[null,123,null,456]")
+    dataVector.close()
+  }
+
+  test("cast to variant with scan input") {
+    withTempPath { dir =>
+      val path = dir.getAbsolutePath
+      val input = Seq(Row(Array(1, null), Map("k1" -> null, "k2" -> false), 
Row(null, "str")))
+      val schema = StructType.fromDDL(
+        "a array<int>, m map<string, boolean>, s struct<f1 string, f2 string>")
+      spark.createDataFrame(spark.sparkContext.parallelize(input), 
schema).write.parquet(path)
+      val df = spark.read.parquet(path).selectExpr(
+        s"cast(cast(a as variant) as ${schema(0).dataType.sql})",
+        s"cast(cast(m as variant) as ${schema(1).dataType.sql})",
+        s"cast(cast(s as variant) as ${schema(2).dataType.sql})")
+      checkAnswer(df, input)
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to