sunchao commented on a change in pull request #32777: URL: https://github.com/apache/spark/pull/32777#discussion_r648483520
########## File path: sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java ########## @@ -0,0 +1,979 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet; + +import org.apache.parquet.column.ColumnDescriptor; +import org.apache.parquet.column.Dictionary; +import org.apache.parquet.io.api.Binary; +import org.apache.parquet.schema.LogicalTypeAnnotation; +import org.apache.parquet.schema.LogicalTypeAnnotation.IntLogicalTypeAnnotation; +import org.apache.parquet.schema.LogicalTypeAnnotation.DecimalLogicalTypeAnnotation; +import org.apache.parquet.schema.LogicalTypeAnnotation.TimestampLogicalTypeAnnotation; +import org.apache.parquet.schema.PrimitiveType; + +import org.apache.spark.sql.catalyst.util.DateTimeUtils; +import org.apache.spark.sql.catalyst.util.RebaseDateTime; +import org.apache.spark.sql.execution.datasources.DataSourceUtils; +import org.apache.spark.sql.execution.datasources.SchemaColumnConvertNotSupportedException; +import org.apache.spark.sql.execution.vectorized.WritableColumnVector; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.DecimalType; + +import java.math.BigInteger; +import java.time.ZoneId; +import java.time.ZoneOffset; +import java.util.Arrays; + +public class ParquetVectorUpdaterFactory { + private static final ZoneId UTC = ZoneOffset.UTC; + + private final LogicalTypeAnnotation logicalTypeAnnotation; + // The timezone conversion to apply to int96 timestamps. Null if no conversion. + private final ZoneId convertTz; + private final String datetimeRebaseMode; + private final String int96RebaseMode; + + ParquetVectorUpdaterFactory( + LogicalTypeAnnotation logicalTypeAnnotation, + ZoneId convertTz, + String datetimeRebaseMode, + String int96RebaseMode) { + this.logicalTypeAnnotation = logicalTypeAnnotation; + this.convertTz = convertTz; + this.datetimeRebaseMode = datetimeRebaseMode; + this.int96RebaseMode = int96RebaseMode; + } + + public ParquetVectorUpdater getUpdater(ColumnDescriptor descriptor, DataType sparkType) { + PrimitiveType.PrimitiveTypeName typeName = descriptor.getPrimitiveType().getPrimitiveTypeName(); + + switch (typeName) { + case BOOLEAN: + if (sparkType == DataTypes.BooleanType) { + return new BooleanUpdater(); + } + break; + case INT32: + if (sparkType == DataTypes.IntegerType || canReadAsIntDecimal(descriptor, sparkType)) { + return new IntegerUpdater(); + } else if (sparkType == DataTypes.LongType && isUnsignedIntTypeMatched(32)) { + // In `ParquetToSparkSchemaConverter`, we map parquet UINT32 to our LongType. + // For unsigned int32, it stores as plain signed int32 in Parquet when dictionary + // fallbacks. We read them as long values. + return new UnsignedIntegerUpdater(); + } else if (sparkType == DataTypes.ByteType) { + return new ByteUpdater(); + } else if (sparkType == DataTypes.ShortType) { + return new ShortUpdater(); + } else if (sparkType == DataTypes.DateType) { + if ("CORRECTED".equals(datetimeRebaseMode)) { + return new IntegerUpdater(); + } else { + boolean failIfRebase = "EXCEPTION".equals(datetimeRebaseMode); + return new IntegerWithRebaseUpdater(failIfRebase); + } + } + break; + case INT64: + // This is where we implement support for the valid type conversions. + if (sparkType == DataTypes.LongType || canReadAsLongDecimal(descriptor, sparkType)) { + return new LongUpdater(DecimalType.is32BitDecimalType(sparkType)); + } else if (isUnsignedIntTypeMatched(64)) { + // In `ParquetToSparkSchemaConverter`, we map parquet UINT64 to our Decimal(20, 0). + // For unsigned int64, it stores as plain signed int64 in Parquet when dictionary + // fallbacks. We read them as decimal values. + return new UnsignedLongUpdater(); + } else if (isTimestampTypeMatched(LogicalTypeAnnotation.TimeUnit.MICROS)) { + if ("CORRECTED".equals(datetimeRebaseMode)) { + return new LongUpdater(false); + } else { + boolean failIfRebase = "EXCEPTION".equals(datetimeRebaseMode); + return new LongWithRebaseUpdater(failIfRebase); + } + } else if (isTimestampTypeMatched(LogicalTypeAnnotation.TimeUnit.MILLIS)) { + if ("CORRECTED".equals(datetimeRebaseMode)) { + return new LongAsMicrosUpdater(); + } else { + final boolean failIfRebase = "EXCEPTION".equals(datetimeRebaseMode); + return new LongAsMicrosRebaseUpdater(failIfRebase); + } + } + break; + case FLOAT: + if (sparkType == DataTypes.FloatType) { + return new FloatUpdater(); + } + break; + case DOUBLE: + if (sparkType == DataTypes.DoubleType) { + return new DoubleUpdater(); + } + break; + case INT96: + if (sparkType == DataTypes.TimestampType) { + final boolean failIfRebase = "EXCEPTION".equals(int96RebaseMode); + if (!shouldConvertTimestamps()) { + if ("CORRECTED".equals(int96RebaseMode)) { + return new BinaryToSQLTimestampUpdater(); + } else { + return new BinaryToSQLTimestampRebaseUpdater(failIfRebase); + } + } else { + if ("CORRECTED".equals(int96RebaseMode)) { + return new BinaryToSQLTimestampConvertTzUpdater(convertTz); + } else { + return new BinaryToSQLTimestampConvertTzRebaseUpdater(failIfRebase, convertTz); + } + } + } + break; + case BINARY: + if (sparkType == DataTypes.StringType || sparkType == DataTypes.BinaryType || + canReadAsBinaryDecimal(descriptor, sparkType)) { + return new BinaryUpdater(); + } + break; + case FIXED_LEN_BYTE_ARRAY: + int arrayLen = descriptor.getPrimitiveType().getTypeLength(); + if (canReadAsIntDecimal(descriptor, sparkType)) { + return new FixedLenByteArrayAsIntUpdater(arrayLen); + } else if (canReadAsLongDecimal(descriptor, sparkType)) { + return new FixedLenByteArrayAsLongUpdater(arrayLen); + } else if (canReadAsBinaryDecimal(descriptor, sparkType)) { + return new FixedLenByteArrayUpdater(arrayLen); + } + break; + default: + break; + } + + // If we get here, it means the combination of Spark and Parquet type is invalid or not + // supported. + throw constructConvertNotSupportedException(descriptor, sparkType); + } + + boolean isTimestampTypeMatched(LogicalTypeAnnotation.TimeUnit unit) { + return logicalTypeAnnotation instanceof TimestampLogicalTypeAnnotation && + ((TimestampLogicalTypeAnnotation) logicalTypeAnnotation).getUnit() == unit; + } + + boolean isUnsignedIntTypeMatched(int bitWidth) { + return logicalTypeAnnotation instanceof IntLogicalTypeAnnotation && + !((IntLogicalTypeAnnotation) logicalTypeAnnotation).isSigned() && + ((IntLogicalTypeAnnotation) logicalTypeAnnotation).getBitWidth() == bitWidth; + } + + private static class BooleanUpdater implements ParquetVectorUpdater { + @Override + public void updateBatch( + int total, + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + valuesReader.readBooleans(total, values, offset); + } + + @Override + public void update( + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + values.putBoolean(offset, valuesReader.readBoolean()); + } + + @Override + public void decodeSingleDictionaryId( + int offset, + WritableColumnVector values, + WritableColumnVector dictionaryIds, + Dictionary dictionary) { + throw new UnsupportedOperationException("Boolean is not supported"); + } + } + + private static class IntegerUpdater implements ParquetVectorUpdater { + @Override + public void updateBatch( + int total, + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + valuesReader.readIntegers(total, values, offset); + } + + @Override + public void update( + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + values.putInt(offset, valuesReader.readInteger()); + } + + @Override + public void decodeSingleDictionaryId( + int offset, + WritableColumnVector values, + WritableColumnVector dictionaryIds, + Dictionary dictionary) { + values.putInt(offset, dictionary.decodeToInt(dictionaryIds.getDictId(offset))); + } + } + + private static class UnsignedIntegerUpdater implements ParquetVectorUpdater { + @Override + public void updateBatch( + int total, + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + valuesReader.readUnsignedIntegers(total, values, offset); + } + + @Override + public void update( + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + values.putLong(offset, Integer.toUnsignedLong(valuesReader.readInteger())); + } + + @Override + public void decodeSingleDictionaryId( + int offset, + WritableColumnVector values, + WritableColumnVector dictionaryIds, + Dictionary dictionary) { + values.putLong(offset, Integer.toUnsignedLong( + dictionary.decodeToInt(dictionaryIds.getDictId(offset)))); + } + } + + private static class ByteUpdater implements ParquetVectorUpdater { + @Override + public void updateBatch( + int total, + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + valuesReader.readBytes(total, values, offset); + } + + @Override + public void update( + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + values.putByte(offset, valuesReader.readByte()); + } + + @Override + public void decodeSingleDictionaryId( + int offset, + WritableColumnVector values, + WritableColumnVector dictionaryIds, + Dictionary dictionary) { + values.putByte(offset, (byte) dictionary.decodeToInt(dictionaryIds.getDictId(offset))); + } + } + + private static class ShortUpdater implements ParquetVectorUpdater { + @Override + public void updateBatch( + int total, + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + valuesReader.readShorts(total, values, offset); + } + + @Override + public void update( + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + values.putShort(offset, valuesReader.readShort()); + } + + @Override + public void decodeSingleDictionaryId( + int offset, + WritableColumnVector values, + WritableColumnVector dictionaryIds, + Dictionary dictionary) { + values.putShort(offset, (short) dictionary.decodeToInt(dictionaryIds.getDictId(offset))); + } + } + + private static class IntegerWithRebaseUpdater implements ParquetVectorUpdater { + private final boolean failIfRebase; + + IntegerWithRebaseUpdater(boolean failIfRebase) { + this.failIfRebase = failIfRebase; + } + + @Override + public void updateBatch( + int total, + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + valuesReader.readIntegersWithRebase(total, values, offset, failIfRebase); + } + + @Override + public void update( + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + int julianDays = valuesReader.readInteger(); + values.putInt(offset, rebaseDays(julianDays, failIfRebase)); + } + + @Override + public void decodeSingleDictionaryId( + int offset, + WritableColumnVector values, + WritableColumnVector dictionaryIds, + Dictionary dictionary) { + int julianDays = dictionary.decodeToInt(dictionaryIds.getDictId(offset)); + values.putInt(offset, rebaseDays(julianDays, failIfRebase)); + } + } + + private static class LongUpdater implements ParquetVectorUpdater { + private final boolean downCastLongToInt; + + LongUpdater(boolean downCastLongToInt) { + this.downCastLongToInt = downCastLongToInt; + } + + @Override + public void updateBatch( + int total, + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + if (downCastLongToInt) { + for (int i = 0; i < total; ++i) { + values.putInt(offset + i, (int) valuesReader.readLong()); + } + } else { + valuesReader.readLongs(total, values, offset); + } + } + + @Override + public void update( + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + if (downCastLongToInt) { Review comment: Yea I thought about it too. I think it's a good idea. Will do. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
