sunchao commented on a change in pull request #34199:
URL: https://github.com/apache/spark/pull/34199#discussion_r724362170
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala
##########
@@ -60,40 +58,106 @@ class ParquetToSparkSchemaConverter(
/**
* Converts Parquet [[MessageType]] `parquetSchema` to a Spark SQL
[[StructType]].
*/
- def convert(parquetSchema: MessageType): StructType =
convert(parquetSchema.asGroupType())
+ def convert(parquetSchema: MessageType): StructType = {
+ val column = new ColumnIOFactory().getColumnIO(parquetSchema)
+ val converted = convertInternal(column)
+ converted.sparkType.asInstanceOf[StructType]
+ }
- private def convert(parquetSchema: GroupType): StructType = {
- val fields = parquetSchema.getFields.asScala.map { field =>
- field.getRepetition match {
- case OPTIONAL =>
- StructField(field.getName, convertField(field), nullable = true)
+ /**
+ * Convert `parquetSchema` into a [[ParquetType]] which contains its
corresponding Spark
+ * SQL [[StructType]] along with other information such as the maximum
repetition and definition
+ * level of each node, column descriptor for the leave nodes, etc.
+ *
+ * If `sparkReadSchema` is not empty, when deriving Spark SQL type from a
Parquet field this will
+ * check if the same field also exists in the schema. If so, it will use the
Spark SQL type.
+ * This is necessary since conversion from Parquet to Spark could cause
precision loss. For
+ * instance, Spark read schema is smallint/tinyint but Parquet only support
int.
+ */
+ def convertParquetType(
+ parquetSchema: MessageType,
+ sparkReadSchema: Option[StructType] = None,
+ caseSensitive: Boolean = true): ParquetType = {
+ val column = new ColumnIOFactory().getColumnIO(parquetSchema)
+ convertInternal(column, sparkReadSchema, caseSensitive)
+ }
- case REQUIRED =>
- StructField(field.getName, convertField(field), nullable = false)
+ private def convertInternal(
+ groupColumn: GroupColumnIO,
+ sparkReadSchema: Option[StructType] = None,
+ caseSensitive: Boolean = true): ParquetType = {
+ val converted = (0 until groupColumn.getChildrenCount).map { i =>
+ val field = groupColumn.getChild(i)
+ var fieldReadType = sparkReadSchema.flatMap { schema =>
+ schema.find(f => isSameFieldName(f.name, field.getName,
caseSensitive)).map(_.dataType)
+ }
+
+ // if a field is repeated here then it is neither contained by a `LIST`
nor `MAP`
+ // annotated group (these should've been handled in
`convertGroupField`), e.g.:
+ //
+ // message schema {
+ // repeated int32 int_array;
+ // }
+ // or
+ // message schema {
+ // repeated group struct_array {
+ // optional int32 field;
+ // }
+ // }
+ //
+ // the corresponding Spark read type should be an array and we should
pass the element type
+ // to the group or primitive type conversion method.
+ if (field.getType.getRepetition == REPEATED) {
+ fieldReadType = fieldReadType.flatMap {
+ case at: ArrayType => Some(at.elementType)
+ case _ =>
+ throw
QueryCompilationErrors.illegalParquetTypeError(groupColumn.toString)
+ }
+ }
+
+ val convertedField = convertField(field, fieldReadType)
+
+ field.getType.getRepetition match {
+ case OPTIONAL | REQUIRED =>
+ val nullable = field.getType.getRepetition == OPTIONAL
+ (StructField(field.getType.getName, convertedField.sparkType,
nullable = nullable),
+ convertedField)
case REPEATED =>
// A repeated field that is neither contained by a `LIST`- or
`MAP`-annotated group nor
// annotated by `LIST` or `MAP` should be interpreted as a required
list of required
// elements where the element type is the type of the field.
- val arrayType = ArrayType(convertField(field), containsNull = false)
- StructField(field.getName, arrayType, nullable = false)
+ val arrayType = ArrayType(convertedField.sparkType, containsNull =
false)
+ (StructField(field.getType.getName, arrayType, nullable = false),
+ ParquetType(arrayType, None, convertedField.repetitionLevel - 1,
+ convertedField.definitionLevel - 1, required = true,
convertedField.path,
+ Seq(convertedField.copy(required = true))))
}
}
- StructType(fields.toSeq)
+ ParquetType(StructType(converted.map(_._1)), groupColumn,
converted.map(_._2))
}
+ private def isSameFieldName(left: String, right: String, caseSensitive:
Boolean): Boolean =
+ if (caseSensitive) left.equalsIgnoreCase(right)
Review comment:
Good catch! not sure why it wasn't caught by unit tests but let me add
one for this.
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala
##########
@@ -609,7 +610,9 @@ private[parquet] class ParquetRowConverter(
//
// If the element type does not match the Catalyst type and the
underlying repeated type
// does not belong to the legacy LIST type, then it is case 1;
otherwise, it is case 2.
- val guessedElementType = schemaConverter.convertField(repeatedType)
+ val messageType =
Types.buildMessage().addField(repeatedType).named("foo")
Review comment:
Not really - we need _some_ string to pass to `Types.buildMessage` which
is why this is here. I'll add some comment.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]