This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 1ee3d2ae3b63 [SPARK-45827][SQL] Variant fixes with codegen and 
vectorized reader disabled
1ee3d2ae3b63 is described below

commit 1ee3d2ae3b6305fbaec9b49789f8a8352cb0564d
Author: cashmand <david.cash...@databricks.com>
AuthorDate: Tue Nov 28 17:06:33 2023 +0100

    [SPARK-45827][SQL] Variant fixes with codegen and vectorized reader disabled
    
    ### What changes were proposed in this pull request?
    
    Fix two issues with the new Variant type:
    
    1) In `InterpretedUnsafeProjection`, define element size to be 8, since 
Variant has variable length, so it is categorized as a reference type, which 
always has size 8. This only manifests as an issue when codegen is disabled and 
an array or struct contains Variant values.
    
    2) Define and use a `ParquetGroupConverter` for Variant. The previous tests 
used the vectorized reader, so this issue didn't manifest.
    
    ### Why are the changes needed?
    
    Fixes crashes when Variant is used.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Added a unit test that writes and reads an array of Variant values with 
codegen and the vectorized reader disabled. Reverting either of the two fixes 
causes the test to fail.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #43984 from cashmand/SPARK-45827-fixes.
    
    Authored-by: cashmand <david.cash...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../expressions/InterpretedUnsafeProjection.scala  |  2 +-
 .../datasources/parquet/ParquetRowConverter.scala  | 39 ++++++++++++++-
 .../scala/org/apache/spark/sql/VariantSuite.scala  | 56 ++++++++++++++++++++++
 3 files changed, 95 insertions(+), 2 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
index 50408b41c1a7..a53903a7c16d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
@@ -294,7 +294,7 @@ object InterpretedUnsafeProjection {
    */
   @scala.annotation.tailrec
   private def getElementSize(dataType: DataType): Int = dataType match {
-    case NullType | StringType | BinaryType | CalendarIntervalType |
+    case NullType | StringType | BinaryType | CalendarIntervalType | 
VariantType |
          _: DecimalType | _: StructType | _: ArrayType | _: MapType => 8
     case udt: UserDefinedType[_] =>
       getElementSize(udt.sqlType)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala
index 89c7cae175ae..7bc98974226b 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala
@@ -42,7 +42,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.execution.datasources.DataSourceUtils
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.unsafe.types.{UTF8String, VariantVal}
 import org.apache.spark.util.collection.Utils
 
 /**
@@ -498,6 +498,9 @@ private[parquet] class ParquetRowConverter(
           int96RebaseSpec,
           wrappedUpdater)
 
+      case t: VariantType =>
+        new ParquetVariantConverter(parquetType.asGroupType(), updater)
+
       case t =>
         throw 
QueryExecutionErrors.cannotCreateParquetConverterForDataTypeError(
           t, parquetType.toString)
@@ -810,6 +813,40 @@ private[parquet] class ParquetRowConverter(
     }
   }
 
+  /** Parquet converter for Variant */
+  private final class ParquetVariantConverter(
+     parquetType: GroupType,
+     updater: ParentContainerUpdater)
+    extends ParquetGroupConverter(updater) {
+
+    private[this] var currentValue: Any = _
+    private[this] var currentMetadata: Any = _
+
+    private[this] val converters = Array(
+      // Converter for value
+      newConverter(parquetType.getType(0), BinaryType, new 
ParentContainerUpdater {
+        override def set(value: Any): Unit = currentValue = value
+      }),
+
+      // Converter for metadata
+      newConverter(parquetType.getType(1), BinaryType, new 
ParentContainerUpdater {
+        override def set(value: Any): Unit = currentMetadata = value
+      }))
+
+    override def getConverter(fieldIndex: Int): Converter = 
converters(fieldIndex)
+
+    override def end(): Unit = {
+      updater.set(
+        new VariantVal(currentValue.asInstanceOf[Array[Byte]],
+            currentMetadata.asInstanceOf[Array[Byte]]))
+    }
+
+    override def start(): Unit = {
+      currentValue = null
+      currentMetadata = null
+    }
+  }
+
   private trait RepeatedConverter {
     private[this] val currentArray = ArrayBuffer.empty[Any]
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala
index 35a1444f0e9d..98d106f05f0c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala
@@ -19,8 +19,11 @@ package org.apache.spark.sql
 
 import java.io.File
 
+import scala.collection.mutable
 import scala.util.Random
 
+import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSparkSession
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.unsafe.types.VariantVal
@@ -82,4 +85,57 @@ class VariantSuite extends QueryTest with SharedSparkSession 
{
       assert(prepareAnswer(input) == 
prepareAnswer(readResult.toImmutableArraySeq))
     }
   }
+
+  test("array of variant") {
+    val rand = new Random(42)
+    val input = Seq.fill(3) {
+      if (rand.nextInt(10) == 0) {
+        null
+      } else {
+        val value = new Array[Byte](rand.nextInt(50))
+        rand.nextBytes(value)
+        val metadata = new Array[Byte](rand.nextInt(50))
+        rand.nextBytes(metadata)
+        val numElements = 3 // rand.nextInt(10)
+        Seq.fill(numElements)(new VariantVal(value, metadata))
+      }
+    }
+
+    val df = spark.createDataFrame(
+      spark.sparkContext.parallelize(input.map { v =>
+        Row.fromSeq(Seq(v))
+      }),
+      StructType.fromDDL("v array<variant>")
+    )
+
+    def prepareAnswer(values: Seq[Row]): Seq[String] = {
+      values.map(_.get(0)).map { v =>
+        if (v == null) {
+          "null"
+        } else {
+          v.asInstanceOf[mutable.ArraySeq[Any]]
+           .map(_.asInstanceOf[VariantVal].debugString()).mkString(",")
+        }
+      }.sorted
+    }
+
+    // Test conversion to UnsafeRow in both codegen and interpreted code paths.
+    val codegenModes = Seq(CodegenObjectFactoryMode.NO_CODEGEN.toString,
+                           CodegenObjectFactoryMode.FALLBACK.toString)
+    codegenModes.foreach { codegen =>
+      withTempDir { dir =>
+        withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegen) {
+          val tempDir = new File(dir, "files").getCanonicalPath
+          df.write.parquet(tempDir)
+          Seq(false, true).foreach { vectorizedReader =>
+            withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key ->
+                vectorizedReader.toString) {
+              val readResult = spark.read.parquet(tempDir).collect().toSeq
+              assert(prepareAnswer(df.collect().toSeq) == 
prepareAnswer(readResult))
+            }
+          }
+        }
+      }
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to