This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 2c1c4d2614ae [SPARK-50644][SQL] Read variant struct in Parquet reader
2c1c4d2614ae is described below
commit 2c1c4d2614ae1ff902c244209f7ec3c79102d3e0
Author: Chenhao Li <[email protected]>
AuthorDate: Tue Dec 24 14:54:02 2024 +0800
[SPARK-50644][SQL] Read variant struct in Parquet reader
### What changes were proposed in this pull request?
It adds support for variant struct in Parquet reader. The concept of
variant struct was introduced in https://github.com/apache/spark/pull/49235. It
includes all the extracted fields from a variant column that the query requests.
### Why are the changes needed?
By producing variant struct in Parquet reader, we can avoid
reading/rebuilding the full variant and achieve more efficient variant
processing.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Unit test.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #49263 from chenhao-db/spark_variant_struct_reader.
Authored-by: Chenhao Li <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../apache/spark/types/variant/ShreddingUtils.java | 9 +-
.../apache/spark/types/variant/VariantSchema.java | 6 +
.../datasources/parquet/ParquetColumnVector.java | 24 +-
.../datasources/parquet/ParquetReadSupport.scala | 9 +
.../datasources/parquet/ParquetRowConverter.scala | 26 +-
.../parquet/ParquetSchemaConverter.scala | 4 +
.../datasources/parquet/SparkShreddingUtils.scala | 597 ++++++++++++++++++++-
.../apache/spark/sql/VariantShreddingSuite.scala | 185 ++++++-
8 files changed, 820 insertions(+), 40 deletions(-)
diff --git
a/common/variant/src/main/java/org/apache/spark/types/variant/ShreddingUtils.java
b/common/variant/src/main/java/org/apache/spark/types/variant/ShreddingUtils.java
index 59e16b77ab01..6a04bf9a2b25 100644
---
a/common/variant/src/main/java/org/apache/spark/types/variant/ShreddingUtils.java
+++
b/common/variant/src/main/java/org/apache/spark/types/variant/ShreddingUtils.java
@@ -49,9 +49,8 @@ public class ShreddingUtils {
throw malformedVariant();
}
byte[] metadata = row.getBinary(schema.topLevelMetadataIdx);
- if (schema.variantIdx >= 0 && schema.typedIdx < 0) {
- // The variant is unshredded. We are not required to do anything
special, but we can have an
- // optimization to avoid `rebuild`.
+ if (schema.isUnshredded()) {
+ // `rebuild` is unnecessary for unshredded variant.
if (row.isNullAt(schema.variantIdx)) {
throw malformedVariant();
}
@@ -65,8 +64,8 @@ public class ShreddingUtils {
// Rebuild a variant value from the shredded data according to the
reconstruction algorithm in
// https://github.com/apache/parquet-format/blob/master/VariantShredding.md.
// Append the result to `builder`.
- private static void rebuild(ShreddedRow row, byte[] metadata, VariantSchema
schema,
- VariantBuilder builder) {
+ public static void rebuild(ShreddedRow row, byte[] metadata, VariantSchema
schema,
+ VariantBuilder builder) {
int typedIdx = schema.typedIdx;
int variantIdx = schema.variantIdx;
if (typedIdx >= 0 && !row.isNullAt(typedIdx)) {
diff --git
a/common/variant/src/main/java/org/apache/spark/types/variant/VariantSchema.java
b/common/variant/src/main/java/org/apache/spark/types/variant/VariantSchema.java
index 551e46214859..d1e6cc3a727f 100644
---
a/common/variant/src/main/java/org/apache/spark/types/variant/VariantSchema.java
+++
b/common/variant/src/main/java/org/apache/spark/types/variant/VariantSchema.java
@@ -138,6 +138,12 @@ public class VariantSchema {
this.arraySchema = arraySchema;
}
+ // Return whether the variant column is unshrededed. The user is not
required to do anything
+ // special, but can have certain optimizations for unshrededed variant.
+ public boolean isUnshredded() {
+ return topLevelMetadataIdx >= 0 && variantIdx >= 0 && typedIdx < 0;
+ }
+
@Override
public String toString() {
return "VariantSchema{" +
diff --git
a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnVector.java
b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnVector.java
index 0b9a25fc46a0..7fb8be7caf28 100644
---
a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnVector.java
+++
b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnVector.java
@@ -35,7 +35,6 @@ import org.apache.spark.sql.types.MapType;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.VariantType;
import org.apache.spark.types.variant.VariantSchema;
-import org.apache.spark.unsafe.types.VariantVal;
/**
* Contains necessary information representing a Parquet column, either of
primitive or nested type.
@@ -49,6 +48,9 @@ final class ParquetColumnVector {
// contains only one child that reads the underlying file content. This
`ParquetColumnVector`
// should assemble Spark variant values from the file content.
private VariantSchema variantSchema;
+ // Only meaningful if `variantSchema` is not null. See
`SparkShreddingUtils.getFieldsToExtract`
+ // for its meaning.
+ private FieldToExtract[] fieldsToExtract;
/**
* Repetition & Definition levels
@@ -117,6 +119,7 @@ final class ParquetColumnVector {
fileContent, capacity, memoryMode, missingColumns, false, null);
children.add(contentVector);
variantSchema =
SparkShreddingUtils.buildVariantSchema(fileContentCol.sparkType());
+ fieldsToExtract =
SparkShreddingUtils.getFieldsToExtract(column.sparkType(), variantSchema);
repetitionLevels = contentVector.repetitionLevels;
definitionLevels = contentVector.definitionLevels;
} else if (isPrimitive) {
@@ -188,20 +191,11 @@ final class ParquetColumnVector {
if (variantSchema != null) {
children.get(0).assemble();
WritableColumnVector fileContent = children.get(0).getValueVector();
- int numRows = fileContent.getElementsAppended();
- vector.reset();
- vector.reserve(numRows);
- WritableColumnVector valueChild = vector.getChild(0);
- WritableColumnVector metadataChild = vector.getChild(1);
- for (int i = 0; i < numRows; ++i) {
- if (fileContent.isNullAt(i)) {
- vector.appendStruct(true);
- } else {
- vector.appendStruct(false);
- VariantVal v = SparkShreddingUtils.rebuild(fileContent.getStruct(i),
variantSchema);
- valueChild.appendByteArray(v.getValue(), 0, v.getValue().length);
- metadataChild.appendByteArray(v.getMetadata(), 0,
v.getMetadata().length);
- }
+ if (fieldsToExtract == null) {
+ SparkShreddingUtils.assembleVariantBatch(fileContent, vector,
variantSchema);
+ } else {
+ SparkShreddingUtils.assembleVariantStructBatch(fileContent, vector,
variantSchema,
+ fieldsToExtract);
}
return;
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala
index 8dde02a4673f..af0bf0d51f07 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala
@@ -35,6 +35,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec
import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.execution.datasources.VariantMetadata
import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf}
import org.apache.spark.sql.types._
@@ -221,6 +222,9 @@ object ParquetReadSupport extends Logging {
clipParquetMapType(
parquetType.asGroupType(), t.keyType, t.valueType, caseSensitive,
useFieldId)
+ case t: StructType if VariantMetadata.isVariantStruct(t) =>
+ clipVariantSchema(parquetType.asGroupType(), t)
+
case t: StructType =>
clipParquetGroup(parquetType.asGroupType(), t, caseSensitive,
useFieldId)
@@ -390,6 +394,11 @@ object ParquetReadSupport extends Logging {
.named(parquetRecord.getName)
}
+ private def clipVariantSchema(parquetType: GroupType, variantStruct:
StructType): GroupType = {
+ // TODO(SHREDDING): clip `parquetType` to retain the necessary columns.
+ parquetType
+ }
+
/**
* Clips a Parquet [[GroupType]] which corresponds to a Catalyst
[[StructType]].
*
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala
index 3ed7fe37ccd9..550c2af43a70 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala
@@ -40,7 +40,7 @@ import
org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec
import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.errors.QueryExecutionErrors
-import org.apache.spark.sql.execution.datasources.DataSourceUtils
+import org.apache.spark.sql.execution.datasources.{DataSourceUtils,
VariantMetadata}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{UTF8String, VariantVal}
@@ -498,6 +498,9 @@ private[parquet] class ParquetRowConverter(
case t: MapType =>
new ParquetMapConverter(parquetType.asGroupType(), t, updater)
+ case t: StructType if VariantMetadata.isVariantStruct(t) =>
+ new ParquetVariantConverter(t, parquetType.asGroupType(), updater)
+
case t: StructType =>
val wrappedUpdater = {
// SPARK-30338: avoid unnecessary InternalRow copying for nested
structs:
@@ -536,12 +539,7 @@ private[parquet] class ParquetRowConverter(
case t: VariantType =>
if (SQLConf.get.getConf(SQLConf.VARIANT_ALLOW_READING_SHREDDED)) {
- // Infer a Spark type from `parquetType`. This piece of code is
copied from
- // `ParquetArrayConverter`.
- val messageType =
Types.buildMessage().addField(parquetType).named("foo")
- val column = new ColumnIOFactory().getColumnIO(messageType)
- val parquetSparkType =
schemaConverter.convertField(column.getChild(0)).sparkType
- new ParquetVariantConverter(parquetType.asGroupType(),
parquetSparkType, updater)
+ new ParquetVariantConverter(t, parquetType.asGroupType(), updater)
} else {
new ParquetUnshreddedVariantConverter(parquetType.asGroupType(),
updater)
}
@@ -909,13 +907,14 @@ private[parquet] class ParquetRowConverter(
/** Parquet converter for Variant (shredded or unshredded) */
private final class ParquetVariantConverter(
- parquetType: GroupType,
- parquetSparkType: DataType,
- updater: ParentContainerUpdater)
+ targetType: DataType, parquetType: GroupType, updater:
ParentContainerUpdater)
extends ParquetGroupConverter(updater) {
private[this] var currentRow: Any = _
+ private[this] val parquetSparkType =
SparkShreddingUtils.parquetTypeToSparkType(parquetType)
private[this] val variantSchema =
SparkShreddingUtils.buildVariantSchema(parquetSparkType)
+ private[this] val fieldsToExtract =
+ SparkShreddingUtils.getFieldsToExtract(targetType, variantSchema)
// A struct converter that reads the underlying file data.
private[this] val fileConverter = new ParquetRowConverter(
schemaConverter,
@@ -932,7 +931,12 @@ private[parquet] class ParquetRowConverter(
override def end(): Unit = {
fileConverter.end()
- val v =
SparkShreddingUtils.rebuild(currentRow.asInstanceOf[InternalRow], variantSchema)
+ val row = currentRow.asInstanceOf[InternalRow]
+ val v = if (fieldsToExtract == null) {
+ SparkShreddingUtils.assembleVariant(row, variantSchema)
+ } else {
+ SparkShreddingUtils.assembleVariantStruct(row, variantSchema,
fieldsToExtract)
+ }
updater.set(v)
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala
index 7f1b49e73790..64c2a3126ca9 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala
@@ -28,6 +28,7 @@ import org.apache.parquet.schema.Type.Repetition._
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.errors.QueryCompilationErrors
+import org.apache.spark.sql.execution.datasources.VariantMetadata
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -185,6 +186,9 @@ class ParquetToSparkSchemaConverter(
} else {
convertVariantField(groupColumn)
}
+ case groupColumn: GroupColumnIO if
targetType.exists(VariantMetadata.isVariantStruct) =>
+ val col = convertGroupField(groupColumn)
+ col.copy(sparkType = targetType.get, variantFileType = Some(col))
case groupColumn: GroupColumnIO => convertGroupField(groupColumn,
targetType)
}
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala
index f38e188ed042..a83ca78455fa 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala
@@ -17,12 +17,23 @@
package org.apache.spark.sql.execution.datasources.parquet
+import org.apache.parquet.io.ColumnIOFactory
+import org.apache.parquet.schema.{Type => ParquetType, Types => ParquetTypes}
+
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
-import org.apache.spark.sql.errors.QueryCompilationErrors
+import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
+import org.apache.spark.sql.catalyst.expressions.variant._
+import
org.apache.spark.sql.catalyst.expressions.variant.VariantPathParser.PathSegment
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData,
DateTimeUtils, GenericArrayData}
+import org.apache.spark.sql.errors.{QueryCompilationErrors,
QueryExecutionErrors}
+import org.apache.spark.sql.execution.RowToColumnConverter
+import org.apache.spark.sql.execution.datasources.VariantMetadata
+import org.apache.spark.sql.execution.vectorized.WritableColumnVector
import org.apache.spark.sql.types._
import org.apache.spark.types.variant._
+import org.apache.spark.types.variant.VariantUtil.Type
import org.apache.spark.unsafe.types._
case class SparkShreddedRow(row: SpecializedGetters) extends
ShreddingUtils.ShreddedRow {
@@ -45,6 +56,369 @@ case class SparkShreddedRow(row: SpecializedGetters)
extends ShreddingUtils.Shre
override def numElements(): Int = row.asInstanceOf[ArrayData].numElements()
}
+// The search result of a `PathSegment` in a `VariantSchema`.
+case class SchemaPathSegment(
+ rawPath: PathSegment,
+ // Whether this path segment is an object or array extraction.
+ isObject: Boolean,
+ // `schema.typedIdx`, if the path exists in the schema (for object
extraction, the schema
+ // should contain an object `typed_value` containing the requested field;
similar for array
+ // extraction). Negative otherwise.
+ typedIdx: Int,
+ // For object extraction, it is the index of the desired field in
`schema.objectSchema`. If the
+ // requested field doesn't exist, both `extractionIdx/typedIdx` are set to
negative.
+ // For array extraction, it is the array index. The information is already
stored in `rawPath`,
+ // but accessing a raw int should be more efficient than `rawPath`, which
is an `Either`.
+ extractionIdx: Int)
+
+// Represent a single field in a variant struct (see `VariantMetadata` for
definition), that is, a
+// single requested field that the scan should produce by extracting from the
variant column.
+case class FieldToExtract(path: Array[SchemaPathSegment], reader:
ParquetVariantReader)
+
+// A helper class to cast from scalar `typed_value` into a scalar `dataType`.
Need a custom
+// expression because it has different error reporting code than `Cast`.
+case class ScalarCastHelper(
+ child: Expression,
+ dataType: DataType,
+ castArgs: VariantCastArgs) extends UnaryExpression {
+ // The expression is only for the internal use of `ScalarReader`, which can
guarantee the child
+ // is not nullable.
+ assert(!child.nullable)
+
+ // If `cast` is null, it means the cast always fails because the type
combination is not allowed.
+ private val cast = if (Cast.canAnsiCast(child.dataType, dataType)) {
+ Cast(child, dataType, castArgs.zoneStr, EvalMode.TRY)
+ } else {
+ null
+ }
+ // Cast the input to string. Only used for reporting an invalid cast.
+ private val castToString = Cast(child, StringType, castArgs.zoneStr,
EvalMode.ANSI)
+
+ override def nullable: Boolean = !castArgs.failOnError
+ override def withNewChildInternal(newChild: Expression): UnaryExpression =
copy(child = newChild)
+
+ // No need to define the interpreted version of `eval`: the codegen must
succeed.
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode):
ExprCode = {
+ // Throw an error or do nothing, depending on `castArgs.failOnError`.
+ val invalidCastCode = if (castArgs.failOnError) {
+ val castToStringCode = castToString.genCode(ctx)
+ val typeObj = ctx.addReferenceObj("dataType", dataType)
+ val cls = classOf[ScalarCastHelper].getName
+ s"""
+ ${castToStringCode.code}
+ $cls.throwInvalidVariantCast(${castToStringCode.value}, $typeObj);
+ """
+ } else {
+ ""
+ }
+ if (cast != null) {
+ val castCode = cast.genCode(ctx)
+ val code = code"""
+ ${castCode.code}
+ boolean ${ev.isNull} = ${castCode.isNull};
+ ${CodeGenerator.javaType(dataType)} ${ev.value} = ${castCode.value};
+ if (${ev.isNull}) { $invalidCastCode }
+ """
+ ev.copy(code = code)
+ } else {
+ val code = code"""
+ boolean ${ev.isNull} = true;
+ ${CodeGenerator.javaType(dataType)} ${ev.value} =
${CodeGenerator.defaultValue(dataType)};
+ if (${ev.isNull}) { $invalidCastCode }
+ """
+ ev.copy(code = code)
+ }
+ }
+}
+
+object ScalarCastHelper {
+ // A helper function for codegen. The java compiler doesn't allow throwing a
`Throwable` in a
+ // method without `throws` annotation.
+ def throwInvalidVariantCast(value: UTF8String, dataType: DataType): Any =
+ throw QueryExecutionErrors.invalidVariantCast(value.toString, dataType)
+}
+
+// The base class to read Parquet variant values into a Spark type.
+// For convenience, we also allow creating an instance of the base class
itself. None of its
+// functions can be used, but it can serve as a container of `targetType` and
`castArgs`.
+class ParquetVariantReader(
+ val schema: VariantSchema, val targetType: DataType, val castArgs:
VariantCastArgs) {
+ // Read from a row containing a Parquet variant value (shredded or
unshredded) and return a value
+ // of `targetType`. The row schema is described by `schema`.
+ // This function throws MALFORMED_VARIANT if the variant is missing. If the
variant can be
+ // legally missing (the only possible situation is struct fields in object
`typed_value`), the
+ // caller should check for it and avoid calling this function if the variant
is missing.
+ def read(row: InternalRow, topLevelMetadata: Array[Byte]): Any = {
+ if (schema.typedIdx < 0 || row.isNullAt(schema.typedIdx)) {
+ if (schema.variantIdx < 0 || row.isNullAt(schema.variantIdx)) {
+ // Both `typed_value` and `value` are null, meaning the variant is
missing.
+ throw QueryExecutionErrors.malformedVariant()
+ }
+ val v = new Variant(row.getBinary(schema.variantIdx), topLevelMetadata)
+ VariantGet.cast(v, targetType, castArgs)
+ } else {
+ readFromTyped(row, topLevelMetadata)
+ }
+ }
+
+ // Subclasses should override it to produce the read result when
`typed_value` is not null.
+ protected def readFromTyped(row: InternalRow, topLevelMetadata:
Array[Byte]): Any =
+ throw QueryExecutionErrors.unreachableError()
+
+ // A util function to rebuild the variant in binary format from a Parquet
variant value.
+ protected final def rebuildVariant(row: InternalRow, topLevelMetadata:
Array[Byte]): Variant = {
+ val builder = new VariantBuilder(false)
+ ShreddingUtils.rebuild(SparkShreddedRow(row), topLevelMetadata, schema,
builder)
+ builder.result()
+ }
+
+ // A util function to throw error or return null when an invalid cast
happens.
+ protected final def invalidCast(row: InternalRow, topLevelMetadata:
Array[Byte]): Any = {
+ if (castArgs.failOnError) {
+ throw QueryExecutionErrors.invalidVariantCast(
+ rebuildVariant(row, topLevelMetadata).toJson(castArgs.zoneId),
targetType)
+ } else {
+ null
+ }
+ }
+}
+
+object ParquetVariantReader {
+ // Create a reader for `targetType`. If `schema` is null, meaning that the
extraction path doesn't
+ // exist in `typed_value`, it returns an instance of `ParquetVariantReader`.
As described in the
+ // class comment, the reader is only a container of `targetType` and
`castArgs` in this case.
+ def apply(schema: VariantSchema, targetType: DataType, castArgs:
VariantCastArgs,
+ isTopLevelUnshredded: Boolean = false): ParquetVariantReader =
targetType match {
+ case _ if schema == null => new ParquetVariantReader(schema, targetType,
castArgs)
+ case s: StructType => new StructReader(schema, s, castArgs)
+ case a: ArrayType => new ArrayReader(schema, a, castArgs)
+ case m@MapType(_: StringType, _, _) => new MapReader(schema, m, castArgs)
+ case v: VariantType => new VariantReader(schema, v, castArgs,
isTopLevelUnshredded)
+ case s: AtomicType => new ScalarReader(schema, s, castArgs)
+ case _ =>
+ // Type check should have rejected map with non-string type.
+ throw QueryExecutionErrors.unreachableError(s"Invalid target type:
`${targetType.sql}`")
+ }
+}
+
+// Read Parquet variant values into a Spark struct type. It reads unshredded
fields (fields that are
+// not in the typed object) from the `value`, and reads the shredded fields
from the object
+// `typed_value`.
+// `value` must not contain any shredded field according to the shredding
spec, but this requirement
+// is not enforced. If `value` does contain a shredded field, no error will
occur, and the field in
+// object `typed_value` will be the final result.
+private[this] final class StructReader(
+ schema: VariantSchema, targetType: StructType, castArgs: VariantCastArgs)
+ extends ParquetVariantReader(schema, targetType, castArgs) {
+ // For each field in `targetType`, store the index of the field with the
same name in object
+ // `typed_value`, or -1 if it doesn't exist in object `typed_value`.
+ private[this] val fieldInputIndices: Array[Int] = targetType.fields.map { f
=>
+ val inputIdx = if (schema.objectSchemaMap != null)
schema.objectSchemaMap.get(f.name) else null
+ if (inputIdx != null) inputIdx.intValue() else -1
+ }
+ // For each field in `targetType`, store the reader from the corresponding
field in object
+ // `typed_value`, or null if it doesn't exist in object `typed_value`.
+ private[this] val fieldReaders: Array[ParquetVariantReader] =
+ targetType.fields.zip(fieldInputIndices).map { case (f, inputIdx) =>
+ if (inputIdx >= 0) {
+ val fieldSchema = schema.objectSchema(inputIdx).schema
+ ParquetVariantReader(fieldSchema, f.dataType, castArgs)
+ } else {
+ null
+ }
+ }
+ // If all fields in `targetType` can be found in object `typed_value`, then
the reader doesn't
+ // need to read from `value`.
+ private[this] val needUnshreddedObject: Boolean = fieldInputIndices.exists(_
< 0)
+
+ override def readFromTyped(row: InternalRow, topLevelMetadata: Array[Byte]):
Any = {
+ if (schema.objectSchema == null) return invalidCast(row, topLevelMetadata)
+ val obj = row.getStruct(schema.typedIdx, schema.objectSchema.length)
+ val result = new GenericInternalRow(fieldInputIndices.length)
+ var unshreddedObject: Variant = null
+ if (needUnshreddedObject && schema.variantIdx >= 0 &&
!row.isNullAt(schema.variantIdx)) {
+ unshreddedObject = new Variant(row.getBinary(schema.variantIdx),
topLevelMetadata)
+ if (unshreddedObject.getType != Type.OBJECT) throw
QueryExecutionErrors.malformedVariant()
+ }
+ val numFields = fieldInputIndices.length
+ var i = 0
+ while (i < numFields) {
+ val inputIdx = fieldInputIndices(i)
+ if (inputIdx >= 0) {
+ // Shredded field must not be null.
+ if (obj.isNullAt(inputIdx)) throw
QueryExecutionErrors.malformedVariant()
+ val fieldSchema = schema.objectSchema(inputIdx).schema
+ val fieldInput = obj.getStruct(inputIdx, fieldSchema.numFields)
+ // Only read from the shredded field if it is not missing.
+ if ((fieldSchema.typedIdx >= 0 &&
!fieldInput.isNullAt(fieldSchema.typedIdx)) ||
+ (fieldSchema.variantIdx >= 0 &&
!fieldInput.isNullAt(fieldSchema.variantIdx))) {
+ result.update(i, fieldReaders(i).read(fieldInput, topLevelMetadata))
+ }
+ } else if (unshreddedObject != null) {
+ val fieldName = targetType.fields(i).name
+ val fieldType = targetType.fields(i).dataType
+ val unshreddedField = unshreddedObject.getFieldByKey(fieldName)
+ if (unshreddedField != null) {
+ result.update(i, VariantGet.cast(unshreddedField, fieldType,
castArgs))
+ }
+ }
+ i += 1
+ }
+ result
+ }
+}
+
+// Read Parquet variant values into a Spark array type.
+private[this] final class ArrayReader(
+ schema: VariantSchema, targetType: ArrayType, castArgs: VariantCastArgs)
+ extends ParquetVariantReader(schema, targetType, castArgs) {
+ private[this] val elementReader = if (schema.arraySchema != null) {
+ ParquetVariantReader(schema.arraySchema, targetType.elementType, castArgs)
+ } else {
+ null
+ }
+
+ override def readFromTyped(row: InternalRow, topLevelMetadata: Array[Byte]):
Any = {
+ if (schema.arraySchema == null) return invalidCast(row, topLevelMetadata)
+ val elementNumFields = schema.arraySchema.numFields
+ val arr = row.getArray(schema.typedIdx)
+ val size = arr.numElements()
+ val result = new Array[Any](size)
+ var i = 0
+ while (i < size) {
+ // Shredded array element must not be null.
+ if (arr.isNullAt(i)) throw QueryExecutionErrors.malformedVariant()
+ result(i) = elementReader.read(arr.getStruct(i, elementNumFields),
topLevelMetadata)
+ i += 1
+ }
+ new GenericArrayData(result)
+ }
+}
+
+// Read Parquet variant values into a Spark map type with string key type. The
input must be object
+// for a valid cast. The resulting map contains shredded fields from object
`typed_value` and
+// unshredded fields from object `value`.
+// `value` must not contain any shredded field according to the shredding
spec. Unlike
+// `StructReader`, this requirement is enforced in `MapReader`. If `value`
does contain a shredded
+// field, throw a MALFORMED_VARIANT error. The purpose is to avoid duplicate
map keys.
+private[this] final class MapReader(
+ schema: VariantSchema, targetType: MapType, castArgs: VariantCastArgs)
+ extends ParquetVariantReader(schema, targetType, castArgs) {
+ // Readers that convert each shredded field into the map value type.
+ private[this] val valueReaders = if (schema.objectSchema != null) {
+ schema.objectSchema.map { f =>
+ ParquetVariantReader(f.schema, targetType.valueType, castArgs)
+ }
+ } else {
+ null
+ }
+ // `UTF8String` representation of shredded field names. Do the `String ->
UTF8String` once, so
+ // that `readFromTyped` doesn't need to do it repeatedly.
+ private[this] val shreddedFieldNames = if (schema.objectSchema != null) {
+ schema.objectSchema.map { f => UTF8String.fromString(f.fieldName) }
+ } else {
+ null
+ }
+
+ override def readFromTyped(row: InternalRow, topLevelMetadata: Array[Byte]):
Any = {
+ if (schema.objectSchema == null) return invalidCast(row, topLevelMetadata)
+ val obj = row.getStruct(schema.typedIdx, schema.objectSchema.length)
+ val numShreddedFields = valueReaders.length
+ var unshreddedObject: Variant = null
+ if (schema.variantIdx >= 0 && !row.isNullAt(schema.variantIdx)) {
+ unshreddedObject = new Variant(row.getBinary(schema.variantIdx),
topLevelMetadata)
+ if (unshreddedObject.getType != Type.OBJECT) throw
QueryExecutionErrors.malformedVariant()
+ }
+ val numUnshreddedFields = if (unshreddedObject != null)
unshreddedObject.objectSize() else 0
+ var keyArray = new Array[UTF8String](numShreddedFields +
numUnshreddedFields)
+ var valueArray = new Array[Any](numShreddedFields + numUnshreddedFields)
+ var mapLength = 0
+ var i = 0
+ while (i < numShreddedFields) {
+ // Shredded field must not be null.
+ if (obj.isNullAt(i)) throw QueryExecutionErrors.malformedVariant()
+ val fieldSchema = schema.objectSchema(i).schema
+ val fieldInput = obj.getStruct(i, fieldSchema.numFields)
+ // Only add the shredded field to map if it is not missing.
+ if ((fieldSchema.typedIdx >= 0 &&
!fieldInput.isNullAt(fieldSchema.typedIdx)) ||
+ (fieldSchema.variantIdx >= 0 &&
!fieldInput.isNullAt(fieldSchema.variantIdx))) {
+ keyArray(mapLength) = shreddedFieldNames(i)
+ valueArray(mapLength) = valueReaders(i).read(fieldInput,
topLevelMetadata)
+ mapLength += 1
+ }
+ i += 1
+ }
+ i = 0
+ while (i < numUnshreddedFields) {
+ val field = unshreddedObject.getFieldAtIndex(i)
+ if (schema.objectSchemaMap.containsKey(field.key)) {
+ throw QueryExecutionErrors.malformedVariant()
+ }
+ keyArray(mapLength) = UTF8String.fromString(field.key)
+ valueArray(mapLength) = VariantGet.cast(field.value,
targetType.valueType, castArgs)
+ mapLength += 1
+ i += 1
+ }
+ // Need to shrink the arrays if there are missing shredded fields.
+ if (mapLength < keyArray.length) {
+ keyArray = keyArray.slice(0, mapLength)
+ valueArray = valueArray.slice(0, mapLength)
+ }
+ ArrayBasedMapData(keyArray, valueArray)
+ }
+}
+
+// Read Parquet variant values into a Spark variant type (the binary format).
+private[this] final class VariantReader(
+ schema: VariantSchema, targetType: DataType, castArgs: VariantCastArgs,
+ // An optional optimization: the user can set it to true if the Parquet
variant column is
+ // unshredded and the extraction path is empty. We are not required to do
anything special, bu
+ // we can avoid rebuilding variant for optimization purpose.
+ private[this] val isTopLevelUnshredded: Boolean)
+ extends ParquetVariantReader(schema, targetType, castArgs) {
+ override def read(row: InternalRow, topLevelMetadata: Array[Byte]): Any = {
+ if (isTopLevelUnshredded) {
+ if (row.isNullAt(schema.variantIdx)) throw
QueryExecutionErrors.malformedVariant()
+ return new VariantVal(row.getBinary(schema.variantIdx), topLevelMetadata)
+ }
+ val v = rebuildVariant(row, topLevelMetadata)
+ new VariantVal(v.getValue, v.getMetadata)
+ }
+}
+
+// Read Parquet variant values into a Spark scalar type. When `typed_value` is
not null but not a
+// scalar, all other target types should return an invalid cast, but only the
string target type can
+// still build a string from array/object `typed_value`. For scalar
`typed_value`, it depends on
+// `ScalarCastHelper` to perform the cast.
+// According to the shredding spec, scalar `typed_value` and `value` must not
be non-null at the
+// same time. The requirement is not enforced in this reader. If they are both
non-null, no error
+// will occur, and the reader will read from `typed_value`.
+private[this] final class ScalarReader(
+ schema: VariantSchema, targetType: DataType, castArgs: VariantCastArgs)
+ extends ParquetVariantReader(schema, targetType, castArgs) {
+ private[this] val castProject = if (schema.scalarSchema != null) {
+ val scalarType =
SparkShreddingUtils.scalarSchemaToSparkType(schema.scalarSchema)
+ // Read the cast input from ordinal `schema.typedIdx` in the input row.
The cast input is never
+ // null, because `readFromTyped` is only called when `typed_value` is not
null.
+ val input = BoundReference(schema.typedIdx, scalarType, nullable = false)
+ MutableProjection.create(Seq(ScalarCastHelper(input, targetType,
castArgs)))
+ } else {
+ null
+ }
+
+ override def readFromTyped(row: InternalRow, topLevelMetadata: Array[Byte]):
Any = {
+ if (castProject == null) {
+ return if (targetType.isInstanceOf[StringType]) {
+ UTF8String.fromString(rebuildVariant(row,
topLevelMetadata).toJson(castArgs.zoneId))
+ } else {
+ invalidCast(row, topLevelMetadata)
+ }
+ }
+ val result = castProject(row)
+ if (result.isNullAt(0)) null else result.get(0, targetType)
+ }
+}
+
case object SparkShreddingUtils {
val VariantValueFieldName = "value";
val TypedValueFieldName = "typed_value";
@@ -126,6 +500,11 @@ case object SparkShreddingUtils {
var objectSchema: Array[VariantSchema.ObjectField] = null
var arraySchema: VariantSchema = null
+ // The struct must not be empty or contain duplicate field names. The
latter is enforced in the
+ // loop below (`if (typedIdx != -1)` and other similar checks).
+ if (schema.fields.isEmpty) {
+ throw QueryCompilationErrors.invalidVariantShreddingSchema(schema)
+ }
schema.fields.zipWithIndex.foreach { case (f, i) =>
f.name match {
case TypedValueFieldName =>
@@ -135,8 +514,11 @@ case object SparkShreddingUtils {
typedIdx = i
f.dataType match {
case StructType(fields) =>
- objectSchema =
- new Array[VariantSchema.ObjectField](fields.length)
+ // The struct must not be empty or contain duplicate field names.
+ if (fields.isEmpty || fields.map(_.name).distinct.length !=
fields.length) {
+ throw
QueryCompilationErrors.invalidVariantShreddingSchema(schema)
+ }
+ objectSchema = new
Array[VariantSchema.ObjectField](fields.length)
fields.zipWithIndex.foreach { case (field, fieldIdx) =>
field.dataType match {
case s: StructType =>
@@ -188,6 +570,32 @@ case object SparkShreddingUtils {
scalarSchema, objectSchema, arraySchema)
}
+ // Convert a scalar variant schema into a Spark scalar type.
+ def scalarSchemaToSparkType(scalar: VariantSchema.ScalarType): DataType =
scalar match {
+ case _: VariantSchema.StringType => StringType
+ case it: VariantSchema.IntegralType => it.size match {
+ case VariantSchema.IntegralSize.BYTE => ByteType
+ case VariantSchema.IntegralSize.SHORT => ShortType
+ case VariantSchema.IntegralSize.INT => IntegerType
+ case VariantSchema.IntegralSize.LONG => LongType
+ }
+ case _: VariantSchema.FloatType => FloatType
+ case _: VariantSchema.DoubleType => DoubleType
+ case _: VariantSchema.BooleanType => BooleanType
+ case _: VariantSchema.BinaryType => BinaryType
+ case dt: VariantSchema.DecimalType => DecimalType(dt.precision, dt.scale)
+ case _: VariantSchema.DateType => DateType
+ case _: VariantSchema.TimestampType => TimestampType
+ case _: VariantSchema.TimestampNTZType => TimestampNTZType
+ }
+
+ // Convert a Parquet type into a Spark data type.
+ def parquetTypeToSparkType(parquetType: ParquetType): DataType = {
+ val messageType =
ParquetTypes.buildMessage().addField(parquetType).named("foo")
+ val column = new ColumnIOFactory().getColumnIO(messageType)
+ new
ParquetToSparkSchemaConverter().convertField(column.getChild(0)).sparkType
+ }
+
class SparkShreddedResult(schema: VariantSchema) extends
VariantShreddingWriter.ShreddedResult {
// Result is stored as an InternalRow.
val row = new GenericInternalRow(schema.numFields)
@@ -243,8 +651,187 @@ case object SparkShreddingUtils {
.row
}
- def rebuild(row: InternalRow, schema: VariantSchema): VariantVal = {
+ // Return a list of fields to extract. `targetType` must be either variant
or variant struct.
+ // If it is variant, return null because the target is the full variant and
there is no field to
+ // extract. If it is variant struct, return a list of fields matching the
variant struct fields.
+ def getFieldsToExtract(targetType: DataType, inputSchema: VariantSchema):
Array[FieldToExtract] =
+ targetType match {
+ case _: VariantType => null
+ case s: StructType if VariantMetadata.isVariantStruct(s) =>
+ s.fields.map { f =>
+ val metadata = VariantMetadata.fromMetadata(f.metadata)
+ val rawPath = metadata.parsedPath()
+ val schemaPath = new Array[SchemaPathSegment](rawPath.length)
+ var schema = inputSchema
+ // Search `rawPath` in `schema` to produce `schemaPath`. If a raw
path segment cannot be
+ // found at a certain level of the file type, then `typedIdx` will
be -1 starting from
+ // this position, and the final `schema` will be null.
+ for (i <- rawPath.indices) {
+ val isObject = rawPath(i).isLeft
+ var typedIdx = -1
+ var extractionIdx = -1
+ rawPath(i) match {
+ case scala.util.Left(key) if schema != null &&
schema.objectSchema != null =>
+ val fieldIdx = schema.objectSchemaMap.get(key)
+ if (fieldIdx != null) {
+ typedIdx = schema.typedIdx
+ extractionIdx = fieldIdx
+ schema = schema.objectSchema(fieldIdx).schema
+ } else {
+ schema = null
+ }
+ case scala.util.Right(index) if schema != null &&
schema.arraySchema != null =>
+ typedIdx = schema.typedIdx
+ extractionIdx = index
+ schema = schema.arraySchema
+ case _ =>
+ schema = null
+ }
+ schemaPath(i) = SchemaPathSegment(rawPath(i), isObject, typedIdx,
extractionIdx)
+ }
+ val reader = ParquetVariantReader(schema, f.dataType,
VariantCastArgs(
+ metadata.failOnError,
+ Some(metadata.timeZoneId),
+ DateTimeUtils.getZoneId(metadata.timeZoneId)),
+ isTopLevelUnshredded = schemaPath.isEmpty &&
inputSchema.isUnshredded)
+ FieldToExtract(schemaPath, reader)
+ }
+ case _ =>
+ throw QueryExecutionErrors.unreachableError(s"Invalid target type:
`${targetType.sql}`")
+ }
+
+ // Extract a single variant struct field from a Parquet variant value. It
steps into `inputRow`
+ // according to the variant extraction path, and read the extracted value as
the target type.
+ private def extractField(
+ inputRow: InternalRow,
+ topLevelMetadata: Array[Byte],
+ inputSchema: VariantSchema,
+ pathList: Array[SchemaPathSegment],
+ reader: ParquetVariantReader): Any = {
+ var pathIdx = 0
+ val pathLen = pathList.length
+ var row = inputRow
+ var schema = inputSchema
+ while (pathIdx < pathLen) {
+ val path = pathList(pathIdx)
+
+ if (path.typedIdx < 0) {
+ // The extraction doesn't exist in `typed_value`. Try to extract the
remaining part of the
+ // path in `value`.
+ val variantIdx = schema.variantIdx
+ if (variantIdx < 0 || row.isNullAt(variantIdx)) return null
+ var v = new Variant(row.getBinary(variantIdx), topLevelMetadata)
+ while (pathIdx < pathLen) {
+ v = pathList(pathIdx).rawPath match {
+ case scala.util.Left(key) if v.getType == Type.OBJECT =>
v.getFieldByKey(key)
+ case scala.util.Right(index) if v.getType == Type.ARRAY =>
v.getElementAtIndex(index)
+ case _ => null
+ }
+ if (v == null) return null
+ pathIdx += 1
+ }
+ return VariantGet.cast(v, reader.targetType, reader.castArgs)
+ }
+
+ if (row.isNullAt(path.typedIdx)) return null
+ if (path.isObject) {
+ val obj = row.getStruct(path.typedIdx, schema.objectSchema.length)
+ // Object field must not be null.
+ if (obj.isNullAt(path.extractionIdx)) throw
QueryExecutionErrors.malformedVariant()
+ schema = schema.objectSchema(path.extractionIdx).schema
+ row = obj.getStruct(path.extractionIdx, schema.numFields)
+ // Return null if the field is missing.
+ if ((schema.typedIdx < 0 || row.isNullAt(schema.typedIdx)) &&
+ (schema.variantIdx < 0 || row.isNullAt(schema.variantIdx))) {
+ return null
+ }
+ } else {
+ val arr = row.getArray(path.typedIdx)
+ // Return null if the extraction index is out of bound.
+ if (path.extractionIdx >= arr.numElements()) return null
+ // Array element must not be null.
+ if (arr.isNullAt(path.extractionIdx)) throw
QueryExecutionErrors.malformedVariant()
+ schema = schema.arraySchema
+ row = arr.getStruct(path.extractionIdx, schema.numFields)
+ }
+ pathIdx += 1
+ }
+ reader.read(row, topLevelMetadata)
+ }
+
+ // Assemble a variant (binary format) from a Parquet variant value.
+ def assembleVariant(row: InternalRow, schema: VariantSchema): VariantVal = {
val v = ShreddingUtils.rebuild(SparkShreddedRow(row), schema)
new VariantVal(v.getValue, v.getMetadata)
}
+
+ // Assemble a variant struct, in which each field is extracted from the
Parquet variant value.
+ def assembleVariantStruct(
+ inputRow: InternalRow,
+ schema: VariantSchema,
+ fields: Array[FieldToExtract]): InternalRow = {
+ if (inputRow.isNullAt(schema.topLevelMetadataIdx)) {
+ throw QueryExecutionErrors.malformedVariant()
+ }
+ val topLevelMetadata = inputRow.getBinary(schema.topLevelMetadataIdx)
+ val numFields = fields.length
+ val resultRow = new GenericInternalRow(numFields)
+ var fieldIdx = 0
+ while (fieldIdx < numFields) {
+ resultRow.update(fieldIdx, extractField(inputRow, topLevelMetadata,
schema,
+ fields(fieldIdx).path, fields(fieldIdx).reader))
+ fieldIdx += 1
+ }
+ resultRow
+ }
+
+ // Assemble a batch of variant (binary format) from a batch of Parquet
variant values.
+ def assembleVariantBatch(
+ input: WritableColumnVector,
+ output: WritableColumnVector,
+ schema: VariantSchema): Unit = {
+ val numRows = input.getElementsAppended
+ output.reset()
+ output.reserve(numRows)
+ val valueChild = output.getChild(0)
+ val metadataChild = output.getChild(1)
+ var i = 0
+ while (i < numRows) {
+ if (input.isNullAt(i)) {
+ output.appendStruct(true)
+ } else {
+ output.appendStruct(false)
+ val v = SparkShreddingUtils.assembleVariant(input.getStruct(i), schema)
+ valueChild.appendByteArray(v.getValue, 0, v.getValue.length)
+ metadataChild.appendByteArray(v.getMetadata, 0, v.getMetadata.length)
+ }
+ i += 1
+ }
+ }
+
+ // Assemble a batch of variant struct from a batch of Parquet variant values.
+ def assembleVariantStructBatch(
+ input: WritableColumnVector,
+ output: WritableColumnVector,
+ schema: VariantSchema,
+ fields: Array[FieldToExtract]): Unit = {
+ val numRows = input.getElementsAppended
+ output.reset()
+ output.reserve(numRows)
+ val converter = new RowToColumnConverter(StructType(Array(StructField("",
output.dataType()))))
+ val converterVectors = Array(output)
+ val converterRow = new GenericInternalRow(1)
+ output.reset()
+ output.reserve(input.getElementsAppended)
+ var i = 0
+ while (i < numRows) {
+ if (input.isNullAt(i)) {
+ converterRow.update(0, null)
+ } else {
+ converterRow.update(0, assembleVariantStruct(input.getStruct(i),
schema, fields))
+ }
+ converter.convert(converterRow, converterVectors)
+ i += 1
+ }
+ }
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/VariantShreddingSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/VariantShreddingSuite.scala
index 5d5c44105255..b6623bb57a71 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/VariantShreddingSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantShreddingSuite.scala
@@ -22,13 +22,21 @@ import java.sql.{Date, Timestamp}
import java.time.LocalDateTime
import org.apache.spark.SparkThrowable
+import org.apache.spark.sql.catalyst.InternalRow
+import
org.apache.spark.sql.catalyst.expressions.variant.VariantExpressionEvalUtils
import org.apache.spark.sql.execution.datasources.parquet.{ParquetTest,
SparkShreddingUtils}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types._
import org.apache.spark.types.variant._
+import org.apache.spark.unsafe.types.{UTF8String, VariantVal}
class VariantShreddingSuite extends QueryTest with SharedSparkSession with
ParquetTest {
+ def parseJson(s: String): VariantVal = {
+ val v = VariantBuilder.parseJson(s, false)
+ new VariantVal(v.getValue, v.getMetadata)
+ }
+
// Make a variant value binary by parsing a JSON string.
def value(s: String): Array[Byte] = VariantBuilder.parseJson(s,
false).getValue
@@ -53,9 +61,21 @@ class VariantShreddingSuite extends QueryTest with
SharedSparkSession with Parqu
def writeSchema(schema: DataType): StructType =
StructType(Array(StructField("v",
SparkShreddingUtils.variantShreddingSchema(schema))))
+ def withPushConfigs(pushConfigs: Seq[Boolean] = Seq(true, false))(fn: =>
Unit): Unit = {
+ for (push <- pushConfigs) {
+ withSQLConf(SQLConf.PUSH_VARIANT_INTO_SCAN.key -> push.toString) {
+ fn
+ }
+ }
+ }
+
+ def isPushEnabled: Boolean =
SQLConf.get.getConf(SQLConf.PUSH_VARIANT_INTO_SCAN)
+
def testWithTempPath(name: String)(block: File => Unit): Unit = test(name) {
- withTempPath { path =>
- block(path)
+ withPushConfigs() {
+ withTempPath { path =>
+ block(path)
+ }
}
}
@@ -63,6 +83,9 @@ class VariantShreddingSuite extends QueryTest with
SharedSparkSession with Parqu
spark.createDataFrame(spark.sparkContext.parallelize(rows.map(Row(_)),
numSlices = 1), schema)
.write.mode("overwrite").parquet(path.getAbsolutePath)
+ def writeRows(path: File, schema: String, rows: Row*): Unit =
+ writeRows(path, StructType.fromDDL(schema), rows: _*)
+
def read(path: File): DataFrame =
spark.read.schema("v variant").parquet(path.getAbsolutePath)
@@ -150,10 +173,13 @@ class VariantShreddingSuite extends QueryTest with
SharedSparkSession with Parqu
// Top-level variant must not be missing.
writeRows(path, writeSchema(IntegerType), Row(metadata(Nil), null, null))
checkException(path, "v", "MALFORMED_VARIANT")
+
// Array-element variant must not be missing.
writeRows(path, writeSchema(ArrayType(IntegerType)),
Row(metadata(Nil), null, Array(Row(null, null))))
checkException(path, "v", "MALFORMED_VARIANT")
+ checkException(path, "variant_get(v, '$[0]')", "MALFORMED_VARIANT")
+
// Shredded field must not be null.
// Construct the schema manually, because
SparkShreddingUtils.variantShreddingSchema will make
// `a` non-nullable, which would prevent us from writing the file.
@@ -164,12 +190,163 @@ class VariantShreddingSuite extends QueryTest with
SharedSparkSession with Parqu
StructField("a", StructType(Seq(
StructField("value", BinaryType),
StructField("typed_value", BinaryType))))))))))))
- writeRows(path, schema,
- Row(metadata(Seq("a")), null, Row(null)))
+ writeRows(path, schema, Row(metadata(Seq("a")), null, Row(null)))
checkException(path, "v", "MALFORMED_VARIANT")
+ checkException(path, "variant_get(v, '$.a')", "MALFORMED_VARIANT")
+
// `value` must not contain any shredded field.
writeRows(path, writeSchema(StructType.fromDDL("a int")),
Row(metadata(Seq("a")), value("""{"a": 1}"""), Row(Row(null, null))))
checkException(path, "v", "MALFORMED_VARIANT")
+ checkException(path, "cast(v as map<string, int>)", "MALFORMED_VARIANT")
+ if (isPushEnabled) {
+ checkExpr(path, "cast(v as struct<a int>)", Row(null))
+ checkExpr(path, "variant_get(v, '$.a', 'int')", null)
+ } else {
+ checkException(path, "cast(v as struct<a int>)", "MALFORMED_VARIANT")
+ checkException(path, "variant_get(v, '$.a', 'int')", "MALFORMED_VARIANT")
+ }
+
+ // Scalar reader reads from `typed_value` if both `value` and
`typed_value` are not null.
+ // Cast from `value` succeeds, cast from `typed_value` fails.
+ writeRows(path, "v struct<metadata binary, value binary, typed_value
string>",
+ Row(metadata(Nil), value("1"), "invalid"))
+ checkException(path, "cast(v as int)", "INVALID_VARIANT_CAST")
+ checkExpr(path, "try_cast(v as int)", null)
+
+ // Cast from `value` fails, cast from `typed_value` succeeds.
+ writeRows(path, "v struct<metadata binary, value binary, typed_value
string>",
+ Row(metadata(Nil), value("\"invalid\""), "1"))
+ checkExpr(path, "cast(v as int)", 1)
+ checkExpr(path, "try_cast(v as int)", 1)
+ }
+
+ testWithTempPath("extract from shredded object") { path =>
+ val keys1 = Seq("a", "b", "c", "d")
+ val keys2 = Seq("a", "b", "c", "e", "f")
+ writeRows(path, "v struct<metadata binary, value binary, typed_value
struct<" +
+ "a struct<value binary, typed_value int>, b struct<value binary>," +
+ "c struct<typed_value decimal(20, 10)>>>",
+ // {"a":1,"b":"2","c":3.3,"d":4.4}, d is in the left over value.
+ Row(metadata(keys1), shreddedValue("""{"d": 4.4}""", keys1),
+ Row(Row(null, 1), Row(value("\"2\"")), Row(Decimal("3.3")))),
+ // {"a":5.4,"b":-6,"e":{"f":[true]}}, e is in the left over value.
+ Row(metadata(keys2), shreddedValue("""{"e": {"f": [true]}}""", keys2),
+ Row(Row(value("5.4"), null), Row(value("-6")), Row(null))),
+ // [{"a":1}], the unshredded array at the top-level is put into `value`
as a whole.
+ Row(metadata(Seq("a")), value("""[{"a": 1}]"""), null))
+
+ checkAnswer(read(path).selectExpr("variant_get(v, '$.a', 'int')",
+ "variant_get(v, '$.b', 'long')", "variant_get(v, '$.c', 'double')",
+ "variant_get(v, '$.d', 'decimal(9, 4)')"),
+ Seq(Row(1, 2L, 3.3, BigDecimal("4.4")), Row(5, -6L, null, null),
Row(null, null, null, null)))
+ checkExpr(path, "variant_get(v, '$.e.f[0]', 'boolean')", null, true, null)
+ checkExpr(path, "variant_get(v, '$[0].a', 'boolean')", null, null, true)
+ checkExpr(path, "try_cast(v as struct<a float, e variant>)",
+ Row(1.0F, null), Row(5.4F, parseJson("""{"f": [true]}""")), null)
+
+ // String "2" cannot be cast into boolean.
+ checkException(path, "variant_get(v, '$.b', 'boolean')",
"INVALID_VARIANT_CAST")
+ // Decimal cannot be cast into date.
+ checkException(path, "variant_get(v, '$.c', 'date')",
"INVALID_VARIANT_CAST")
+ // The value of `c` doesn't fit into `decimal(1, 1)`.
+ checkException(path, "variant_get(v, '$.c', 'decimal(1, 1)')",
"INVALID_VARIANT_CAST")
+ checkExpr(path, "try_variant_get(v, '$.b', 'boolean')", null, true, null)
+ // Scalar cannot be cast into struct.
+ checkException(path, "variant_get(v, '$.a', 'struct<a int>')",
"INVALID_VARIANT_CAST")
+ checkExpr(path, "try_variant_get(v, '$.a', 'struct<a int>')", null, null,
null)
+
+ checkExpr(path, "try_cast(v as map<string, double>)",
+ Map("a" -> 1.0, "b" -> 2.0, "c" -> 3.3, "d" -> 4.4),
+ Map("a" -> 5.4, "b" -> -6.0, "e" -> null), null)
+ checkExpr(path, "try_cast(v as array<string>)", null, null,
Seq("""{"a":1}"""))
+
+ val strings = Seq("""{"a":1,"b":"2","c":3.3,"d":4.4}""",
+ """{"a":5.4,"b":-6,"e":{"f":[true]}}""", """[{"a":1}]""")
+ checkExpr(path, "cast(v as string)", strings: _*)
+ checkExpr(path, "v",
+ VariantExpressionEvalUtils.castToVariant(
+ InternalRow(1, UTF8String.fromString("2"), Decimal("3.3000000000"),
Decimal("4.4")),
+ StructType.fromDDL("a int, b string, c decimal(20, 10), d decimal(2,
1)")
+ ),
+ parseJson(strings(1)),
+ parseJson(strings(2))
+ )
+ }
+
+ testWithTempPath("extract from shredded array") { path =>
+ val keys = Seq("a", "b")
+ writeRows(path, "v struct<metadata binary, value binary, typed_value
array<" +
+ "struct<value binary, typed_value struct<a struct<value binary,
typed_value string>>>>>",
+ // [{"a":"2000-01-01"},{"a":"1000-01-01","b":[7]}], b is in the left
over value.
+ Row(metadata(keys), null, Array(
+ Row(null, Row(Row(null, "2000-01-01"))),
+ Row(shreddedValue("""{"b": [7]}""", keys), Row(Row(null,
"1000-01-01"))))),
+ // [null,{"a":null},{"a":"null"},{}]
+ Row(metadata(keys), null, Array(
+ Row(value("null"), null),
+ Row(null, Row(Row(value("null"), null))),
+ Row(null, Row(Row(null, "null"))),
+ Row(null, Row(Row(null, null))))))
+
+ val date1 = Date.valueOf("2000-01-01")
+ val date2 = Date.valueOf("1000-01-01")
+ checkExpr(path, "variant_get(v, '$[0].a', 'date')", date1, null)
+ // try_cast succeeds.
+ checkExpr(path, "try_variant_get(v, '$[1].a', 'date')", date2, null)
+ // The first array returns null because of out-of-bound index.
+ // The second array returns "null".
+ checkExpr(path, "variant_get(v, '$[2].a', 'string')", null, "null")
+ // Return null because of invalid cast.
+ checkExpr(path, "try_variant_get(v, '$[1].a', 'int')", null, null)
+
+ checkExpr(path, "variant_get(v, '$[0].b[0]', 'int')", null, null)
+ checkExpr(path, "variant_get(v, '$[1].b[0]', 'int')", 7, null)
+ // Validate timestamp-related casts uses the session time zone correctly.
+ Seq("Etc/UTC", "America/Los_Angeles").foreach { tz =>
+ withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz) {
+ val expected = sql("select timestamp'1000-01-01',
timestamp_ntz'1000-01-01'").head()
+ checkAnswer(read(path).selectExpr("variant_get(v, '$[1].a',
'timestamp')",
+ "variant_get(v, '$[1].a', 'timestamp_ntz')"), Seq(expected,
Row(null, null)))
+ }
+ }
+ checkException(path, "variant_get(v, '$[0]', 'int')",
"INVALID_VARIANT_CAST")
+ // An out-of-bound array access produces null. It never causes an invalid
cast.
+ checkExpr(path, "variant_get(v, '$[4]', 'int')", null, null)
+
+ checkExpr(path, "cast(v as array<struct<a string, b array<int>>>)",
+ Seq(Row("2000-01-01", null), Row("1000-01-01", Seq(7))),
+ Seq(null, Row(null, null), Row("null", null), Row(null, null)))
+ checkExpr(path, "cast(v as array<map<string, string>>)",
+ Seq(Map("a" -> "2000-01-01"), Map("a" -> "1000-01-01", "b" -> "[7]")),
+ Seq(null, Map("a" -> null), Map("a" -> "null"), Map()))
+ checkExpr(path, "try_cast(v as array<map<string, date>>)",
+ Seq(Map("a" -> date1), Map("a" -> date2, "b" -> null)),
+ Seq(null, Map("a" -> null), Map("a" -> null), Map()))
+
+ val strings = Seq("""[{"a":"2000-01-01"},{"a":"1000-01-01","b":[7]}]""",
+ """[null,{"a":null},{"a":"null"},{}]""")
+ checkExpr(path, "cast(v as string)", strings: _*)
+ checkExpr(path, "v", strings.map(parseJson): _*)
+ }
+
+ testWithTempPath("missing fields") { path =>
+ writeRows(path, "v struct<metadata binary, typed_value struct<" +
+ "a struct<value binary, typed_value int>, b struct<typed_value int>>>",
+ Row(metadata(Nil), Row(Row(null, null), Row(null))),
+ Row(metadata(Nil), Row(Row(value("null"), null), Row(null))),
+ Row(metadata(Nil), Row(Row(null, 1), Row(null))),
+ Row(metadata(Nil), Row(Row(null, null), Row(2))),
+ Row(metadata(Nil), Row(Row(value("null"), null), Row(2))),
+ Row(metadata(Nil), Row(Row(null, 3), Row(4))))
+
+ val strings = Seq("{}", """{"a":null}""", """{"a":1}""", """{"b":2}""",
"""{"a":null,"b":2}""",
+ """{"a":3,"b":4}""")
+ checkExpr(path, "cast(v as string)", strings: _*)
+ checkExpr(path, "v", strings.map(parseJson): _*)
+
+ checkExpr(path, "variant_get(v, '$.a', 'string')", null, null, "1", null,
null, "3")
+ checkExpr(path, "variant_get(v, '$.a')", null, parseJson("null"),
parseJson("1"), null,
+ parseJson("null"), parseJson("3"))
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]