gene-db commented on code in PR #52406:
URL: https://github.com/apache/spark/pull/52406#discussion_r2388585612


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/InferVariantShreddingSchema.scala:
##########
@@ -0,0 +1,380 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.parquet
+
+import org.apache.spark.SparkRuntimeException
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+import org.apache.spark.types.variant._
+import org.apache.spark.types.variant.VariantUtil.Type
+import org.apache.spark.unsafe.types._
+
+/**
+ *
+ * Infer a schema when there are Variant values in the shredding schema.
+ * Only VariantType values at the top level or nested in struct fields are 
replaced.
+ * VariantType nested in arrays or maps are not modified.
+ * @param schema The original schema containing VariantType.
+ */
+class InferVariantShreddingSchema(val schema: StructType) {
+
+  /**
+   * Create a list of paths to Variant values in the schema.
+   * Variant fields nested in arrays or maps are not included.
+   * For example, if the schema is
+   * struct<v: variant, struct<a: int, b: int, c: variant>>
+   * the function will return [[0], [1, 2]
+   */
+  private def getPathsToVariant(s: StructType): Seq[Seq[Int]] = {
+    s.fields.zipWithIndex
+      .map {
+        case (field, idx) =>
+          field.dataType match {
+            case VariantType =>
+              Seq(Seq(idx))
+            case inner: StructType =>
+              // Prepend this index to each downstream path.
+              getPathsToVariant(inner).map { path =>
+                idx +: path
+              }
+            case _ => Seq()
+          }
+      }
+      .toSeq
+      .flatten
+  }
+
+  private def getValueAtPath(s: StructType, row: InternalRow, p: Seq[Int]): 
Option[VariantVal] = {

Review Comment:
   NIT:
   ```suggestion
     private def getValueAtPath(schema: StructType, row: InternalRow, path: 
Seq[Int]): Option[VariantVal] = {
   ```
   
   Also, can we comment on what this will return? It looks it will return 
`None` of the field will be null?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/InferVariantShreddingSchema.scala:
##########
@@ -0,0 +1,380 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.parquet
+
+import org.apache.spark.SparkRuntimeException
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+import org.apache.spark.types.variant._
+import org.apache.spark.types.variant.VariantUtil.Type
+import org.apache.spark.unsafe.types._
+
+/**
+ *
+ * Infer a schema when there are Variant values in the shredding schema.
+ * Only VariantType values at the top level or nested in struct fields are 
replaced.

Review Comment:
   What are they replaced with?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/InferVariantShreddingSchema.scala:
##########
@@ -0,0 +1,380 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.parquet
+
+import org.apache.spark.SparkRuntimeException
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+import org.apache.spark.types.variant._
+import org.apache.spark.types.variant.VariantUtil.Type
+import org.apache.spark.unsafe.types._
+
+/**
+ *
+ * Infer a schema when there are Variant values in the shredding schema.
+ * Only VariantType values at the top level or nested in struct fields are 
replaced.
+ * VariantType nested in arrays or maps are not modified.
+ * @param schema The original schema containing VariantType.
+ */
+class InferVariantShreddingSchema(val schema: StructType) {
+
+  /**
+   * Create a list of paths to Variant values in the schema.
+   * Variant fields nested in arrays or maps are not included.
+   * For example, if the schema is
+   * struct<v: variant, struct<a: int, b: int, c: variant>>
+   * the function will return [[0], [1, 2]
+   */
+  private def getPathsToVariant(s: StructType): Seq[Seq[Int]] = {
+    s.fields.zipWithIndex
+      .map {
+        case (field, idx) =>
+          field.dataType match {
+            case VariantType =>
+              Seq(Seq(idx))
+            case inner: StructType =>
+              // Prepend this index to each downstream path.
+              getPathsToVariant(inner).map { path =>
+                idx +: path
+              }
+            case _ => Seq()
+          }
+      }
+      .toSeq
+      .flatten
+  }
+
+  private def getValueAtPath(s: StructType, row: InternalRow, p: Seq[Int]): 
Option[VariantVal] = {
+    if (row.isNullAt(p.head)) {
+      None
+    } else if (p.length == 1) {
+      // We've reached the Variant value.
+      Some(row.getVariant(p.head))
+    } else {
+      // The field must be a struct.
+      val childStruct = s.fields(p.head).dataType.asInstanceOf[StructType]
+      getValueAtPath(
+        childStruct,
+        row.getStruct(p.head, childStruct.length),
+        p.tail
+      )
+    }
+  }
+
+  private val pathsToVariant = getPathsToVariant(schema)
+
+  private val maxShreddedFieldsPerFile =
+    SQLConf.get.getConf(SQLConf.VARIANT_SHREDDING_MAX_SCHEMA_WIDTH)
+
+  private val maxShreddingDepth =
+    SQLConf.get.getConf(SQLConf.VARIANT_SHREDDING_MAX_SCHEMA_DEPTH)
+
+  private val COUNT_METADATA_KEY = "COUNT"
+
+  /**
+   * Return an appropriate schema for shredding a Variant value.
+   * It is similar to the SchemaOfVariant expression, but the rules are 
somewhat different, because
+   * we want the types to be consistent with what will be allowed during 
shredding. E.g.
+   * SchemaOfVariant will consider the common type across Integer and Double 
to be double, but we
+   * consider it to be VariantType, since shredding will not allow those types 
to be written to
+   * the same typed_value.
+   * We also maintain metadata on struct fields to track how frequently they 
occur. Rare fields
+   * are dropped in the final schema.
+   */
+  private def schemaOf(v: Variant, maxDepth: Int): DataType = v.getType match {
+    case Type.OBJECT =>
+      if (maxDepth <= 0) return VariantType
+      val size = v.objectSize()
+      val fields = new Array[StructField](size)
+      for (i <- 0 until size) {
+        val field = v.getFieldAtIndex(i)
+        fields(i) = StructField(field.key, schemaOf(field.value, maxDepth - 1),
+          metadata = new MetadataBuilder().putLong(COUNT_METADATA_KEY, 
1).build())
+      }
+      // According to the variant spec, object fields must be sorted 
alphabetically. So we don't
+      // have to sort, but just need to validate they are sorted.
+      for (i <- 1 until size) {
+        if (fields(i - 1).name >= fields(i).name) {
+          throw new SparkRuntimeException(
+            errorClass = "MALFORMED_VARIANT",
+            messageParameters = Map.empty
+          )
+        }
+      }
+      StructType(fields)
+    case Type.ARRAY =>
+      if (maxDepth <= 0) return VariantType
+      var elementType: DataType = NullType
+      for (i <- 0 until v.arraySize()) {
+        elementType = mergeSchema(elementType, 
schemaOf(v.getElementAtIndex(i), maxDepth - 1))
+      }
+      ArrayType(elementType)
+    case Type.NULL => NullType
+    case Type.BOOLEAN => BooleanType
+    case Type.LONG =>
+      // Compute the smallest decimal that can contain this value.
+      // This will allow us to merge with decimal later without introducing 
excessive precision.
+      // If we only end up encountering integer values, we'll convert back to 
LongType when we
+      // finalize.
+      val d = BigDecimal(v.getLong())
+      val precision = d.precision
+      if (precision <= Decimal.MAX_LONG_DIGITS) {
+        DecimalType(precision, 0)
+      } else {
+        // Value is too large for Decimal(18, 0), so record its type as long.
+        LongType
+      }
+    case Type.STRING => StringType
+    case Type.DOUBLE => DoubleType
+    case Type.DECIMAL =>
+      // Don't strip trailing zeros to determine scale. Even if we allow scale 
relaxation during
+      // shredding, it's useful to take trailing zeros as a hint that the 
extra digits may be used
+      // in later values, and use the larger scale.
+      val d = Decimal(v.getDecimalWithOriginalScale())
+      DecimalType(d.precision, d.scale)
+    case Type.DATE => DateType
+    case Type.TIMESTAMP => TimestampType
+    case Type.TIMESTAMP_NTZ => TimestampNTZType
+    case Type.FLOAT => FloatType
+    case Type.BINARY => BinaryType
+    // Spark doesn't support UUID, so shred it as an untyped value.
+    case Type.UUID => VariantType
+  }
+
+  private def getFieldCount(field: StructField): Long = {
+    field.metadata.getLong(COUNT_METADATA_KEY)
+  }
+
+  // Merge two decimals with possibly different scales.
+  private def mergeDecimal(d1: DecimalType, d2: DecimalType): DataType = {
+    val scale = Math.max(d1.scale, d2.scale)
+    val range = Math.max(d1.precision - d1.scale, d2.precision - d2.scale)
+    if (range + scale > DecimalType.MAX_PRECISION) {
+      // DecimalType can't support precision > 38
+      VariantType
+    } else {
+      DecimalType(range + scale, scale)
+    }
+  }
+
+  private def mergeDecimalWithLong(d: DecimalType): DataType = {
+    if (d.scale == 0 && d.precision <= 18) {
+      // It's an integer-like Decimal. Rather than widen to a precision of 19, 
we can
+      // use LongType
+      LongType
+    } else {
+      // Long can always fit in a Decimal(19, 0)
+      mergeDecimal(d, DecimalType(19, 0))
+    }
+  }
+
+  private def mergeSchema(dt1: DataType, dt2: DataType): DataType = {
+    (dt1, dt2) match {
+      // Allow VariantNull to appear in any typed schema
+      case (NullType, t) => t
+      case (t, NullType) => t
+      case (d1: DecimalType, d2: DecimalType) =>
+        mergeDecimal(d1, d2)
+      case (d: DecimalType, LongType) =>
+        mergeDecimalWithLong(d)
+      case (LongType, d: DecimalType) =>
+        mergeDecimalWithLong(d)
+      case (StructType(fields1), StructType(fields2)) =>
+        // Rely on fields being sorted by name, and merge fields with the same 
name recursively.
+        val newFields = new java.util.ArrayList[StructField]()
+
+        var f1Idx = 0
+        var f2Idx = 0
+        // We end up dropping all but 300 fields in the final schema, but add 
a cap on how many
+        // we'll try to track to avoid memory/time blow-ups in the 
intermediate state.
+        val maxStructSize = 1000
+
+        while (f1Idx < fields1.length && f2Idx < fields2.length && 
newFields.size < maxStructSize) {
+          val f1Name = fields1(f1Idx).name
+          val f2Name = fields2(f2Idx).name
+          val comp = f1Name.compareTo(f2Name)
+          if (comp == 0) {
+            val dataType = mergeSchema(fields1(f1Idx).dataType, 
fields2(f2Idx).dataType)
+            val c1 = getFieldCount(fields1(f1Idx))
+            val c2 = getFieldCount(fields2(f2Idx))
+            newFields.add(
+              StructField(
+                f1Name,
+                dataType,
+                metadata = new MetadataBuilder().putLong(COUNT_METADATA_KEY, 
c1 + c2).build()
+              )
+            )
+            f1Idx += 1
+            f2Idx += 1
+          } else if (comp < 0) { // f1Name < f2Name
+            newFields.add(fields1(f1Idx))
+            f1Idx += 1
+          } else { // f1Name > f2Name
+            newFields.add(fields2(f2Idx))
+            f2Idx += 1
+          }
+        }
+        while (f1Idx < fields1.length && newFields.size < maxStructSize) {
+          newFields.add(fields1(f1Idx))
+          f1Idx += 1
+        }
+        while (f2Idx < fields2.length && newFields.size < maxStructSize) {
+          newFields.add(fields2(f2Idx))
+          f2Idx += 1
+        }
+        StructType(newFields.toArray(Array.empty[StructField]))
+      case (ArrayType(e1, _), ArrayType(e2, _)) =>
+        ArrayType(mergeSchema(e1, e2))
+      // For any other scalar types, the types must be identical, or we give 
up and use Variant.
+      case (_, _) if dt1 == dt2 => dt1
+      case _ => VariantType
+    }
+  }
+
+  /**
+   * Update each VariantType with its inferred schema.

Review Comment:
   Does it return a copy of the schema (with the updates), or does it update in 
place?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/InferVariantShreddingSchema.scala:
##########
@@ -0,0 +1,380 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.parquet
+
+import org.apache.spark.SparkRuntimeException
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+import org.apache.spark.types.variant._
+import org.apache.spark.types.variant.VariantUtil.Type
+import org.apache.spark.unsafe.types._
+
+/**
+ *
+ * Infer a schema when there are Variant values in the shredding schema.
+ * Only VariantType values at the top level or nested in struct fields are 
replaced.
+ * VariantType nested in arrays or maps are not modified.
+ * @param schema The original schema containing VariantType.

Review Comment:
   Does this require that the schema must contain a `VariantType`, and does the 
`VariantType` have to be a top-level field?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to