cloud-fan commented on code in PR #54343:
URL: https://github.com/apache/spark/pull/54343#discussion_r3263703852
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/InferVariantShreddingSchema.scala:
##########
@@ -351,36 +311,252 @@ class InferVariantShreddingSchema(val schema:
StructType) {
}
def inferSchema(rows: Seq[InternalRow]): StructType = {
- // For each path to a Variant value, iterate over all rows and update the
inferred schema.
- // Add the result to a map, which we'll use to update the full schema.
- // maxShreddedFieldsPerFile is a global max for all fields, so initialize
it here.
+ // For each variant path, collect field statistics using a single pass
val maxFields = MaxFields(maxShreddedFieldsPerFile)
+
val inferredSchemas = pathsToVariant.map { path =>
- var numNonNullValues = 0
- val simpleSchema = rows.foldLeft(NullType: DataType) {
- case (partialSchema, row) =>
- getValueAtPath(schema, row, path).map { variantVal =>
- numNonNullValues += 1
- val v = new Variant(variantVal.getValue, variantVal.getMetadata)
- val schemaOfRow = schemaOf(v, maxShreddingDepth)
- mergeSchema(partialSchema, schemaOfRow)
- // If getValueAtPath returned None, the value is null in this row;
just ignore.
- }
- .getOrElse(partialSchema)
- // If we didn't find any non-null rows, use an unshredded schema.
- }
+ val rootNode = FieldNode(NullType)
+ var numNonNullVariants = 0
- // Don't infer a schema for fields that appear in less than 10% of rows.
- // Ensure that minCardinality is at least 1 if we have any rows.
- val minCardinality = (numNonNullValues + 9) / 10
+ // Single pass: process all rows for this variant path
+ rows.zipWithIndex.foreach { case (row, rowIdx) =>
+ getValueAtPath(schema, row, path).foreach { variantVal =>
+ numNonNullVariants += 1
+ val v = new Variant(variantVal.getValue, variantVal.getMetadata)
+ rootNode.dataType = mergeSchema(rootNode.dataType,
inferPrimitiveType(v, 0))
+ // Traverse variant and update field stats tree
+ collectFieldStats(v, rootNode, rowIdx, 0, inArrayContext = false)
+ }
+ }
+ // Build final schema from collected statistics
+ val minCardinality = (numNonNullVariants + 9) / 10
+ val simpleSchema = buildSchemaFromStats(
+ rootNode,
+ minCardinality,
+ inArrayContext = false,
+ isArray = rootNode.arrayElementNode.isDefined)
val finalizedSchema = finalizeSimpleSchema(simpleSchema, minCardinality,
maxFields)
val shreddingSchema =
SparkShreddingUtils.variantShreddingSchema(finalizedSchema)
val schemaWithMetadata =
SparkShreddingUtils.addWriteShreddingMetadata(shreddingSchema)
(path, schemaWithMetadata)
}.toMap
- // Insert each inferred schema into the full schema.
+ // Insert each inferred schema into the full schema
updateSchema(schema, inferredSchemas)
}
+
+ /**
+ * Recursively traverse a variant value and build field statistics tree.
+ * For each field encountered, record its type and track distinct row count.
+ * For fields inside arrays, also increment the occurrence count.
+ */
+ private def collectFieldStats(
+ v: Variant,
+ currentNode: FieldNode,
+ rowIdx: Int,
+ depth: Int,
+ inArrayContext: Boolean): Unit = {
+
+ if (depth >= maxShreddingDepth) return
+
+ v.getType match {
+ case Type.OBJECT =>
+ val size = v.objectSize()
+ // Validate fields are sorted (per variant spec)
+ for (i <- 1 until size) {
+ val prevKey = v.getFieldAtIndex(i - 1).key
+ val currKey = v.getFieldAtIndex(i).key
+ if (prevKey >= currKey) {
+ throw new SparkRuntimeException(
+ errorClass = "MALFORMED_VARIANT",
+ messageParameters = Map.empty
+ )
+ }
+ }
+
+ // Process each field
+ for (i <- 0 until size) {
+ val field = v.getFieldAtIndex(i)
+ val fieldName = field.key
+
+ // Get or create child node (O(1) map access - no path string
building!)
+ val childNode = currentNode.getOrCreateChild(fieldName)
+
+ // Track row-level presence only outside array context.
+ if (inArrayContext) {
+ childNode.arrayElementCount += 1
+ } else if (childNode.lastSeenRow != rowIdx) {
+ childNode.rowCount += 1
+ childNode.lastSeenRow = rowIdx
+ }
+
+ // Infer and merge type
+ val fieldType = inferPrimitiveType(field.value, depth)
+ childNode.dataType = mergeSchema(childNode.dataType, fieldType)
+
+ // Recurse into nested structures (pass child node, not path string)
+ collectFieldStats(field.value, childNode, rowIdx, depth + 1,
inArrayContext)
+ }
+
+ case Type.ARRAY =>
+ val arrayNode = currentNode.getOrCreateArrayElement()
+
+ // Track distinct row count for the array field itself
+ if (arrayNode.lastSeenRow != rowIdx) {
+ arrayNode.rowCount += 1
+ arrayNode.lastSeenRow = rowIdx
+ }
+
+ val arraySize = v.arraySize()
+ if (arraySize > 0) {
+ // Process array elements
+ for (i <- 0 until arraySize) {
+ val element = v.getElementAtIndex(i)
+ val elementTypeClass = element.getType
+
+ // Primitives merge into `dataType` only; objects and arrays need
tree descent.
+ if (elementTypeClass != Type.OBJECT && elementTypeClass !=
Type.ARRAY) {
+ val primitiveType = inferPrimitiveType(element, depth)
+ arrayNode.dataType = mergeSchema(arrayNode.dataType,
primitiveType)
+ } else {
+ collectFieldStats(element, arrayNode, rowIdx, depth + 1,
inArrayContext = true)
+ }
+ }
+ }
+
+ case _ =>
+ }
+ }
+
+ /**
+ * Infer the type of a variant value without recursive field collection.
+ * For objects and arrays, return a marker type; recursive collection is
done separately.
+ */
+ private def inferPrimitiveType(v: Variant, depth: Int): DataType = {
+ if (depth >= maxShreddingDepth) return VariantType
+
+ v.getType match {
+ case Type.OBJECT =>
+ // Return empty struct as marker; fields collected separately
+ StructType(Seq.empty)
+ case Type.ARRAY =>
+ // Return array with null element as marker; elements processed
separately
+ ArrayType(NullType)
+ case Type.NULL => NullType
+ case Type.BOOLEAN => BooleanType
+ case Type.LONG =>
+ val d = BigDecimal(v.getLong())
+ val precision = d.precision
+ if (precision <= Decimal.MAX_LONG_DIGITS) {
+ DecimalType(precision, 0)
+ } else {
+ LongType
+ }
+ case Type.STRING => StringType
+ case Type.DOUBLE => DoubleType
+ case Type.DECIMAL =>
+ val d = Decimal(v.getDecimalWithOriginalScale())
+ DecimalType(d.precision, d.scale)
+ case Type.DATE => DateType
+ case Type.TIMESTAMP => TimestampType
+ case Type.TIMESTAMP_NTZ => TimestampNTZType
+ case Type.FLOAT => FloatType
+ case Type.BINARY => BinaryType
+ case Type.UUID => VariantType
+ }
+ }
+
+ /**
+ * Build a schema from collected field statistics tree.
+ *
+ * When isArray=true the function builds and returns the full ArrayType for
this node
+ * (using its arrayElementNode to determine the element type).
+ * When isArray=false it returns the type for the node itself (scalar,
VariantType,
+ * or StructType).
+ *
+ * Cardinality metric:
+ * - inArrayContext=true uses arrayElementCount (total occurrences across
array positions).
+ * - inArrayContext=false uses rowCount (distinct rows containing the
field).
+ */
+ private def buildSchemaFromStats(
+ currentNode: FieldNode,
+ minCardinality: Int,
+ inArrayContext: Boolean,
+ isArray: Boolean): DataType = {
+
+ // Pick the right counter for this context; reused in filter, sort, and
metadata below.
+ def cardinality(n: FieldNode): Long =
+ if (inArrayContext) n.arrayElementCount else n.rowCount
+
+ // Array branch
+ if (isArray) {
+ // Case 1: mixed array and non-array rows at the same path merged
dataType to VariantType.
+ // The whole node is variant, not an array.
+ if (currentNode.dataType == VariantType) {
+ return VariantType
+ }
+ // Case 2: object elements and inner-array elements coexist on the same
element aggregate
+ // (children from objects, arrayElementNode from inner arrays):
element is variant.
+ if (currentNode.children.nonEmpty &&
currentNode.arrayElementNode.isDefined) {
Review Comment:
Is Case 2 actually reachable? Tracing each `isArray=true` call site:
- Line 334 (root call): for `rootNode` to have both `children` and
`arrayElementNode` populated, rows must be a mix of object and array shapes,
which makes `rootNode.dataType` collapse to `VariantType` via `mergeSchema` —
Case 1 catches it.
- Line 514 (Case 3 recursion): the if-then short-circuit at lines 511–512
already returns `VariantType` when `elemNode` has both, before this recursion
runs.
- Line 550 (struct-field ArrayType recursion): if `childNode.dataType`
matched `ArrayType(_, _)`, the field was consistently an array, so
`childNode.children` is empty.
If I'm missing a scenario, a comment naming it would help. If it's purely
defensive, a note saying so would clarify intent.
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/InferVariantShreddingSchema.scala:
##########
@@ -93,74 +95,29 @@ class InferVariantShreddingSchema(val schema: StructType) {
private val COUNT_METADATA_KEY = "COUNT"
- /**
- * Return an appropriate schema for shredding a Variant value.
- * It is similar to the SchemaOfVariant expression, but the rules are
somewhat different, because
- * we want the types to be consistent with what will be allowed during
shredding. E.g.
- * SchemaOfVariant will consider the common type across Integer and Double
to be double, but we
- * consider it to be VariantType, since shredding will not allow those types
to be written to
- * the same typed_value.
- * We also maintain metadata on struct fields to track how frequently they
occur. Rare fields
- * are dropped in the final schema.
- */
- private def schemaOf(v: Variant, maxDepth: Int): DataType = v.getType match {
- case Type.OBJECT =>
- if (maxDepth <= 0) return VariantType
- val size = v.objectSize()
- val fields = new Array[StructField](size)
- for (i <- 0 until size) {
- val field = v.getFieldAtIndex(i)
- fields(i) = StructField(field.key, schemaOf(field.value, maxDepth - 1),
- metadata = new MetadataBuilder().putLong(COUNT_METADATA_KEY,
1).build())
- }
- // According to the variant spec, object fields must be sorted
alphabetically. So we don't
- // have to sort, but just need to validate they are sorted.
- for (i <- 1 until size) {
- if (fields(i - 1).name >= fields(i).name) {
- throw new SparkRuntimeException(
- errorClass = "MALFORMED_VARIANT",
- messageParameters = Map.empty
- )
- }
- }
- StructType(fields)
- case Type.ARRAY =>
- if (maxDepth <= 0) return VariantType
- var elementType: DataType = NullType
- for (i <- 0 until v.arraySize()) {
- elementType = mergeSchema(elementType,
schemaOf(v.getElementAtIndex(i), maxDepth - 1))
- }
- ArrayType(elementType)
- case Type.NULL => NullType
- case Type.BOOLEAN => BooleanType
- case Type.LONG =>
- // Compute the smallest decimal that can contain this value.
- // This will allow us to merge with decimal later without introducing
excessive precision.
- // If we only end up encountering integer values, we'll convert back to
LongType when we
- // finalize.
- val d = BigDecimal(v.getLong())
- val precision = d.precision
- if (precision <= Decimal.MAX_LONG_DIGITS) {
- DecimalType(precision, 0)
- } else {
- // Value is too large for Decimal(18, 0), so record its type as long.
- LongType
+ // Node for tree-based field tracking
+ private case class FieldNode(
+ var dataType: DataType, // type summary of the field, not fully
defined
Review Comment:
"not fully defined" is vague. State what's incomplete — for OBJECT/ARRAY
this is a marker, structural shape lives elsewhere.
```suggestion
var dataType: DataType, // scalar type or a marker
(StructType(empty) / ArrayType(NullType)); structural shape lives in `children`
and `arrayElementNode`
```
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/VariantInferShreddingSuite.scala:
##########
@@ -226,15 +342,23 @@ class VariantInferShreddingSuite extends
SharedSparkSession with ParquetTest {
val footers = getFooters(dir)
assert(footers.size == 1)
- // We can't call checkFileSchema, because it only handles the case of
one Variant column in
- // the file.
- val largeExpected =
SparkShreddingUtils.variantShreddingSchema(DataType.fromDDL("variant"))
+ // With cardinality-based sorting, v should now have a shredded schema
+ // for the high-cardinality last_* fields (not an unshredded schema like
+ // master would produce). v2 should still be shredded correctly.
+ val actual = getFileSchema(dir)
+ val v_schema = actual.fields(0).dataType.asInstanceOf[StructType]
+ val v2_schema = actual.fields(1).dataType.asInstanceOf[StructType]
+
+ // v should have shredded typed_value (struct with nested last_* fields)
+ assert(v_schema.fieldNames.contains("typed_value"))
+ val v_typed = v_schema("typed_value").dataType.asInstanceOf[StructType]
+ assert(v_typed.fields.exists(_.name.startsWith("last_")))
Review Comment:
These assertions are noticeably weaker than the pre-PR version, which
checked the exact schema. The cardinality-based selection is the key behavior
change this PR makes — it should be locked in. Currently no check that
`first_*_<id>` are excluded, that all 50 `last_*` survive, or that each one
keeps the `struct<x long, y long>` shape.
```suggestion
assert(!v_typed.fieldNames.exists(_.startsWith("first_")))
assert(v_typed.fieldNames.count(_.startsWith("last_")) == 50)
val last50 = v_typed.fields.find(_.name == "last_50").get
assert(last50.dataType ==
SparkShreddingUtils.variantShreddingSchema(DataType.fromDDL("struct<x long, y
long>")))
```
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/VariantInferShreddingSuite.scala:
##########
@@ -203,13 +322,10 @@ class VariantInferShreddingSuite extends
SharedSparkSession with ParquetTest {
}
testWithTempDir("infer shredding key as data") { dir =>
- // The first 10 fields in each object include the row ID in the field
name, so they'll be
- // unique. Because we impose a 1000-field limit when building up the
schema, we'll end up
- // dropping all but the first 1000, so we won't include the non-unique
fields in the schema.
- // Since the unique names are below the count threshold, we'll end up
with an unshredded
- // schema.
- // In the future, we could consider trying to improve this by dropping
the least-common fields
- // when we hit the limit of 1000.
+ // The first 50 fields include the row ID in the field name, so they're
+ // unique (low cardinality). The last 50 fields are shared across all
rows
+ // (high cardinality). With cardinality-based sorting,
+ // we now correctly shred the high-cardinality last_* fields
Review Comment:
Line breaks fragment the sentence and the comment lacks a terminating period.
```suggestion
// The first 50 fields include the row ID in the field name, so
they're unique
// (low cardinality). The last 50 fields are shared across all rows
(high
// cardinality). With cardinality-based sorting, the high-cardinality
`last_*`
// fields are now shredded.
```
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/VariantInferShreddingSuite.scala:
##########
@@ -634,4 +758,111 @@ class VariantInferShreddingSuite extends
SharedSparkSession with ParquetTest {
checkFileSchema(expected, dir)
checkAnswer(spark.read.parquet(dir.getAbsolutePath), df.collect())
}
+
+ testWithTempDir("special characters in field names - dots") { dir =>
+ val df = spark.sql(
+ """
+ |select parse_json(
+ | '{"field.with.dots": ' || id || ', "another.dotted.field": "value"}'
+ |) as v
+ |from range(0, 100, 1, 1)
+ """.stripMargin)
+ df.write.mode("overwrite").parquet(dir.getAbsolutePath)
+
+ // Verify the schema contains fields with dots
+ val schema = getFileSchema(dir)
+ val vSchema = schema("v").dataType.asInstanceOf[StructType]
+ val typedValue = vSchema("typed_value").dataType.asInstanceOf[StructType]
+ assert(typedValue.fieldNames.contains("another.dotted.field"))
+ assert(typedValue.fieldNames.contains("field.with.dots"))
+
+ // Verify we can read the data back
+ val result = spark.read.parquet(dir.getAbsolutePath)
+ assert(result.count() == 100)
Review Comment:
This test (and the four other `special characters in field names - *` tests
below) only assert schema field names and `result.count() == 100`. They don't
verify that variant values are read back correctly. The convention elsewhere in
this file is `checkAnswer(spark.read.parquet(dir.getAbsolutePath),
df.collect())` — using that instead of `result.count()` would catch value-level
corruption (e.g., metadata mis-encoding for dotted keys) that schema-only
checks miss.
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/InferVariantShreddingSchema.scala:
##########
@@ -351,36 +311,252 @@ class InferVariantShreddingSchema(val schema:
StructType) {
}
def inferSchema(rows: Seq[InternalRow]): StructType = {
- // For each path to a Variant value, iterate over all rows and update the
inferred schema.
- // Add the result to a map, which we'll use to update the full schema.
- // maxShreddedFieldsPerFile is a global max for all fields, so initialize
it here.
+ // For each variant path, collect field statistics using a single pass
val maxFields = MaxFields(maxShreddedFieldsPerFile)
+
val inferredSchemas = pathsToVariant.map { path =>
- var numNonNullValues = 0
- val simpleSchema = rows.foldLeft(NullType: DataType) {
- case (partialSchema, row) =>
- getValueAtPath(schema, row, path).map { variantVal =>
- numNonNullValues += 1
- val v = new Variant(variantVal.getValue, variantVal.getMetadata)
- val schemaOfRow = schemaOf(v, maxShreddingDepth)
- mergeSchema(partialSchema, schemaOfRow)
- // If getValueAtPath returned None, the value is null in this row;
just ignore.
- }
- .getOrElse(partialSchema)
- // If we didn't find any non-null rows, use an unshredded schema.
- }
+ val rootNode = FieldNode(NullType)
+ var numNonNullVariants = 0
- // Don't infer a schema for fields that appear in less than 10% of rows.
- // Ensure that minCardinality is at least 1 if we have any rows.
- val minCardinality = (numNonNullValues + 9) / 10
+ // Single pass: process all rows for this variant path
+ rows.zipWithIndex.foreach { case (row, rowIdx) =>
+ getValueAtPath(schema, row, path).foreach { variantVal =>
+ numNonNullVariants += 1
+ val v = new Variant(variantVal.getValue, variantVal.getMetadata)
+ rootNode.dataType = mergeSchema(rootNode.dataType,
inferPrimitiveType(v, 0))
+ // Traverse variant and update field stats tree
+ collectFieldStats(v, rootNode, rowIdx, 0, inArrayContext = false)
+ }
+ }
+ // Build final schema from collected statistics
+ val minCardinality = (numNonNullVariants + 9) / 10
+ val simpleSchema = buildSchemaFromStats(
+ rootNode,
+ minCardinality,
+ inArrayContext = false,
+ isArray = rootNode.arrayElementNode.isDefined)
val finalizedSchema = finalizeSimpleSchema(simpleSchema, minCardinality,
maxFields)
val shreddingSchema =
SparkShreddingUtils.variantShreddingSchema(finalizedSchema)
val schemaWithMetadata =
SparkShreddingUtils.addWriteShreddingMetadata(shreddingSchema)
(path, schemaWithMetadata)
}.toMap
- // Insert each inferred schema into the full schema.
+ // Insert each inferred schema into the full schema
updateSchema(schema, inferredSchemas)
}
+
+ /**
+ * Recursively traverse a variant value and build field statistics tree.
Review Comment:
Missing article.
```suggestion
* Recursively traverse a variant value and build a field statistics tree.
```
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/VariantInferShreddingSuite.scala:
##########
@@ -226,15 +342,23 @@ class VariantInferShreddingSuite extends
SharedSparkSession with ParquetTest {
val footers = getFooters(dir)
assert(footers.size == 1)
- // We can't call checkFileSchema, because it only handles the case of
one Variant column in
- // the file.
- val largeExpected =
SparkShreddingUtils.variantShreddingSchema(DataType.fromDDL("variant"))
+ // With cardinality-based sorting, v should now have a shredded schema
+ // for the high-cardinality last_* fields (not an unshredded schema like
+ // master would produce). v2 should still be shredded correctly.
Review Comment:
"like master would produce" anchors the comment to the pre-merge state —
once this PR lands, the comparison loses its reference. Rewrite in terms of the
current invariant.
```suggestion
// The 50 first_*_<id> fields are unique per row (low cardinality) and
are
// filtered out; the 50 last_* fields are shared across all rows (high
// cardinality) and are shredded into typed_value. v2 is shredded
independently.
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]