This is an automated email from the ASF dual-hosted git repository.

xiangfu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 226087f29c Fix the bug of reading decimal value stored in int32 or 
int64 (#11840)
226087f29c is described below

commit 226087f29c4badf4aff066acedea2a3b8f91c811
Author: Xiang Fu <[email protected]>
AuthorDate: Fri Oct 20 22:20:39 2023 -0700

    Fix the bug of reading decimal value stored in int32 or int64 (#11840)
---
 .../parquet/ParquetNativeRecordExtractor.java      |  53 ++++----
 .../parquet/ParquetReaderLogicalTypesTest.java     | 140 +++++++++++++++++++++
 2 files changed, 171 insertions(+), 22 deletions(-)

diff --git 
a/pinot-plugins/pinot-input-format/pinot-parquet/src/main/java/org/apache/pinot/plugin/inputformat/parquet/ParquetNativeRecordExtractor.java
 
b/pinot-plugins/pinot-input-format/pinot-parquet/src/main/java/org/apache/pinot/plugin/inputformat/parquet/ParquetNativeRecordExtractor.java
index 678ea1175d..d7a1dd0e22 100644
--- 
a/pinot-plugins/pinot-input-format/pinot-parquet/src/main/java/org/apache/pinot/plugin/inputformat/parquet/ParquetNativeRecordExtractor.java
+++ 
b/pinot-plugins/pinot-input-format/pinot-parquet/src/main/java/org/apache/pinot/plugin/inputformat/parquet/ParquetNativeRecordExtractor.java
@@ -33,9 +33,8 @@ import java.util.Set;
 import javax.annotation.Nullable;
 import org.apache.parquet.example.data.Group;
 import org.apache.parquet.io.api.Binary;
-import org.apache.parquet.schema.DecimalMetadata;
 import org.apache.parquet.schema.GroupType;
-import org.apache.parquet.schema.OriginalType;
+import org.apache.parquet.schema.LogicalTypeAnnotation;
 import org.apache.parquet.schema.PrimitiveType;
 import org.apache.parquet.schema.Type;
 import org.apache.pinot.spi.data.readers.BaseRecordExtractor;
@@ -143,14 +142,26 @@ public class ParquetNativeRecordExtractor extends 
BaseRecordExtractor<Group> {
   }
 
   private Object extractValue(Group from, int fieldIndex, Type fieldType, int 
index) {
-    OriginalType originalType = fieldType.getOriginalType();
+    LogicalTypeAnnotation logicalTypeAnnotation = 
fieldType.getLogicalTypeAnnotation();
     if (fieldType.isPrimitive()) {
       PrimitiveType.PrimitiveTypeName primitiveTypeName = 
fieldType.asPrimitiveType().getPrimitiveTypeName();
       switch (primitiveTypeName) {
         case INT32:
-          return from.getInteger(fieldIndex, index);
+          int intValue = from.getInteger(fieldIndex, index);
+          if (logicalTypeAnnotation instanceof 
LogicalTypeAnnotation.DecimalLogicalTypeAnnotation) {
+            LogicalTypeAnnotation.DecimalLogicalTypeAnnotation 
decimalLogicalTypeAnnotation =
+                (LogicalTypeAnnotation.DecimalLogicalTypeAnnotation) 
logicalTypeAnnotation;
+            return BigDecimal.valueOf(intValue, 
decimalLogicalTypeAnnotation.getScale());
+          }
+          return intValue;
         case INT64:
-          return from.getLong(fieldIndex, index);
+          long longValue = from.getLong(fieldIndex, index);
+          if (logicalTypeAnnotation instanceof 
LogicalTypeAnnotation.DecimalLogicalTypeAnnotation) {
+            LogicalTypeAnnotation.DecimalLogicalTypeAnnotation 
decimalLogicalTypeAnnotation =
+                (LogicalTypeAnnotation.DecimalLogicalTypeAnnotation) 
logicalTypeAnnotation;
+            return BigDecimal.valueOf(longValue, 
decimalLogicalTypeAnnotation.getScale());
+          }
+          return longValue;
         case FLOAT:
           return from.getFloat(fieldIndex, index);
         case DOUBLE:
@@ -160,34 +171,32 @@ public class ParquetNativeRecordExtractor extends 
BaseRecordExtractor<Group> {
         case INT96:
           Binary int96 = from.getInt96(fieldIndex, index);
           ByteBuffer buf = 
ByteBuffer.wrap(int96.getBytes()).order(ByteOrder.LITTLE_ENDIAN);
-          long dateTime = (buf.getInt(8) - JULIAN_DAY_NUMBER_FOR_UNIX_EPOCH) * 
DateTimeConstants.MILLIS_PER_DAY
+          return (buf.getInt(8) - JULIAN_DAY_NUMBER_FOR_UNIX_EPOCH) * 
DateTimeConstants.MILLIS_PER_DAY
               + buf.getLong(0) / NANOS_PER_MILLISECOND;
-          return dateTime;
         case BINARY:
         case FIXED_LEN_BYTE_ARRAY:
-          if (originalType != null) {
-            switch (originalType) {
-              case UTF8:
-              case ENUM:
-                return from.getValueToString(fieldIndex, index);
-              case DECIMAL:
-                DecimalMetadata decimalMetadata = 
fieldType.asPrimitiveType().getDecimalMetadata();
-                return binaryToDecimal(from.getBinary(fieldIndex, index), 
decimalMetadata.getPrecision(),
-                    decimalMetadata.getScale());
-              default:
-                break;
-            }
+          if (logicalTypeAnnotation instanceof 
LogicalTypeAnnotation.DecimalLogicalTypeAnnotation) {
+            LogicalTypeAnnotation.DecimalLogicalTypeAnnotation 
decimalLogicalTypeAnnotation =
+                (LogicalTypeAnnotation.DecimalLogicalTypeAnnotation) 
logicalTypeAnnotation;
+            return binaryToDecimal(from.getBinary(fieldIndex, index), 
decimalLogicalTypeAnnotation.getPrecision(),
+                decimalLogicalTypeAnnotation.getScale());
+          }
+          if (logicalTypeAnnotation instanceof 
LogicalTypeAnnotation.StringLogicalTypeAnnotation) {
+            return from.getValueToString(fieldIndex, index);
+          }
+          if (logicalTypeAnnotation instanceof 
LogicalTypeAnnotation.EnumLogicalTypeAnnotation) {
+            return from.getValueToString(fieldIndex, index);
           }
           return from.getBinary(fieldIndex, index).getBytes();
         default:
           throw new IllegalArgumentException(
-              String.format("Unsupported field type: %s, primitive type: %s, 
original type: %s", fieldType,
-                  primitiveTypeName, originalType));
+              String.format("Unsupported field type: %s, primitive type: %s, 
logical type: %s", fieldType,
+                  primitiveTypeName, logicalTypeAnnotation));
       }
     } else if ((fieldType.isRepetition(Type.Repetition.OPTIONAL)) || 
(fieldType.isRepetition(Type.Repetition.REQUIRED))
         || (fieldType.isRepetition(Type.Repetition.REPEATED))) {
       Group group = from.getGroup(fieldIndex, index);
-      if (originalType == OriginalType.LIST) {
+      if (logicalTypeAnnotation instanceof 
LogicalTypeAnnotation.ListLogicalTypeAnnotation) {
         return extractList(group);
       }
       return extractMap(group);
diff --git 
a/pinot-plugins/pinot-input-format/pinot-parquet/src/test/java/org/apache/pinot/plugin/inputformat/parquet/ParquetReaderLogicalTypesTest.java
 
b/pinot-plugins/pinot-input-format/pinot-parquet/src/test/java/org/apache/pinot/plugin/inputformat/parquet/ParquetReaderLogicalTypesTest.java
new file mode 100644
index 0000000000..4efd9689b1
--- /dev/null
+++ 
b/pinot-plugins/pinot-input-format/pinot-parquet/src/test/java/org/apache/pinot/plugin/inputformat/parquet/ParquetReaderLogicalTypesTest.java
@@ -0,0 +1,140 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.plugin.inputformat.parquet;
+
+import java.io.File;
+import java.io.IOException;
+import java.math.BigDecimal;
+import java.util.HashMap;
+import java.util.Map;
+import org.apache.commons.io.FileUtils;
+import org.apache.hadoop.fs.Path;
+import org.apache.parquet.example.data.Group;
+import org.apache.parquet.example.data.simple.NanoTime;
+import org.apache.parquet.example.data.simple.SimpleGroupFactory;
+import org.apache.parquet.hadoop.ParquetWriter;
+import org.apache.parquet.hadoop.example.ExampleParquetWriter;
+import org.apache.parquet.hadoop.metadata.CompressionCodecName;
+import org.apache.parquet.io.api.Binary;
+import org.apache.parquet.schema.MessageType;
+import org.apache.parquet.schema.MessageTypeParser;
+import org.apache.pinot.spi.data.readers.GenericRow;
+import org.testng.annotations.Test;
+
+import static org.testng.Assert.assertEquals;
+
+
+/**
+ * TODO: add test for other logical types
+ */
+public class ParquetReaderLogicalTypesTest {
+
+  @Test
+  public void testDecimalType() {
+    String schemaStr = "message DecimalExample {"
+        + "required int32 decimal_from_int32 (DECIMAL(9, 2));"
+        + "required int64 decimal_from_int64 (DECIMAL(18, 2));"
+        + "required fixed_len_byte_array(16) decimal_from_fixed (DECIMAL(38, 
9));"
+        + "required binary decimal_from_binary (DECIMAL(38, 9));"
+        + "}";
+
+    byte[] byteArray = new 
BigDecimal("12345678901234567890123456789.012345678").unscaledValue().toByteArray();
+    byte[] paddedArray = new byte[16];
+    System.arraycopy(byteArray, 0, paddedArray, 16 - byteArray.length, 
byteArray.length);
+    Map<String, Object> row = new HashMap<>();
+    row.put("decimal_from_int32", new 
BigDecimal("123.45").unscaledValue().intValue());
+    row.put("decimal_from_int64", new 
BigDecimal("1234567890123.45").unscaledValue().longValue());
+    row.put("decimal_from_fixed", Binary.fromConstantByteArray(paddedArray));
+    row.put("decimal_from_binary", Binary.fromConstantByteArray(byteArray));
+
+    String outputPath = null;
+    try {
+      outputPath = writeToFile(schemaStr, row);
+      try (ParquetNativeRecordReader recordReader = new 
ParquetNativeRecordReader()) {
+        recordReader.init(new File(outputPath), null, null);
+        recordReader.rewind();
+        GenericRow genericRow = recordReader.next();
+        assertEquals(genericRow.getValue("decimal_from_int32"), new 
BigDecimal("123.45"));
+        assertEquals(genericRow.getValue("decimal_from_int64"), new 
BigDecimal("1234567890123.45"));
+        assertEquals(genericRow.getValue("decimal_from_fixed"),
+            new BigDecimal("12345678901234567890123456789.012345678"));
+        assertEquals(genericRow.getValue("decimal_from_binary"),
+            new BigDecimal("12345678901234567890123456789.012345678"));
+      }
+    } catch (IOException e) {
+      throw new RuntimeException(e);
+    } finally {
+      if (outputPath != null) {
+        FileUtils.deleteQuietly(new File(outputPath));
+      }
+    }
+  }
+
+  /**
+   * Write the given row object to a random parquet file and return the path 
of the file
+   */
+  private String writeToFile(String schemaStr, Map<String, Object> rowObject)
+      throws IOException {
+    Path outputPath = new Path(FileUtils.getTempDirectory().toString(), 
"example.parquet");
+    MessageType schema = MessageTypeParser.parseMessageType(schemaStr);
+    try (ParquetWriter<Group> writer = ExampleParquetWriter.builder(outputPath)
+        .withType(schema)
+        .withCompressionCodec(CompressionCodecName.SNAPPY)
+        .withPageSize(ParquetWriter.DEFAULT_PAGE_SIZE)
+        .build()) {
+      SimpleGroupFactory groupFactory = new SimpleGroupFactory(schema);
+      Group group = groupFactory.newGroup();
+      for (Map.Entry<String, Object> entry : rowObject.entrySet()) {
+        if (entry.getValue() instanceof Integer) {
+          group.append(entry.getKey(), (int) entry.getValue());
+          continue;
+        }
+        if (entry.getValue() instanceof Long) {
+          group.append(entry.getKey(), (long) entry.getValue());
+          continue;
+        }
+        if (entry.getValue() instanceof Float) {
+          group.append(entry.getKey(), (Float) entry.getValue());
+          continue;
+        }
+        if (entry.getValue() instanceof Double) {
+          group.append(entry.getKey(), (Double) entry.getValue());
+          continue;
+        }
+        if (entry.getValue() instanceof NanoTime) {
+          group.append(entry.getKey(), (NanoTime) entry.getValue());
+          continue;
+        }
+        if (entry.getValue() instanceof String) {
+          group.append(entry.getKey(), (String) entry.getValue());
+          continue;
+        }
+        if (entry.getValue() instanceof Boolean) {
+          group.append(entry.getKey(), (Boolean) entry.getValue());
+          continue;
+        }
+        if (entry.getValue() instanceof Binary) {
+          group.append(entry.getKey(), (Binary) entry.getValue());
+        }
+      }
+      writer.write(group);
+    }
+    return outputPath.toString();
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to