This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 5c634845f084 [SPARK-50614][SQL] Add Variant shredding support for
Parquet
5c634845f084 is described below
commit 5c634845f0849b3d5bf6a98c4ecddff26b71572b
Author: cashmand <[email protected]>
AuthorDate: Thu Jan 2 12:39:07 2025 +0800
[SPARK-50614][SQL] Add Variant shredding support for Parquet
### What changes were proposed in this pull request?
Adds support for shredding in the Parquet writer code. Currently, the only
way to enable shredding is through a SQLConf that provides the schema to use
for shredding. This doesn't make sense as a user API, and is added only for
testing. The exact API for Spark to determine a shredding schema is still TBD,
but likely candidates are to infer it at the task level by inspecting the first
few rows of data, or add an API to specify the schema for a given column.
Either way, the code in this [...]
### Why are the changes needed?
Needed for Variant shredding support.
### Does this PR introduce _any_ user-facing change?
No, the feature is new in Spark 4.0, and is currently disabled, and only
usable as a test feature.
### How was this patch tested?
Added a unit test suite.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #49234 from cashmand/SPARK-50614.
Authored-by: cashmand <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../org/apache/spark/sql/internal/SQLConf.scala | 16 ++
.../parquet/ParquetSchemaConverter.scala | 8 +
.../datasources/parquet/ParquetUtils.scala | 34 ++-
.../datasources/parquet/ParquetWriteSupport.scala | 65 +++++-
.../datasources/parquet/SparkShreddingUtils.scala | 23 +++
.../parquet/ParquetVariantShreddingSuite.scala | 229 +++++++++++++++++++++
6 files changed, 362 insertions(+), 13 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 5e630577638a..6a45380b7a99 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -4659,6 +4659,22 @@ object SQLConf {
.booleanConf
.createWithDefault(false)
+ val VARIANT_WRITE_SHREDDING_ENABLED =
+ buildConf("spark.sql.variant.writeShredding.enabled")
+ .internal()
+ .doc("When true, the Parquet writer is allowed to write shredded
variant. ")
+ .version("4.0.0")
+ .booleanConf
+ .createWithDefault(false)
+
+ val VARIANT_FORCE_SHREDDING_SCHEMA_FOR_TEST =
+ buildConf("spark.sql.variant.forceShreddingSchemaForTest")
+ .internal()
+ .doc("FOR INTERNAL TESTING ONLY. Sets shredding schema for Variant.")
+ .version("4.0.0")
+ .stringConf
+ .createWithDefault("")
+
val LEGACY_CSV_ENABLE_DATE_TIME_PARSING_FALLBACK =
buildConf("spark.sql.legacy.csv.enableDateTimeParsingFallback")
.internal()
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala
index 64c2a3126ca9..daeb8e88a924 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala
@@ -756,6 +756,14 @@ class SparkToParquetSchemaConverter(
.addField(convertField(StructField("metadata", BinaryType, nullable
= false)))
.named(field.name)
+ case s: StructType if SparkShreddingUtils.isVariantShreddingStruct(s) =>
+ // Variant struct takes a Variant and writes to Parquet as a shredded
schema.
+ val group = Types.buildGroup(repetition)
+ s.fields.foreach { f =>
+ group.addField(convertField(f))
+ }
+ group.named(field.name)
+
case StructType(fields) =>
fields.foldLeft(Types.buildGroup(repetition)) { (builder, field) =>
builder.addField(convertField(field))
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala
index 3e111252bc6f..a609a4e0a25f 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala
@@ -45,7 +45,7 @@ import
org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, Outpu
import org.apache.spark.sql.execution.datasources.v2.V2ColumnUtils
import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf}
import org.apache.spark.sql.internal.SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED
-import org.apache.spark.sql.types.{ArrayType, AtomicType, DataType, MapType,
StructField, StructType, UserDefinedType}
+import org.apache.spark.sql.types.{ArrayType, AtomicType, DataType, MapType,
StructField, StructType, UserDefinedType, VariantType}
import org.apache.spark.util.ArrayImplicits._
object ParquetUtils extends Logging {
@@ -420,6 +420,22 @@ object ParquetUtils extends Logging {
statistics.getNumNulls;
}
+ // Replaces each VariantType in the schema with the corresponding type in
the shredding schema.
+ // Used for testing, where we force a single shredding schema for all
Variant fields.
+ // Does not touch Variant fields nested in arrays, maps, or UDTs.
+ private def replaceVariantTypes(schema: StructType, shreddingSchema:
StructType): StructType = {
+ val newFields = schema.fields.zip(shreddingSchema.fields).map {
+ case (field, shreddingField) =>
+ field.dataType match {
+ case s: StructType =>
+ field.copy(dataType = replaceVariantTypes(s, shreddingSchema))
+ case VariantType => field.copy(dataType = shreddingSchema)
+ case _ => field
+ }
+ }
+ StructType(newFields)
+ }
+
def prepareWrite(
sqlConf: SQLConf,
job: Job,
@@ -454,8 +470,22 @@ object ParquetUtils extends Logging {
ParquetOutputFormat.setWriteSupportClass(job, classOf[ParquetWriteSupport])
+ val shreddingSchema = if
(sqlConf.getConf(SQLConf.VARIANT_WRITE_SHREDDING_ENABLED) &&
+
!sqlConf.getConf(SQLConf.VARIANT_FORCE_SHREDDING_SCHEMA_FOR_TEST).isEmpty) {
+ // Convert the schema to a shredding schema, and replace it anywhere
that there is a
+ // VariantType in the original schema.
+ val simpleShreddingSchema = DataType.fromDDL(
+ sqlConf.getConf(SQLConf.VARIANT_FORCE_SHREDDING_SCHEMA_FOR_TEST)
+ )
+ val oneShreddingSchema =
SparkShreddingUtils.variantShreddingSchema(simpleShreddingSchema)
+ val schemaWithMetadata =
SparkShreddingUtils.addWriteShreddingMetadata(oneShreddingSchema)
+ Some(replaceVariantTypes(dataSchema, schemaWithMetadata))
+ } else {
+ None
+ }
+
// This metadata is useful for keeping UDTs like Vector/Matrix.
- ParquetWriteSupport.setSchema(dataSchema, conf)
+ ParquetWriteSupport.setSchema(dataSchema, conf, shreddingSchema)
// Sets flags for `ParquetWriteSupport`, which converts Catalyst schema to
Parquet
// schema and writes actual rows to Parquet files.
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala
index 89a1cd5d4375..9402f5638094 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala
@@ -38,6 +38,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.datasources.DataSourceUtils
import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf}
import org.apache.spark.sql.types._
+import org.apache.spark.types.variant.Variant
/**
* A Parquet [[WriteSupport]] implementation that writes Catalyst
[[InternalRow]]s as Parquet
@@ -59,6 +60,10 @@ class ParquetWriteSupport extends WriteSupport[InternalRow]
with Logging {
// Schema of the `InternalRow`s to be written
private var schema: StructType = _
+ // Schema of the `InternalRow`s to be written, with VariantType replaced
with its shredding
+ // schema, if appropriate.
+ private var shreddedSchema: StructType = _
+
// `ValueWriter`s for all fields of the schema
private var rootFieldWriters: Array[ValueWriter] = _
@@ -95,7 +100,16 @@ class ParquetWriteSupport extends WriteSupport[InternalRow]
with Logging {
override def init(configuration: Configuration): WriteContext = {
val schemaString = configuration.get(ParquetWriteSupport.SPARK_ROW_SCHEMA)
+ val shreddedSchemaString =
configuration.get(ParquetWriteSupport.SPARK_VARIANT_SHREDDING_SCHEMA)
this.schema = StructType.fromString(schemaString)
+ // If shreddingSchemaString is provided, we use that everywhere in the
writer, except for
+ // setting the spark schema in the Parquet metadata. If it isn't provided,
it means that there
+ // are no shredded Variant columns, so it is identical to this.schema.
+ this.shreddedSchema = if (shreddedSchemaString == null) {
+ this.schema
+ } else {
+ StructType.fromString(shreddedSchemaString)
+ }
this.writeLegacyParquetFormat = {
// `SQLConf.PARQUET_WRITE_LEGACY_FORMAT` should always be explicitly set
in ParquetRelation
assert(configuration.get(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key) !=
null)
@@ -108,9 +122,9 @@ class ParquetWriteSupport extends WriteSupport[InternalRow]
with Logging {
SQLConf.ParquetOutputTimestampType.withName(configuration.get(key))
}
- this.rootFieldWriters =
schema.map(_.dataType).map(makeWriter).toArray[ValueWriter]
+ this.rootFieldWriters =
shreddedSchema.map(_.dataType).map(makeWriter).toArray[ValueWriter]
- val messageType = new
SparkToParquetSchemaConverter(configuration).convert(schema)
+ val messageType = new
SparkToParquetSchemaConverter(configuration).convert(shreddedSchema)
val metadata = Map(
SPARK_VERSION_METADATA_KEY -> SPARK_VERSION_SHORT,
ParquetReadSupport.SPARK_METADATA_KEY -> schemaString
@@ -132,13 +146,23 @@ class ParquetWriteSupport extends
WriteSupport[InternalRow] with Logging {
}
}
- logDebug(
- s"""Initialized Parquet WriteSupport with Catalyst schema:
- |${schema.prettyJson}
- |and corresponding Parquet message type:
- |$messageType
- """.stripMargin)
-
+ if (shreddedSchemaString == null) {
+ logDebug(
+ s"""Initialized Parquet WriteSupport with Catalyst schema:
+ |${schema.prettyJson}
+ |and corresponding Parquet message type:
+ |$messageType
+ """.stripMargin)
+ } else {
+ logDebug(
+ s"""Initialized Parquet WriteSupport with Catalyst schema:
+ |${schema.prettyJson}
+ |and shredding schema:
+ |$shreddedSchema.prettyJson}
+ |and corresponding Parquet message type:
+ |$messageType
+ """.stripMargin)
+ }
new WriteContext(messageType, metadata.asJava)
}
@@ -148,7 +172,7 @@ class ParquetWriteSupport extends WriteSupport[InternalRow]
with Logging {
override def write(row: InternalRow): Unit = {
consumeMessage {
- writeFields(row, schema, rootFieldWriters)
+ writeFields(row, shreddedSchema, rootFieldWriters)
}
}
@@ -250,6 +274,17 @@ class ParquetWriteSupport extends
WriteSupport[InternalRow] with Logging {
}
}
+ case s: StructType if SparkShreddingUtils.isVariantShreddingStruct(s) =>
+ val fieldWriters =
s.map(_.dataType).map(makeWriter).toArray[ValueWriter]
+ val variantShreddingSchema = SparkShreddingUtils.buildVariantSchema(s)
+ (row: SpecializedGetters, ordinal: Int) =>
+ val v = row.getVariant(ordinal)
+ val variant = new Variant(v.getValue, v.getMetadata)
+ val shreddedValues = SparkShreddingUtils.castShredded(variant,
variantShreddingSchema)
+ consumeGroup {
+ writeFields(shreddedValues, s, fieldWriters)
+ }
+
case t: StructType =>
val fieldWriters =
t.map(_.dataType).map(makeWriter).toArray[ValueWriter]
(row: SpecializedGetters, ordinal: Int) =>
@@ -499,11 +534,19 @@ class ParquetWriteSupport extends
WriteSupport[InternalRow] with Logging {
object ParquetWriteSupport {
val SPARK_ROW_SCHEMA: String = "org.apache.spark.sql.parquet.row.attributes"
+ // A version of `SPARK_ROW_SCHEMA`, where one or more Variant attributes
have been replace with a
+ // shredded struct schema.
+ val SPARK_VARIANT_SHREDDING_SCHEMA: String =
+ "org.apache.spark.sql.parquet.variant.shredding.attributes"
- def setSchema(schema: StructType, configuration: Configuration): Unit = {
+ def setSchema(schema: StructType, configuration: Configuration,
+ shreddingSchema: Option[StructType]): Unit = {
configuration.set(SPARK_ROW_SCHEMA, schema.json)
configuration.setIfUnset(
ParquetOutputFormat.WRITER_VERSION,
ParquetProperties.WriterVersion.PARQUET_1_0.toString)
+ shreddingSchema.foreach { s =>
+ configuration.set(SPARK_VARIANT_SHREDDING_SCHEMA, s.json)
+ }
}
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala
index 34c167aea363..c0c490034415 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala
@@ -448,6 +448,8 @@ case object SparkShreddingUtils {
val TypedValueFieldName = "typed_value";
val MetadataFieldName = "metadata";
+ val VARIANT_WRITE_SHREDDING_KEY: String = "__VARIANT_WRITE_SHREDDING_KEY"
+
def buildVariantSchema(schema: DataType): VariantSchema = {
schema match {
case s: StructType => buildVariantSchema(s, topLevel = true)
@@ -512,6 +514,27 @@ case object SparkShreddingUtils {
}
}
+ /**
+ * Given a schema that represents a valid shredding schema (e.g. constructed
by
+ * SparkShreddingUtils.variantShreddingSchema), add metadata to the
top-level fields to mark it
+ * as a shredding schema for writers.
+ */
+ def addWriteShreddingMetadata(schema: StructType): StructType = {
+ val newFields = schema.fields.map { f =>
+ f.copy(metadata = new
+ MetadataBuilder()
+ .withMetadata(f.metadata)
+ .putNull(VARIANT_WRITE_SHREDDING_KEY).build())
+ }
+ StructType(newFields)
+ }
+
+ // Check if the struct is marked with metadata set by
addWriteShreddingMetadata - i.e. it
+ // represents a Variant converted to a shredding schema for writing.
+ def isVariantShreddingStruct(s: StructType): Boolean = {
+ s.fields.length > 0 &&
s.fields.forall(_.metadata.contains(VARIANT_WRITE_SHREDDING_KEY))
+ }
+
/*
* Given a Spark schema that represents a valid shredding schema (e.g.
constructed by
* SparkShreddingUtils.variantShreddingSchema), return the corresponding
VariantSchema.
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetVariantShreddingSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetVariantShreddingSuite.scala
new file mode 100644
index 000000000000..8bb5a4b1d0bc
--- /dev/null
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetVariantShreddingSuite.scala
@@ -0,0 +1,229 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.parquet
+
+import java.io.File
+
+import org.apache.spark.sql.{QueryTest, Row}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.unsafe.types.VariantVal
+
+/**
+ * Test shredding Variant values in the Parquet reader/writer.
+ */
+class ParquetVariantShreddingSuite extends QueryTest with ParquetTest with
SharedSparkSession {
+
+ private def testWithTempDir(name: String)(block: File => Unit): Unit =
test(name) {
+ withTempDir { dir =>
+ block(dir)
+ }
+ }
+
+ testWithTempDir("write shredded variant basic") { dir =>
+ val schema = "a int, b string, c decimal(15, 1)"
+ val df = spark.sql(
+ """
+ | select case
+ | when id = 0 then parse_json('{"a": 1, "b": "2", "c": 3.3, "d": 4.4}')
+ | when id = 1 then parse_json('{"a": [1,2,3], "b": "hello", "c": {"x":
0}}')
+ | when id = 2 then parse_json('{"A": 1, "c": 1.23}')
+ | end v from range(3)
+ |""".stripMargin)
+ val fullSchema = "v struct<metadata binary, value binary, typed_value
struct<" +
+ "a struct<value binary, typed_value int>, b struct<value binary,
typed_value string>," +
+ "c struct<value binary, typed_value decimal(15, 1)>>>"
+ withSQLConf(SQLConf.VARIANT_WRITE_SHREDDING_ENABLED.key -> true.toString,
+ SQLConf.VARIANT_FORCE_SHREDDING_SCHEMA_FOR_TEST.key -> schema) {
+ df.write.mode("overwrite").parquet(dir.getAbsolutePath)
+
+
+ // Verify that we can read the full variant. The exact binary layout can
change before and
+ // after shredding, so just check that the JSON representation matches.
+ checkAnswer(
+ spark.read.parquet(dir.getAbsolutePath).selectExpr("to_json(v)"),
+ df.selectExpr("to_json(v)").collect()
+ )
+
+ // Verify that it was shredded to the expected fields.
+
+ val shreddedDf =
spark.read.schema(fullSchema).parquet(dir.getAbsolutePath)
+ // Metadata should be unchanaged.
+ checkAnswer(shreddedDf.selectExpr("v.metadata"),
+ df.collect().map(v =>
Row(v.get(0).asInstanceOf[VariantVal].getMetadata))
+ )
+
+ // Check typed values.
+ // Second row is not an integer, and third is A, not a
+ checkAnswer(
+ shreddedDf.selectExpr("v.typed_value.a.typed_value"),
+ Seq(Row(1), Row(null), Row(null)))
+ // b is missing from third row.
+ checkAnswer(
+ shreddedDf.selectExpr("v.typed_value.b.typed_value"),
+ Seq(Row("2"), Row("hello"), Row(null)))
+ // Second row is an object, third is the wrong scale. (Note: we may
eventually allow the
+ // latter, in which case this test should be updated.)
+ checkAnswer(
+ shreddedDf.selectExpr("v.typed_value.c.typed_value"),
+ Seq(Row(3.3), Row(null), Row(null)))
+
+ // Untyped values are more awkward to check, so for now just check their
nullness. We
+ // can do more thorough checking once the reader is ready.
+ checkAnswer(
+ shreddedDf.selectExpr("v.value is null"),
+ // First row has "d" and third has "A".
+ Seq(Row(false), Row(true), Row(false)))
+ checkAnswer(
+ shreddedDf.selectExpr("v.typed_value.a.value is null"),
+ // First row is fully shredded, third is missing.
+ Seq(Row(true), Row(false), Row(true)))
+ checkAnswer(
+ shreddedDf.selectExpr("v.typed_value.b.value is null"),
+ // b is always fully shredded or missing.
+ Seq(Row(true), Row(true), Row(true)))
+ checkAnswer(
+ shreddedDf.selectExpr("v.typed_value.c.value is null"),
+ Seq(Row(true), Row(false), Row(false)))
+ // The a/b/c levels are not null, even if the field is missing.
+ checkAnswer(
+ shreddedDf.selectExpr(
+ "v.typed_value.a is null or v.typed_value.b is null or
v.typed_value.c is null"),
+ Seq(Row(false), Row(false), Row(false)))
+ }
+ }
+
+ testWithTempDir("write shredded variant array") { dir =>
+ val schema = "array<int>"
+ val df = spark.sql(
+ """
+ | select case
+ | when id = 0 then parse_json('[1, "2", 3.5, null, 5]')
+ | when id = 1 then parse_json('{"a": [1, 2, 3]}')
+ | when id = 2 then parse_json('1')
+ | when id = 3 then parse_json('null')
+ | end v from range(4)
+ |""".stripMargin)
+ val fullSchema = "v struct<metadata binary, value binary, typed_value
array<" +
+ "struct<value binary, typed_value int>>>"
+ withSQLConf(SQLConf.VARIANT_WRITE_SHREDDING_ENABLED.key -> true.toString,
+ SQLConf.VARIANT_FORCE_SHREDDING_SCHEMA_FOR_TEST.key -> schema) {
+ df.write.mode("overwrite").parquet(dir.getAbsolutePath)
+
+ // Verify that we can read the full variant.
+ checkAnswer(
+ spark.read.parquet(dir.getAbsolutePath).selectExpr("to_json(v)"),
+ df.selectExpr("to_json(v)").collect()
+ )
+
+ // Verify that it was shredded to the expected fields.
+
+ val shreddedDf =
spark.read.schema(fullSchema).parquet(dir.getAbsolutePath)
+ // Metadata should be unchanaged.
+ checkAnswer(shreddedDf.selectExpr("v.metadata"),
+ df.collect().map(v =>
Row(v.get(0).asInstanceOf[VariantVal].getMetadata))
+ )
+
+ // Check typed values.
+ checkAnswer(
+ shreddedDf.selectExpr("v.typed_value.typed_value"),
+ Seq(Row(Array(1, null, null, null, 5)), Row(null), Row(null),
Row(null)))
+
+ // All the other array elements should have non-null value.
+ checkAnswer(
+ shreddedDf.selectExpr("transform(v.typed_value.value, v -> v is
null)"),
+ Seq(Row(Array(true, false, false, false, true)), Row(null), Row(null),
Row(null)))
+
+ // The non-arrays should have non-null top-level value.
+ checkAnswer(
+ shreddedDf.selectExpr("v.value is null"),
+ Seq(Row(true), Row(false), Row(false), Row(false)))
+ }
+ }
+
+ testWithTempDir("write no shredding schema") { dir =>
+ // Check that we can write and read normally when shredding is enabled if
+ // we don't provide a shredding schema.
+ withSQLConf(SQLConf.VARIANT_WRITE_SHREDDING_ENABLED.key -> true.toString) {
+ val df = spark.sql(
+ """
+ | select parse_json('{"a": ' || id || ', "b": 2}') as v,
+ | array(parse_json('{"c": 3}'), 123::variant) as a
+ | from range(1, 3, 1, 1)
+ |""".stripMargin)
+ df.write.mode("overwrite").parquet(dir.getAbsolutePath)
+ checkAnswer(
+ spark.read.parquet(dir.getAbsolutePath), df.collect()
+ )
+ }
+ }
+
+ testWithTempDir("arrays and maps ignore shredding schema") { dir =>
+ // Check that we don't try to shred array or map elements, even if a
shredding schema
+ // is specified.
+ val schema = "a int"
+ val df = spark.sql(
+ """ select v, array(v) as arr, map('myKey', v) as m from
+ | (select parse_json('{"a":' || id || '}') v from range(3))
+ |""".stripMargin)
+ val fullSchema = "v struct<metadata binary, value binary, typed_value
struct<" +
+ "a struct<value binary, typed_value int>>>, " +
+ "arr array<struct<metadata binary, value binary>>, " +
+ "m map<string, struct<metadata binary, value binary>>"
+ withSQLConf(SQLConf.VARIANT_WRITE_SHREDDING_ENABLED.key -> true.toString,
+ SQLConf.VARIANT_FORCE_SHREDDING_SCHEMA_FOR_TEST.key -> schema) {
+ df.write.mode("overwrite").parquet(dir.getAbsolutePath)
+
+ // Verify that we can read the full variant.
+ checkAnswer(
+ spark.read.parquet(dir.getAbsolutePath).selectExpr("to_json(v)"),
+ df.selectExpr("to_json(v)").collect()
+ )
+
+ // Verify that it was shredded to the expected fields.
+
+ val shreddedDf =
spark.read.schema(fullSchema).parquet(dir.getAbsolutePath)
+ // Metadata should be unchanaged.
+ checkAnswer(shreddedDf.selectExpr("v.metadata"),
+ df.selectExpr("v").collect().map(v =>
Row(v.get(0).asInstanceOf[VariantVal].getMetadata))
+ )
+ checkAnswer(shreddedDf.selectExpr("arr[0].metadata"),
+ df.selectExpr("arr[0]").collect().map(v =>
+ Row(v.get(0).asInstanceOf[VariantVal].getMetadata))
+ )
+ checkAnswer(shreddedDf.selectExpr("m['myKey'].metadata"),
+ df.selectExpr("m['myKey']").collect().map(
+ v => Row(v.get(0).asInstanceOf[VariantVal].getMetadata))
+ )
+
+ // v should be fully shredded, but the array and map should not be.
+ checkAnswer(
+ shreddedDf.selectExpr(
+ "v.value is null"),
+ Seq(Row(true), Row(true), Row(true)))
+ checkAnswer(
+ shreddedDf.selectExpr(
+ "arr[0].value is null"),
+ Seq(Row(false), Row(false), Row(false)))
+ checkAnswer(
+ shreddedDf.selectExpr(
+ "m['myKey'].value is null"),
+ Seq(Row(false), Row(false), Row(false)))
+ }
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]