chenhao-db commented on code in PR #49263:
URL: https://github.com/apache/spark/pull/49263#discussion_r1896919919


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala:
##########
@@ -45,6 +56,369 @@ case class SparkShreddedRow(row: SpecializedGetters) 
extends ShreddingUtils.Shre
   override def numElements(): Int = row.asInstanceOf[ArrayData].numElements()
 }
 
+// The search result of a `PathSegment` in a `VariantSchema`.
+case class SchemaPathSegment(
+    rawPath: PathSegment,
+    // Whether this path segment is an object or array extraction.
+    isObject: Boolean,
+    // `schema.typedIdx`, if the path exists in the schema (for object 
extraction, the schema
+    // should contain an object `typed_value` containing the requested field; 
similar for array
+    // extraction). Negative otherwise.
+    typedIdx: Int,
+    // For object extraction, it is the index of the desired field in 
`schema.objectSchema`. If the
+    // requested field doesn't exist, both `extractionIdx/typedIdx` are set to 
negative.
+    // For array extraction, it is the array index. The information is already 
stored in `rawPath`,
+    // but accessing a raw int should be more efficient than `rawPath`, which 
is an `Either`.
+    extractionIdx: Int)
+
+// Represent a single field in a variant struct (see `VariantMetadata` for 
definition), that is, a
+// single requested field that the scan should produce by extracting from the 
variant column.
+case class FieldToExtract(path: Array[SchemaPathSegment], reader: 
ParquetVariantReader)
+
+// A helper class to cast from scalar `typed_value` into a scalar `dataType`. 
Need a custom
+// expression because it has different error reporting code than `Cast`.
+case class ScalarCastHelper(
+    child: Expression,
+    dataType: DataType,
+    castArgs: VariantCastArgs) extends UnaryExpression {
+  // The expression is only for the internal use of `ScalarReader`, which can 
guarantee the child
+  // is not nullable.
+  assert(!child.nullable)
+
+  // If `cast` is null, it means the cast always fails because the type 
combination is not allowed.
+  private val cast = if (Cast.canAnsiCast(child.dataType, dataType)) {
+    Cast(child, dataType, castArgs.zoneStr, EvalMode.TRY)
+  } else {
+    null
+  }
+  // Cast the input to string. Only used for reporting an invalid cast.
+  private val castToString = Cast(child, StringType, castArgs.zoneStr, 
EvalMode.ANSI)
+
+  override def nullable: Boolean = !castArgs.failOnError
+  override def withNewChildInternal(newChild: Expression): UnaryExpression = 
copy(child = newChild)
+
+  // No need to define the interpreted version of `eval`: the codegen must 
succeed.
+  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): 
ExprCode = {
+    // Throw an error or do nothing, depending on `castArgs.failOnError`.
+    val invalidCastCode = if (castArgs.failOnError) {
+      val castToStringCode = castToString.genCode(ctx)
+      val typeObj = ctx.addReferenceObj("dataType", dataType)
+      val cls = classOf[ScalarCastHelper].getName
+      s"""
+        ${castToStringCode.code}
+        $cls.throwInvalidVariantCast(${castToStringCode.value}, $typeObj);
+      """
+    } else {
+      ""
+    }
+    if (cast != null) {
+      val castCode = cast.genCode(ctx)
+      val code = code"""
+        ${castCode.code}
+        boolean ${ev.isNull} = ${castCode.isNull};
+        ${CodeGenerator.javaType(dataType)} ${ev.value} = ${castCode.value};
+        if (${ev.isNull}) { $invalidCastCode }
+      """
+      ev.copy(code = code)
+    } else {
+      val code = code"""
+        boolean ${ev.isNull} = true;
+        ${CodeGenerator.javaType(dataType)} ${ev.value} = 
${CodeGenerator.defaultValue(dataType)};
+        if (${ev.isNull}) { $invalidCastCode }
+      """
+      ev.copy(code = code)
+    }
+  }
+}
+
+object ScalarCastHelper {
+  // A helper function for codegen. The java compiler doesn't allow throwing a 
`Throwable` in a
+  // method without `throws` annotation.
+  def throwInvalidVariantCast(value: UTF8String, dataType: DataType): Any =
+    throw QueryExecutionErrors.invalidVariantCast(value.toString, dataType)
+}
+
+// The base class to read Parquet variant values into a Spark type.
+// For convenience, we also allow creating an instance of the base class 
itself. None of its
+// functions can be used, but it can serve as a container of `targetType` and 
`castArgs`.
+class ParquetVariantReader(
+    val schema: VariantSchema, val targetType: DataType, val castArgs: 
VariantCastArgs) {
+  // Read from a row containing a Parquet variant value (shredded or 
unshredded) and return a value
+  // of `targetType`. The row schema is described by `schema`.
+  // This function throws MALFORMED_VARIANT if the variant is missing. If the 
variant can be
+  // legally missing (the only possible situation is struct fields in object 
`typed_value`), the
+  // caller should check for it and avoid calling this function if the variant 
is missing.
+  def read(row: InternalRow, topLevelMetadata: Array[Byte]): Any = {
+    if (schema.typedIdx < 0 || row.isNullAt(schema.typedIdx)) {
+      if (schema.variantIdx < 0 || row.isNullAt(schema.variantIdx)) {
+        // Both `typed_value` and `value` are null, meaning the variant is 
missing.
+        throw QueryExecutionErrors.malformedVariant()
+      }
+      val v = new Variant(row.getBinary(schema.variantIdx), topLevelMetadata)
+      VariantGet.cast(v, targetType, castArgs)
+    } else {
+      readFromTyped(row, topLevelMetadata)
+    }
+  }
+
+  // Subclasses should override it to produce the read result when 
`typed_value` is not null.
+  protected def readFromTyped(row: InternalRow, topLevelMetadata: 
Array[Byte]): Any =
+    throw QueryExecutionErrors.unreachableError()
+
+  // A util function to rebuild the variant in binary format from a Parquet 
variant value.
+  protected final def rebuildVariant(row: InternalRow, topLevelMetadata: 
Array[Byte]): Variant = {
+    val builder = new VariantBuilder(false)
+    ShreddingUtils.rebuild(SparkShreddedRow(row), topLevelMetadata, schema, 
builder)
+    builder.result()
+  }
+
+  // A util function to throw error or return null when an invalid cast 
happens.
+  protected final def invalidCast(row: InternalRow, topLevelMetadata: 
Array[Byte]): Any = {
+    if (castArgs.failOnError) {
+      throw QueryExecutionErrors.invalidVariantCast(
+        rebuildVariant(row, topLevelMetadata).toJson(castArgs.zoneId), 
targetType)
+    } else {
+      null
+    }
+  }
+}
+
+object ParquetVariantReader {
+  // Create a reader for `targetType`. If `schema` is null, meaning that the 
extraction path doesn't
+  // exist in `typed_value`, it returns an instance of `ParquetVariantReader`. 
As described in the
+  // class comment, the reader is only a container of `targetType` and 
`castArgs` in this case.
+  def apply(schema: VariantSchema, targetType: DataType, castArgs: 
VariantCastArgs,
+            isTopLevelUnshredded: Boolean = false): ParquetVariantReader = 
targetType match {
+    case _ if schema == null => new ParquetVariantReader(schema, targetType, 
castArgs)
+    case s: StructType => new StructReader(schema, s, castArgs)
+    case a: ArrayType => new ArrayReader(schema, a, castArgs)
+    case m@MapType(_: StringType, _, _) => new MapReader(schema, m, castArgs)
+    case v: VariantType => new VariantReader(schema, v, castArgs, 
isTopLevelUnshredded)
+    case s: AtomicType => new ScalarReader(schema, s, castArgs)
+    case _ =>
+      // Type check should have rejected map with non-string type.
+      throw QueryExecutionErrors.unreachableError(s"Invalid target type: 
`${targetType.sql}`")
+  }
+}
+
+// Read Parquet variant values into a Spark struct type. It reads unshredded 
fields (fields that are
+// not in the typed object) from the `value`, and reads the shredded fields 
from the object
+// `typed_value`.
+// `value` must not contain any shredded field according to the shredding 
spec, but this requirement
+// is not enforced. If `value` does contain a shredded field, no error will 
occur, and the field in
+// object `typed_value` will be the final result.
+private[this] final class StructReader(
+  schema: VariantSchema, targetType: StructType, castArgs: VariantCastArgs)
+  extends ParquetVariantReader(schema, targetType, castArgs) {
+  // For each field in `targetType`, store the index of the field with the 
same name in object
+  // `typed_value`, or -1 if it doesn't exist in object `typed_value`.
+  private[this] val fieldInputIndices: Array[Int] = targetType.fields.map { f 
=>
+    val inputIdx = if (schema.objectSchemaMap != null) 
schema.objectSchemaMap.get(f.name) else null
+    if (inputIdx != null) inputIdx.intValue() else -1
+  }
+  // For each field in `targetType`, store the reader from the corresponding 
field in object
+  // `typed_value`, or null if it doesn't exist in object `typed_value`.
+  private[this] val fieldReaders: Array[ParquetVariantReader] =
+    targetType.fields.zip(fieldInputIndices).map { case (f, inputIdx) =>
+      if (inputIdx >= 0) {
+        val fieldSchema = schema.objectSchema(inputIdx).schema
+        ParquetVariantReader(fieldSchema, f.dataType, castArgs)
+      } else {
+        null
+      }
+    }
+  // If all fields in `targetType` can be found in object `typed_value`, then 
the reader doesn't
+  // need to read from `value`.
+  private[this] val needUnshreddedObject: Boolean = fieldInputIndices.exists(_ 
< 0)
+
+  override def readFromTyped(row: InternalRow, topLevelMetadata: Array[Byte]): 
Any = {
+    if (schema.objectSchema == null) return invalidCast(row, topLevelMetadata)
+    val obj = row.getStruct(schema.typedIdx, schema.objectSchema.length)
+    val result = new GenericInternalRow(fieldInputIndices.length)
+    var unshreddedObject: Variant = null
+    if (needUnshreddedObject && schema.variantIdx >= 0 && 
!row.isNullAt(schema.variantIdx)) {
+      unshreddedObject = new Variant(row.getBinary(schema.variantIdx), 
topLevelMetadata)
+      if (unshreddedObject.getType != Type.OBJECT) throw 
QueryExecutionErrors.malformedVariant()
+    }
+    val numFields = fieldInputIndices.length
+    var i = 0
+    while (i < numFields) {
+      val inputIdx = fieldInputIndices(i)
+      if (inputIdx >= 0) {
+        // Shredded field must not be null.
+        if (obj.isNullAt(inputIdx)) throw 
QueryExecutionErrors.malformedVariant()
+        val fieldSchema = schema.objectSchema(inputIdx).schema
+        val fieldInput = obj.getStruct(inputIdx, fieldSchema.numFields)
+        // Only read from the shredded field if it is not missing.
+        if ((fieldSchema.typedIdx >= 0 && 
!fieldInput.isNullAt(fieldSchema.typedIdx)) ||
+          (fieldSchema.variantIdx >= 0 && 
!fieldInput.isNullAt(fieldSchema.variantIdx))) {
+          result.update(i, fieldReaders(i).read(fieldInput, topLevelMetadata))
+        }
+      } else if (unshreddedObject != null) {
+        val fieldName = targetType.fields(i).name
+        val fieldType = targetType.fields(i).dataType
+        val unshreddedField = unshreddedObject.getFieldByKey(fieldName)
+        if (unshreddedField != null) {
+          result.update(i, VariantGet.cast(unshreddedField, fieldType, 
castArgs))
+        }
+      }
+      i += 1
+    }
+    result
+  }
+}
+
+// Read Parquet variant values into a Spark array type.
+private[this] final class ArrayReader(
+    schema: VariantSchema, targetType: ArrayType, castArgs: VariantCastArgs)
+  extends ParquetVariantReader(schema, targetType, castArgs) {
+  private[this] val elementReader = if (schema.arraySchema != null) {
+    ParquetVariantReader(schema.arraySchema, targetType.elementType, castArgs)
+  } else {
+    null
+  }
+
+  override def readFromTyped(row: InternalRow, topLevelMetadata: Array[Byte]): 
Any = {
+    if (schema.arraySchema == null) return invalidCast(row, topLevelMetadata)
+    val elementNumFields = schema.arraySchema.numFields
+    val arr = row.getArray(schema.typedIdx)
+    val size = arr.numElements()
+    val result = new Array[Any](size)
+    var i = 0
+    while (i < size) {
+      // Shredded array element must not be null.
+      if (arr.isNullAt(i)) throw QueryExecutionErrors.malformedVariant()
+      result(i) = elementReader.read(arr.getStruct(i, elementNumFields), 
topLevelMetadata)
+      i += 1
+    }
+    new GenericArrayData(result)
+  }
+}
+
+// Read Parquet variant values into a Spark map type with string key type. The 
input must be object
+// for a valid cast. The resulting map contains shredded fields from object 
`typed_value` and
+// unshredded fields from object `value`.
+// `value` must not contain any shredded field according to the shredding 
spec. Unlike
+// `StructReader`, this requirement is enforced in `MapReader`. If `value` 
does contain a shredded
+// field, throw a MALFORMED_VARIANT error. The purpose is to avoid duplicate 
map keys.
+private[this] final class MapReader(
+    schema: VariantSchema, targetType: MapType, castArgs: VariantCastArgs)
+  extends ParquetVariantReader(schema, targetType, castArgs) {
+  // Readers that convert each shredded field into the map value type.
+  private[this] val valueReaders = if (schema.objectSchema != null) {
+    schema.objectSchema.map { f =>
+      ParquetVariantReader(f.schema, targetType.valueType, castArgs)
+    }
+  } else {
+    null
+  }
+  // `UTF8String` representation of shredded field names. Do the `String -> 
UTF8String` once, so
+  // that `readFromTyped` doesn't need to do it repeatedly.
+  private[this] val shreddedFieldNames = if (schema.objectSchema != null) {
+    schema.objectSchema.map { f => UTF8String.fromString(f.fieldName) }
+  } else {
+    null
+  }
+
+  override def readFromTyped(row: InternalRow, topLevelMetadata: Array[Byte]): 
Any = {
+    if (schema.objectSchema == null) return invalidCast(row, topLevelMetadata)
+    val obj = row.getStruct(schema.typedIdx, schema.objectSchema.length)
+    val numShreddedFields = valueReaders.length
+    var unshreddedObject: Variant = null
+    if (schema.variantIdx >= 0 && !row.isNullAt(schema.variantIdx)) {
+      unshreddedObject = new Variant(row.getBinary(schema.variantIdx), 
topLevelMetadata)
+      if (unshreddedObject.getType != Type.OBJECT) throw 
QueryExecutionErrors.malformedVariant()
+    }
+    val numUnshreddedFields = if (unshreddedObject != null) 
unshreddedObject.objectSize() else 0
+    var keyArray = new Array[UTF8String](numShreddedFields + 
numUnshreddedFields)
+    var valueArray = new Array[Any](numShreddedFields + numUnshreddedFields)
+    var mapLength = 0
+    var i = 0
+    while (i < numShreddedFields) {
+      // Shredded field must not be null.
+      if (obj.isNullAt(i)) throw QueryExecutionErrors.malformedVariant()
+      val fieldSchema = schema.objectSchema(i).schema
+      val fieldInput = obj.getStruct(i, fieldSchema.numFields)
+      // Only add the shredded field to map if it is not missing.
+      if ((fieldSchema.typedIdx >= 0 && 
!fieldInput.isNullAt(fieldSchema.typedIdx)) ||
+        (fieldSchema.variantIdx >= 0 && 
!fieldInput.isNullAt(fieldSchema.variantIdx))) {
+        keyArray(mapLength) = shreddedFieldNames(i)
+        valueArray(mapLength) = valueReaders(i).read(fieldInput, 
topLevelMetadata)
+        mapLength += 1
+      }
+      i += 1
+    }
+    i = 0
+    while (i < numUnshreddedFields) {
+      val field = unshreddedObject.getFieldAtIndex(i)
+      if (schema.objectSchemaMap.containsKey(field.key)) {
+        throw QueryExecutionErrors.malformedVariant()
+      }
+      keyArray(mapLength) = UTF8String.fromString(field.key)
+      valueArray(mapLength) = VariantGet.cast(field.value, 
targetType.valueType, castArgs)
+      mapLength += 1
+      i += 1
+    }
+    // Need to shrink the arrays if there are missing shredded fields.
+    if (mapLength < keyArray.length) {
+      keyArray = keyArray.slice(0, mapLength)
+      valueArray = valueArray.slice(0, mapLength)
+    }
+    ArrayBasedMapData(keyArray, valueArray)
+  }
+}
+
+// Read Parquet variant values into a Spark variant type (the binary format).
+private[this] final class VariantReader(
+    schema: VariantSchema, targetType: DataType, castArgs: VariantCastArgs,
+    // An optional optimization: the user can set it to true if the Parquet 
variant column is
+    // unshredded and the extraction path is empty. We are not required to do 
anything special, bu
+    // we can avoid rebuilding variant for optimization purpose.
+    private[this] val isTopLevelUnshredded: Boolean)
+  extends ParquetVariantReader(schema, targetType, castArgs) {
+  override def read(row: InternalRow, topLevelMetadata: Array[Byte]): Any = {
+    if (isTopLevelUnshredded) {
+      if (row.isNullAt(schema.variantIdx)) throw 
QueryExecutionErrors.malformedVariant()
+      return new VariantVal(row.getBinary(schema.variantIdx), topLevelMetadata)
+    }
+    val v = rebuildVariant(row, topLevelMetadata)
+    new VariantVal(v.getValue, v.getMetadata)
+  }
+}
+
+// Read Parquet variant values into a Spark scalar type. When `typed_value` is 
not null but not a
+// scalar, all other target types should return an invalid cast, but only the 
string target type can
+// still build a string from array/object `typed_value`. For scalar 
`typed_value`, it depends on
+// `ScalarCastHelper` to perform the cast.
+// According to the shredding spec, scalar `typed_value` and `value` must not 
be non-null at the
+// same time. The requirement is not enforced in this reader. If they are both 
non-null, no error
+// will occur, and the reader will read from `typed_value`.
+private[this] final class ScalarReader(
+    schema: VariantSchema, targetType: DataType, castArgs: VariantCastArgs)
+  extends ParquetVariantReader(schema, targetType, castArgs) {
+  private[this] val castProject = if (schema.scalarSchema != null) {
+    val scalarType = 
SparkShreddingUtils.scalarSchemaToSparkType(schema.scalarSchema)
+    // Read the cast input from ordinal `schema.typedIdx` in the input row. 
The cast input is never
+    // null, because `readFromTyped` is only called when `typed_value` is not 
null.
+    val input = BoundReference(schema.typedIdx, scalarType, nullable = false)
+    MutableProjection.create(Seq(ScalarCastHelper(input, targetType, 
castArgs)))
+  } else {
+    null
+  }
+
+  override def readFromTyped(row: InternalRow, topLevelMetadata: Array[Byte]): 
Any = {
+    if (castProject == null) {
+      return if (targetType.isInstanceOf[StringType]) {
+        UTF8String.fromString(rebuildVariant(row, 
topLevelMetadata).toJson(castArgs.zoneId))

Review Comment:
   I think there is a misunderstanding. If the target type is string and the 
`typed_value` type is also string, `castProject` will not be null, and the code 
with not take the rebuild path. I also measured the cost of `castProject`, and 
it turns out to be small. For string -> string specifically, if I replace the 
whole `readFromTyped` with `row.getUTF8String(schema.typedIdx)`, the 
performance improvement is <10%.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to