This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 1ee3d2ae3b63 [SPARK-45827][SQL] Variant fixes with codegen and vectorized reader disabled 1ee3d2ae3b63 is described below commit 1ee3d2ae3b6305fbaec9b49789f8a8352cb0564d Author: cashmand <david.cash...@databricks.com> AuthorDate: Tue Nov 28 17:06:33 2023 +0100 [SPARK-45827][SQL] Variant fixes with codegen and vectorized reader disabled ### What changes were proposed in this pull request? Fix two issues with the new Variant type: 1) In `InterpretedUnsafeProjection`, define element size to be 8, since Variant has variable length, so it is categorized as a reference type, which always has size 8. This only manifests as an issue when codegen is disabled and an array or struct contains Variant values. 2) Define and use a `ParquetGroupConverter` for Variant. The previous tests used the vectorized reader, so this issue didn't manifest. ### Why are the changes needed? Fixes crashes when Variant is used. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added a unit test that writes and reads an array of Variant values with codegen and the vectorized reader disabled. Reverting either of the two fixes causes the test to fail. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43984 from cashmand/SPARK-45827-fixes. Authored-by: cashmand <david.cash...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../expressions/InterpretedUnsafeProjection.scala | 2 +- .../datasources/parquet/ParquetRowConverter.scala | 39 ++++++++++++++- .../scala/org/apache/spark/sql/VariantSuite.scala | 56 ++++++++++++++++++++++ 3 files changed, 95 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala index 50408b41c1a7..a53903a7c16d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala @@ -294,7 +294,7 @@ object InterpretedUnsafeProjection { */ @scala.annotation.tailrec private def getElementSize(dataType: DataType): Int = dataType match { - case NullType | StringType | BinaryType | CalendarIntervalType | + case NullType | StringType | BinaryType | CalendarIntervalType | VariantType | _: DecimalType | _: StructType | _: ArrayType | _: MapType => 8 case udt: UserDefinedType[_] => getElementSize(udt.sqlType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala index 89c7cae175ae..7bc98974226b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala @@ -42,7 +42,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.datasources.DataSourceUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.types.{UTF8String, VariantVal} import org.apache.spark.util.collection.Utils /** @@ -498,6 +498,9 @@ private[parquet] class ParquetRowConverter( int96RebaseSpec, wrappedUpdater) + case t: VariantType => + new ParquetVariantConverter(parquetType.asGroupType(), updater) + case t => throw QueryExecutionErrors.cannotCreateParquetConverterForDataTypeError( t, parquetType.toString) @@ -810,6 +813,40 @@ private[parquet] class ParquetRowConverter( } } + /** Parquet converter for Variant */ + private final class ParquetVariantConverter( + parquetType: GroupType, + updater: ParentContainerUpdater) + extends ParquetGroupConverter(updater) { + + private[this] var currentValue: Any = _ + private[this] var currentMetadata: Any = _ + + private[this] val converters = Array( + // Converter for value + newConverter(parquetType.getType(0), BinaryType, new ParentContainerUpdater { + override def set(value: Any): Unit = currentValue = value + }), + + // Converter for metadata + newConverter(parquetType.getType(1), BinaryType, new ParentContainerUpdater { + override def set(value: Any): Unit = currentMetadata = value + })) + + override def getConverter(fieldIndex: Int): Converter = converters(fieldIndex) + + override def end(): Unit = { + updater.set( + new VariantVal(currentValue.asInstanceOf[Array[Byte]], + currentMetadata.asInstanceOf[Array[Byte]])) + } + + override def start(): Unit = { + currentValue = null + currentMetadata = null + } + } + private trait RepeatedConverter { private[this] val currentArray = ArrayBuffer.empty[Any] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala index 35a1444f0e9d..98d106f05f0c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala @@ -19,8 +19,11 @@ package org.apache.spark.sql import java.io.File +import scala.collection.mutable import scala.util.Random +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.types.VariantVal @@ -82,4 +85,57 @@ class VariantSuite extends QueryTest with SharedSparkSession { assert(prepareAnswer(input) == prepareAnswer(readResult.toImmutableArraySeq)) } } + + test("array of variant") { + val rand = new Random(42) + val input = Seq.fill(3) { + if (rand.nextInt(10) == 0) { + null + } else { + val value = new Array[Byte](rand.nextInt(50)) + rand.nextBytes(value) + val metadata = new Array[Byte](rand.nextInt(50)) + rand.nextBytes(metadata) + val numElements = 3 // rand.nextInt(10) + Seq.fill(numElements)(new VariantVal(value, metadata)) + } + } + + val df = spark.createDataFrame( + spark.sparkContext.parallelize(input.map { v => + Row.fromSeq(Seq(v)) + }), + StructType.fromDDL("v array<variant>") + ) + + def prepareAnswer(values: Seq[Row]): Seq[String] = { + values.map(_.get(0)).map { v => + if (v == null) { + "null" + } else { + v.asInstanceOf[mutable.ArraySeq[Any]] + .map(_.asInstanceOf[VariantVal].debugString()).mkString(",") + } + }.sorted + } + + // Test conversion to UnsafeRow in both codegen and interpreted code paths. + val codegenModes = Seq(CodegenObjectFactoryMode.NO_CODEGEN.toString, + CodegenObjectFactoryMode.FALLBACK.toString) + codegenModes.foreach { codegen => + withTempDir { dir => + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegen) { + val tempDir = new File(dir, "files").getCanonicalPath + df.write.parquet(tempDir) + Seq(false, true).foreach { vectorizedReader => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> + vectorizedReader.toString) { + val readResult = spark.read.parquet(tempDir).collect().toSeq + assert(prepareAnswer(df.collect().toSeq) == prepareAnswer(readResult)) + } + } + } + } + } + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org