sunchao commented on a change in pull request #34199: URL: https://github.com/apache/spark/pull/34199#discussion_r741271037
########## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetColumn.scala ########## @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import scala.collection.mutable + +import org.apache.parquet.column.ColumnDescriptor +import org.apache.parquet.io.{ColumnIOUtil, GroupColumnIO, PrimitiveColumnIO} +import org.apache.parquet.schema.Type.Repetition + +import org.apache.spark.sql.types.DataType + +/** + * Rich information for a Parquet column together with its SparkSQL type. + */ +case class ParquetColumn( + sparkType: DataType, + descriptor: Option[ColumnDescriptor], // only set when this is a primitive column + repetitionLevel: Int, + definitionLevel: Int, + required: Boolean, + path: Seq[String], + children: Seq[ParquetColumn]) { + + def isPrimitive: Boolean = descriptor.nonEmpty + + /** + * Returns all the leaves (i.e., primitive columns) of this, in depth-first order. + */ + def leaves: Seq[ParquetColumn] = { + val buffer = mutable.ArrayBuffer[ParquetColumn]() + leaves0(buffer) + buffer.toSeq + } + + private def leaves0(buffer: mutable.ArrayBuffer[ParquetColumn]): Unit = { + children.foreach(_.leaves0(buffer)) + } +} Review comment: Oops my bad. I don't think we need this method yet so I'll remove it here. ########## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala ########## @@ -43,57 +41,128 @@ import org.apache.spark.sql.types._ * [[StringType]] fields. * @param assumeInt96IsTimestamp Whether unannotated INT96 fields should be assumed to be Spark SQL * [[TimestampType]] fields. + * @param caseSensitive Whether use case sensitive analysis when comparing Spark catalyst read + * schema with Parquet schema */ class ParquetToSparkSchemaConverter( assumeBinaryIsString: Boolean = SQLConf.PARQUET_BINARY_AS_STRING.defaultValue.get, - assumeInt96IsTimestamp: Boolean = SQLConf.PARQUET_INT96_AS_TIMESTAMP.defaultValue.get) { + assumeInt96IsTimestamp: Boolean = SQLConf.PARQUET_INT96_AS_TIMESTAMP.defaultValue.get, + caseSensitive: Boolean = SQLConf.CASE_SENSITIVE.defaultValue.get) { def this(conf: SQLConf) = this( assumeBinaryIsString = conf.isParquetBinaryAsString, - assumeInt96IsTimestamp = conf.isParquetINT96AsTimestamp) + assumeInt96IsTimestamp = conf.isParquetINT96AsTimestamp, + caseSensitive = conf.caseSensitiveAnalysis) def this(conf: Configuration) = this( assumeBinaryIsString = conf.get(SQLConf.PARQUET_BINARY_AS_STRING.key).toBoolean, - assumeInt96IsTimestamp = conf.get(SQLConf.PARQUET_INT96_AS_TIMESTAMP.key).toBoolean) + assumeInt96IsTimestamp = conf.get(SQLConf.PARQUET_INT96_AS_TIMESTAMP.key).toBoolean, + caseSensitive = conf.get(SQLConf.CASE_SENSITIVE.key).toBoolean) /** * Converts Parquet [[MessageType]] `parquetSchema` to a Spark SQL [[StructType]]. */ - def convert(parquetSchema: MessageType): StructType = convert(parquetSchema.asGroupType()) + def convert(parquetSchema: MessageType): StructType = { + val column = new ColumnIOFactory().getColumnIO(parquetSchema) + val converted = convertInternal(column) + converted.sparkType.asInstanceOf[StructType] + } - private def convert(parquetSchema: GroupType): StructType = { - val fields = parquetSchema.getFields.asScala.map { field => - field.getRepetition match { - case OPTIONAL => - StructField(field.getName, convertField(field), nullable = true) + /** + * Convert `parquetSchema` into a [[ParquetColumn]] which contains its corresponding Spark + * SQL [[StructType]] along with other information such as the maximum repetition and definition + * level of each node, column descriptor for the leave nodes, etc. + * + * If `sparkReadSchema` is not empty, when deriving Spark SQL type from a Parquet field this will + * check if the same field also exists in the schema. If so, it will use the Spark SQL type. + * This is necessary since conversion from Parquet to Spark could cause precision loss. For + * instance, Spark read schema is smallint/tinyint but Parquet only support int. + */ + def convertParquetColumn( + parquetSchema: MessageType, + sparkReadSchema: Option[StructType] = None): ParquetColumn = { + val column = new ColumnIOFactory().getColumnIO(parquetSchema) + convertInternal(column, sparkReadSchema) + } - case REQUIRED => - StructField(field.getName, convertField(field), nullable = false) + private def convertInternal( + groupColumn: GroupColumnIO, + sparkReadSchema: Option[StructType] = None): ParquetColumn = { + val converted = (0 until groupColumn.getChildrenCount).map { i => + val field = groupColumn.getChild(i) + val fieldFromReadSchema = sparkReadSchema.flatMap { schema => + schema.find(f => isSameFieldName(f.name, field.getName, caseSensitive)) + } + var fieldReadType = fieldFromReadSchema.map(_.dataType) + + // If a field is repeated here then it is neither contained by a `LIST` nor `MAP` + // annotated group (these should've been handled in `convertGroupField`), e.g.: + // + // message schema { + // repeated int32 int_array; + // } + // or + // message schema { + // repeated group struct_array { + // optional int32 field; + // } + // } + // + // the corresponding Spark read type should be an array and we should pass the element type + // to the group or primitive type conversion method. + if (field.getType.getRepetition == REPEATED) { + fieldReadType = fieldReadType.flatMap { + case at: ArrayType => Some(at.elementType) + case _ => + throw QueryCompilationErrors.illegalParquetTypeError(groupColumn.toString) + } + } + + val convertedField = convertField(field, fieldReadType) + val fieldName = fieldFromReadSchema.map(_.name).getOrElse(field.getType.getName) + + field.getType.getRepetition match { + case OPTIONAL | REQUIRED => + val nullable = field.getType.getRepetition == OPTIONAL + (StructField(fieldName, convertedField.sparkType, nullable = nullable), + convertedField) case REPEATED => // A repeated field that is neither contained by a `LIST`- or `MAP`-annotated group nor // annotated by `LIST` or `MAP` should be interpreted as a required list of required // elements where the element type is the type of the field. - val arrayType = ArrayType(convertField(field), containsNull = false) - StructField(field.getName, arrayType, nullable = false) + val arrayType = ArrayType(convertedField.sparkType, containsNull = false) + (StructField(fieldName, arrayType, nullable = false), + ParquetColumn(arrayType, None, convertedField.repetitionLevel - 1, + convertedField.definitionLevel - 1, required = true, convertedField.path, + Seq(convertedField.copy(required = true)))) } } - StructType(fields.toSeq) + ParquetColumn(StructType(converted.map(_._1)), groupColumn, converted.map(_._2)) } + private def isSameFieldName(left: String, right: String, caseSensitive: Boolean): Boolean = + if (!caseSensitive) left.equalsIgnoreCase(right) + else left == right + Review comment: It doesn't seem easier since we also need to initialize `ParquetToSparkSchemaConverter` with `Configuration` which doesn't have a `resolver` available, so we still need to write a similar method I think. ########## File path: sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java ########## @@ -152,6 +152,7 @@ protected void initialize(String path, List<String> columns) throws IOException Configuration config = new Configuration(); config.setBoolean(SQLConf.PARQUET_BINARY_AS_STRING().key() , false); config.setBoolean(SQLConf.PARQUET_INT96_AS_TIMESTAMP().key(), false); + config.setBoolean(SQLConf.CASE_SENSITIVE().key(), false); Review comment: This path is only used for testing, so I followed the above lines to just set the default value for this config. ########## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala ########## @@ -43,57 +41,128 @@ import org.apache.spark.sql.types._ * [[StringType]] fields. * @param assumeInt96IsTimestamp Whether unannotated INT96 fields should be assumed to be Spark SQL * [[TimestampType]] fields. + * @param caseSensitive Whether use case sensitive analysis when comparing Spark catalyst read + * schema with Parquet schema */ class ParquetToSparkSchemaConverter( assumeBinaryIsString: Boolean = SQLConf.PARQUET_BINARY_AS_STRING.defaultValue.get, - assumeInt96IsTimestamp: Boolean = SQLConf.PARQUET_INT96_AS_TIMESTAMP.defaultValue.get) { + assumeInt96IsTimestamp: Boolean = SQLConf.PARQUET_INT96_AS_TIMESTAMP.defaultValue.get, + caseSensitive: Boolean = SQLConf.CASE_SENSITIVE.defaultValue.get) { def this(conf: SQLConf) = this( assumeBinaryIsString = conf.isParquetBinaryAsString, - assumeInt96IsTimestamp = conf.isParquetINT96AsTimestamp) + assumeInt96IsTimestamp = conf.isParquetINT96AsTimestamp, + caseSensitive = conf.caseSensitiveAnalysis) def this(conf: Configuration) = this( assumeBinaryIsString = conf.get(SQLConf.PARQUET_BINARY_AS_STRING.key).toBoolean, - assumeInt96IsTimestamp = conf.get(SQLConf.PARQUET_INT96_AS_TIMESTAMP.key).toBoolean) + assumeInt96IsTimestamp = conf.get(SQLConf.PARQUET_INT96_AS_TIMESTAMP.key).toBoolean, + caseSensitive = conf.get(SQLConf.CASE_SENSITIVE.key).toBoolean) /** * Converts Parquet [[MessageType]] `parquetSchema` to a Spark SQL [[StructType]]. */ - def convert(parquetSchema: MessageType): StructType = convert(parquetSchema.asGroupType()) + def convert(parquetSchema: MessageType): StructType = { + val column = new ColumnIOFactory().getColumnIO(parquetSchema) + val converted = convertInternal(column) + converted.sparkType.asInstanceOf[StructType] + } - private def convert(parquetSchema: GroupType): StructType = { - val fields = parquetSchema.getFields.asScala.map { field => - field.getRepetition match { - case OPTIONAL => - StructField(field.getName, convertField(field), nullable = true) + /** + * Convert `parquetSchema` into a [[ParquetColumn]] which contains its corresponding Spark + * SQL [[StructType]] along with other information such as the maximum repetition and definition + * level of each node, column descriptor for the leave nodes, etc. + * + * If `sparkReadSchema` is not empty, when deriving Spark SQL type from a Parquet field this will + * check if the same field also exists in the schema. If so, it will use the Spark SQL type. + * This is necessary since conversion from Parquet to Spark could cause precision loss. For + * instance, Spark read schema is smallint/tinyint but Parquet only support int. + */ + def convertParquetColumn( + parquetSchema: MessageType, + sparkReadSchema: Option[StructType] = None): ParquetColumn = { + val column = new ColumnIOFactory().getColumnIO(parquetSchema) + convertInternal(column, sparkReadSchema) + } - case REQUIRED => - StructField(field.getName, convertField(field), nullable = false) + private def convertInternal( + groupColumn: GroupColumnIO, + sparkReadSchema: Option[StructType] = None): ParquetColumn = { + val converted = (0 until groupColumn.getChildrenCount).map { i => + val field = groupColumn.getChild(i) + val fieldFromReadSchema = sparkReadSchema.flatMap { schema => + schema.find(f => isSameFieldName(f.name, field.getName, caseSensitive)) + } + var fieldReadType = fieldFromReadSchema.map(_.dataType) + + // If a field is repeated here then it is neither contained by a `LIST` nor `MAP` + // annotated group (these should've been handled in `convertGroupField`), e.g.: + // + // message schema { + // repeated int32 int_array; + // } + // or + // message schema { + // repeated group struct_array { + // optional int32 field; + // } + // } + // + // the corresponding Spark read type should be an array and we should pass the element type + // to the group or primitive type conversion method. + if (field.getType.getRepetition == REPEATED) { + fieldReadType = fieldReadType.flatMap { + case at: ArrayType => Some(at.elementType) + case _ => + throw QueryCompilationErrors.illegalParquetTypeError(groupColumn.toString) + } + } + + val convertedField = convertField(field, fieldReadType) + val fieldName = fieldFromReadSchema.map(_.name).getOrElse(field.getType.getName) + + field.getType.getRepetition match { + case OPTIONAL | REQUIRED => + val nullable = field.getType.getRepetition == OPTIONAL + (StructField(fieldName, convertedField.sparkType, nullable = nullable), + convertedField) case REPEATED => // A repeated field that is neither contained by a `LIST`- or `MAP`-annotated group nor // annotated by `LIST` or `MAP` should be interpreted as a required list of required // elements where the element type is the type of the field. - val arrayType = ArrayType(convertField(field), containsNull = false) - StructField(field.getName, arrayType, nullable = false) + val arrayType = ArrayType(convertedField.sparkType, containsNull = false) + (StructField(fieldName, arrayType, nullable = false), + ParquetColumn(arrayType, None, convertedField.repetitionLevel - 1, + convertedField.definitionLevel - 1, required = true, convertedField.path, + Seq(convertedField.copy(required = true)))) } } - StructType(fields.toSeq) + ParquetColumn(StructType(converted.map(_._1)), groupColumn, converted.map(_._2)) } + private def isSameFieldName(left: String, right: String, caseSensitive: Boolean): Boolean = + if (!caseSensitive) left.equalsIgnoreCase(right) + else left == right + /** * Converts a Parquet [[Type]] to a Spark SQL [[DataType]]. Review comment: Will update the comments here. ########## File path: sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala ########## @@ -114,7 +130,66 @@ abstract class ParquetSchemaTest extends ParquetTest with SharedSparkSession { sqlSchema, parquetSchema, binaryAsString, - int96AsTimestamp) + int96AsTimestamp, + expectedParquetColumn = expectedParquetColumn) + } + + protected def compareParquetColumn(actual: ParquetColumn, expected: ParquetColumn): Unit = { + assert(actual.sparkType == expected.sparkType, "sparkType mismatch: " + + s"actual = ${actual.sparkType}, expected = ${expected.sparkType}") + assert(actual.descriptor === expected.descriptor, "column descriptor mismatch: " + + s"actual = ${actual.descriptor}, expected = ${expected.descriptor})") + // Parquet ColumnDescriptor equals only compare path so we'll need to compare other fields Review comment: I think `path` equality is already compared above `actual.descriptor === expected.descriptor`? Fixed the comments. ########## File path: sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala ########## @@ -114,7 +130,66 @@ abstract class ParquetSchemaTest extends ParquetTest with SharedSparkSession { sqlSchema, parquetSchema, binaryAsString, - int96AsTimestamp) + int96AsTimestamp, + expectedParquetColumn = expectedParquetColumn) + } + + protected def compareParquetColumn(actual: ParquetColumn, expected: ParquetColumn): Unit = { + assert(actual.sparkType == expected.sparkType, "sparkType mismatch: " + + s"actual = ${actual.sparkType}, expected = ${expected.sparkType}") + assert(actual.descriptor === expected.descriptor, "column descriptor mismatch: " + + s"actual = ${actual.descriptor}, expected = ${expected.descriptor})") + // Parquet ColumnDescriptor equals only compare path so we'll need to compare other fields Review comment: I think `path` equality is already compared above via `actual.descriptor === expected.descriptor`? Fixed the comments. ########## File path: sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala ########## @@ -114,7 +130,66 @@ abstract class ParquetSchemaTest extends ParquetTest with SharedSparkSession { sqlSchema, parquetSchema, binaryAsString, - int96AsTimestamp) + int96AsTimestamp, + expectedParquetColumn = expectedParquetColumn) + } + + protected def compareParquetColumn(actual: ParquetColumn, expected: ParquetColumn): Unit = { + assert(actual.sparkType == expected.sparkType, "sparkType mismatch: " + + s"actual = ${actual.sparkType}, expected = ${expected.sparkType}") + assert(actual.descriptor === expected.descriptor, "column descriptor mismatch: " + + s"actual = ${actual.descriptor}, expected = ${expected.descriptor})") + // Parquet ColumnDescriptor equals only compare path so we'll need to compare other fields + // explicitly here + if (actual.descriptor.isDefined && expected.descriptor.isDefined) { + val actualDesc = actual.descriptor.get + val expectedDesc = expected.descriptor.get + assert(actualDesc.getMaxRepetitionLevel == expectedDesc.getMaxRepetitionLevel) Review comment: It looks like a bug to me, although I don't know the history (can't find it either). ########## File path: sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala ########## @@ -902,6 +1890,181 @@ class ParquetSchemaSuite extends ParquetSchemaTest { """.stripMargin, writeLegacyParquetFormat = true) + testParquetToCatalyst( + "SPARK-36935: test case insensitive when converting Parquet schema", + StructType(Seq(StructField("F1", ShortType))), + """message root { + | optional int32 f1; + |} + |""".stripMargin, Review comment: Eh I just followed the previous test case on this.. ########## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala ########## @@ -609,7 +610,13 @@ private[parquet] class ParquetRowConverter( // // If the element type does not match the Catalyst type and the underlying repeated type // does not belong to the legacy LIST type, then it is case 1; otherwise, it is case 2. - val guessedElementType = schemaConverter.convertField(repeatedType) + // + // Since `convertField` method requires a Parquet `ColumnIO` as input, here we first create + // a dummy message type which wraps the given repeated type, and then convert it to the + // `ColumnIO` using Parquet API. + val messageType = Types.buildMessage().addField(repeatedType).named("foo") Review comment: Hm not sure if it's necessary since the comments above already explained the reason and this is only the place where the dummy name is used. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
