This is an automated email from the ASF dual-hosted git repository.

kishoreg pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new deb3891  Add support for Decimal with Precision Sum aggregation (#6053)
deb3891 is described below

commit deb389182209db4a18761be9e20d7dcbf037b16b
Author: Kartik Khare <[email protected]>
AuthorDate: Fri Oct 2 00:57:32 2020 +0530

    Add support for Decimal with Precision Sum aggregation (#6053)
    
    * Add support for big decimal
    
    * Add transform function to factory
    
    * Add support for decimal with precision addition
    
    * Add license header
    
    * Add Sum Precision aggregation function
    
    * Remove add with precision transform function
    
    * Add license header
    
    * Refactor: Correct import of Scalar function
    
    * Add function to convert normal string to bigdecimal bytes
    
    * Add test for big decimal
    
    * Add test for big decimal precision
    
    * Add support for scale along with precision
    
    * Add license header
    
    * Add base64 encode functions
    
    * typo fix
    
    * Move arguments logic inside constructor
    
    * set scale and precision only in final result
    
    * Reduce scale bytes from 4 to 2
    
    * Add java docs for sum with precision function
    
    * Rename sumWithPrecision to sumPrecision
    
    * Adding methods to directly take big decimal input
    
    Co-authored-by: Kartik Khare <[email protected]>
---
 .../common/function/AggregationFunctionType.java   |   1 +
 .../scalar/DataTypeConversionFunctions.java        | 142 +++++++++++++
 .../apache/pinot/core/common/ObjectSerDeUtils.java |  31 ++-
 .../function/AggregationFunctionFactory.java       |   2 +
 .../function/SumPrecisionAggregationFunction.java  | 180 +++++++++++++++++
 .../function/AggregationFunctionFactoryTest.java   |   7 +
 .../apache/pinot/queries/SumWithPrecisionTest.java | 221 +++++++++++++++++++++
 7 files changed, 581 insertions(+), 3 deletions(-)

diff --git 
a/pinot-common/src/main/java/org/apache/pinot/common/function/AggregationFunctionType.java
 
b/pinot-common/src/main/java/org/apache/pinot/common/function/AggregationFunctionType.java
index aeae907..37517b2 100644
--- 
a/pinot-common/src/main/java/org/apache/pinot/common/function/AggregationFunctionType.java
+++ 
b/pinot-common/src/main/java/org/apache/pinot/common/function/AggregationFunctionType.java
@@ -30,6 +30,7 @@ public enum AggregationFunctionType {
   MIN("min"),
   MAX("max"),
   SUM("sum"),
+  SUMPRECISION("sumPrecision"),
   AVG("avg"),
   MINMAXRANGE("minMaxRange"),
   DISTINCTCOUNT("distinctCount"),
diff --git 
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/DataTypeConversionFunctions.java
 
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/DataTypeConversionFunctions.java
new file mode 100644
index 0000000..d9ec3b1
--- /dev/null
+++ 
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/DataTypeConversionFunctions.java
@@ -0,0 +1,142 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.common.function.scalar;
+
+import java.math.BigDecimal;
+import java.math.BigInteger;
+import java.util.Base64;
+import org.apache.pinot.spi.annotations.ScalarFunction;
+
+
+/**
+ * Contains function to convert a datatype to another datatype.
+ */
+public class DataTypeConversionFunctions {
+  private DataTypeConversionFunctions() {
+
+  }
+
+  /**
+   * Converts big decimal string representation to bytes.
+   * Only scale of upto 2 bytes is supported by the function
+   * @param number big decimal number in plain string. e.g. '1234.12121'
+   * @return The result byte array contains the bytes of the unscaled value 
appended to bytes of the scale in BIG ENDIAN order.
+   */
+  @ScalarFunction
+  public static byte[] bigDecimalToBytes(String number) {
+    BigDecimal bigDecimal = new BigDecimal(number);
+    return bigDecimalToBytes(bigDecimal);
+  }
+
+  /**
+   * Converts bytes value representation generated by {@link 
#bigDecimalToBytes(String)} back to string big decimal
+   * @param bytes array that contains the bytes of the unscaled value appended 
to 2 bytes of the scale in BIG ENDIAN order.
+   * @return plain string representation of big decimal
+   */
+  @ScalarFunction
+  public static String bytesToBigDecimal(byte[] bytes) {
+    BigDecimal number = bytesToBigDecimalObject(bytes);
+    return number.toString();
+  }
+
+  /**
+   * convert simple hex string to byte array
+   * @param hex a plain hex string e.g. 'f0e1a3b2'
+   * @return byte array representation of hex string
+   */
+  @ScalarFunction
+  public static byte[] hexToBytes(String hex) {
+    int len = hex.length();
+    byte[] data = new byte[len / 2];
+    for (int i = 0; i < len; i += 2) {
+      data[i / 2] = (byte) ((Character.digit(hex.charAt(i), 16) << 4) + 
Character.digit(hex.charAt(i + 1), 16));
+    }
+    return data;
+  }
+
+  /**
+   * convert simple bytes array to hex string
+   * @param bytes any byte array
+   * @return plain hex string e.g. 'f012be3c'
+   */
+  @ScalarFunction
+  public static String bytesToHex(byte[] bytes) {
+    StringBuilder sb = new StringBuilder();
+    for (byte b : bytes) {
+      sb.append(String.format("%02X ", b));
+    }
+
+    return sb.toString();
+  }
+
+  /**
+   * Converts bytes value representation generated by {@link 
#bigDecimalToBytes(String)} back to string big decimal
+   * @param bytes array that contains the bytes of the unscaled value appended 
to 2 bytes of the scale in BIG ENDIAN order.
+   * @return big decimal object
+   */
+  public static BigDecimal bytesToBigDecimalObject(byte[] bytes) {
+    int scale = 0;
+    scale += (((int) bytes[0]) << (8));
+    scale += (((int) bytes[1]));
+    byte[] vals = new byte[bytes.length - 2];
+    System.arraycopy(bytes, 2, vals, 0, vals.length);
+    BigInteger unscaled = new BigInteger(vals);
+    BigDecimal number = new BigDecimal(unscaled, scale);
+    return number;
+  }
+
+  /**
+   * Converts big decimal string representation to bytes.
+   * Only scale of upto 2 bytes is supported by the function
+   * @param bigDecimal big decimal number object
+   * @return The result byte array contains the bytes of the unscaled value 
appended to bytes of the scale in BIG ENDIAN order.
+   */
+  public static byte[] bigDecimalToBytes(BigDecimal bigDecimal) {
+    int scale = bigDecimal.scale();
+    BigInteger unscaled = bigDecimal.unscaledValue();
+    byte[] value = unscaled.toByteArray();
+    byte[] bigDecimalBytesArray = new byte[value.length + 2];
+
+    bigDecimalBytesArray[0] = (byte) (scale >>> 8);
+    bigDecimalBytesArray[1] = (byte) (scale);
+
+    System.arraycopy(value, 0, bigDecimalBytesArray, 2, value.length);
+    return bigDecimalBytesArray;
+  }
+
+  /**
+   * encode byte array to base64 using {@link Base64}
+   * @param input original byte array
+   * @return base64 encoded byte array
+   */
+  @ScalarFunction
+  public static byte[] base64Encode(byte[] input) {
+    return Base64.getEncoder().encodeToString(input).getBytes();
+  }
+
+  /**
+   * decode base64 encoded string to bytes using {@link Base64}
+   * @param input base64 encoded string
+   * @return decoded byte array
+   */
+  @ScalarFunction
+  public static byte[] base64Decode(String input) {
+    return Base64.getDecoder().decode(input.getBytes());
+  }
+}
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/common/ObjectSerDeUtils.java 
b/pinot-core/src/main/java/org/apache/pinot/core/common/ObjectSerDeUtils.java
index 45ed146..0367f29 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/common/ObjectSerDeUtils.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/common/ObjectSerDeUtils.java
@@ -40,6 +40,7 @@ import it.unimi.dsi.fastutil.objects.ObjectSet;
 import java.io.ByteArrayOutputStream;
 import java.io.DataOutputStream;
 import java.io.IOException;
+import java.math.BigDecimal;
 import java.nio.ByteBuffer;
 import java.util.ArrayList;
 import java.util.HashMap;
@@ -49,6 +50,7 @@ import java.util.Map;
 import java.util.Set;
 import org.apache.datasketches.memory.Memory;
 import org.apache.datasketches.theta.Sketch;
+import org.apache.pinot.common.function.scalar.DataTypeConversionFunctions;
 import org.apache.pinot.common.utils.StringUtil;
 import org.apache.pinot.core.geospatial.serde.GeometrySerializer;
 import org.apache.pinot.core.query.aggregation.function.customobject.AvgPair;
@@ -94,8 +96,8 @@ public class ObjectSerDeUtils {
     StringSet(18),
     BytesSet(19),
     IdSet(20),
-    List(21);
-
+    List(21),
+    BigDecimal(22);
     private final int _value;
 
     ObjectType(int value) {
@@ -113,6 +115,8 @@ public class ObjectSerDeUtils {
         return ObjectType.Long;
       } else if (value instanceof Double) {
         return ObjectType.Double;
+      } else if (value instanceof BigDecimal) {
+        return ObjectType.BigDecimal;
       } else if (value instanceof DoubleArrayList) {
         return ObjectType.DoubleArrayList;
       } else if (value instanceof AvgPair) {
@@ -850,6 +854,26 @@ public class ObjectSerDeUtils {
     }
   };
 
+  public static final ObjectSerDe<BigDecimal> BIGDECIMAL_SER_DE = new 
ObjectSerDe<BigDecimal>() {
+
+    @Override
+    public byte[] serialize(BigDecimal value) {
+      return DataTypeConversionFunctions.bigDecimalToBytes(value.toString());
+    }
+
+    @Override
+    public BigDecimal deserialize(byte[] bytes) {
+      return new 
BigDecimal(DataTypeConversionFunctions.bytesToBigDecimal(bytes));
+    }
+
+    @Override
+    public BigDecimal deserialize(ByteBuffer byteBuffer) {
+      byte[] bytes = new byte[byteBuffer.remaining()];
+      byteBuffer.get(bytes);
+      return deserialize(bytes);
+    }
+  };
+
   // NOTE: DO NOT change the order, it has to be the same order as the 
ObjectType
   //@formatter:off
   private static final ObjectSerDe[] SER_DES = {
@@ -874,7 +898,8 @@ public class ObjectSerDeUtils {
       STRING_SET_SER_DE,
       BYTES_SET_SER_DE,
       ID_SET_SER_DE,
-      LIST_SER_DE
+      LIST_SER_DE,
+      BIGDECIMAL_SER_DE
   };
   //@formatter:on
 
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
index d74db3e..9093a8c 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
@@ -117,6 +117,8 @@ public class AggregationFunctionFactory {
             return new MaxAggregationFunction(firstArgument);
           case SUM:
             return new SumAggregationFunction(firstArgument);
+          case SUMPRECISION:
+            return new SumPrecisionAggregationFunction(arguments);
           case AVG:
             return new AvgAggregationFunction(firstArgument);
           case MINMAXRANGE:
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumPrecisionAggregationFunction.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumPrecisionAggregationFunction.java
new file mode 100644
index 0000000..8f4897d
--- /dev/null
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumPrecisionAggregationFunction.java
@@ -0,0 +1,180 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.query.aggregation.function;
+
+import java.math.BigDecimal;
+import java.math.MathContext;
+import java.util.List;
+import java.util.Map;
+import org.apache.pinot.common.function.AggregationFunctionType;
+import org.apache.pinot.common.function.scalar.DataTypeConversionFunctions;
+import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.core.common.BlockValSet;
+import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
+import org.apache.pinot.core.query.aggregation.ObjectAggregationResultHolder;
+import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
+import 
org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder;
+import org.apache.pinot.core.query.request.context.ExpressionContext;
+
+
+/**
+ * This function is used for BigDecimal calculcations. It supports the sum 
aggregation using both precision and scale.
+ * The function can be used as SUMPRECISION(column, 10, 2)
+ * Following arguments are supported
+ * bytes column - this is a column which contains big decimal value as bytes
+ * precision - precision to be set to the final result
+ * scale - scale to be set to the final result
+ */
+public class SumPrecisionAggregationFunction extends 
BaseSingleInputAggregationFunction<BigDecimal, BigDecimal> {
+  MathContext _mathContext = new MathContext(0);
+  Integer _scale = null;
+
+  public SumPrecisionAggregationFunction(List<ExpressionContext> arguments) {
+    super(arguments.get(0));
+    int numArguments = arguments.size();
+
+    if (numArguments == 3) {
+      Integer precision = Integer.parseInt(arguments.get(1).getLiteral());
+      _scale = Integer.parseInt(arguments.get(2).getLiteral());
+      _mathContext = new MathContext(precision);
+    } else if (numArguments == 2) {
+      Integer precision = Integer.parseInt(arguments.get(1).getLiteral());
+      _mathContext = new MathContext(precision);
+    }
+  }
+
+  @Override
+  public AggregationFunctionType getType() {
+    return AggregationFunctionType.SUMPRECISION;
+  }
+
+  @Override
+  public AggregationResultHolder createAggregationResultHolder() {
+    return new ObjectAggregationResultHolder();
+  }
+
+  @Override
+  public GroupByResultHolder createGroupByResultHolder(int initialCapacity, 
int maxCapacity) {
+    return new ObjectGroupByResultHolder(initialCapacity, maxCapacity);
+  }
+
+  @Override
+  public void aggregate(int length, AggregationResultHolder 
aggregationResultHolder,
+      Map<ExpressionContext, BlockValSet> blockValSetMap) {
+    byte[][] valueArray = blockValSetMap.get(_expression).getBytesValuesSV();
+    BigDecimal sumValue = getDefaultResult(aggregationResultHolder);
+    for (int i = 0; i < length; i++) {
+      BigDecimal value = 
DataTypeConversionFunctions.bytesToBigDecimalObject(valueArray[i]);
+      sumValue = sumValue.add(value);
+    }
+    aggregationResultHolder.setValue(sumValue);
+  }
+
+  @Override
+  public void aggregateGroupBySV(int length, int[] groupKeyArray, 
GroupByResultHolder groupByResultHolder,
+      Map<ExpressionContext, BlockValSet> blockValSetMap) {
+    byte[][] valueArray = blockValSetMap.get(_expression).getBytesValuesSV();
+    for (int i = 0; i < length; i++) {
+      int groupKey = groupKeyArray[i];
+      BigDecimal groupByResultValue = getDefaultResult(groupByResultHolder, 
groupKey);
+      BigDecimal value =  
DataTypeConversionFunctions.bytesToBigDecimalObject(valueArray[i]);
+      groupByResultValue = groupByResultValue.add(value);
+      groupByResultHolder.setValueForKey(groupKey, groupByResultValue);
+    }
+  }
+
+  @Override
+  public void aggregateGroupByMV(int length, int[][] groupKeysArray, 
GroupByResultHolder groupByResultHolder,
+      Map<ExpressionContext, BlockValSet> blockValSetMap) {
+    byte[][] valueArray = blockValSetMap.get(_expression).getBytesValuesSV();
+    for (int i = 0; i < length; i++) {
+      byte[] value = valueArray[i];
+      for (int groupKey : groupKeysArray[i]) {
+        BigDecimal groupByResultValue = getDefaultResult(groupByResultHolder, 
groupKey);
+        BigDecimal valueBigDecimal = 
DataTypeConversionFunctions.bytesToBigDecimalObject(value);
+        groupByResultValue = groupByResultValue.add(valueBigDecimal);
+        groupByResultHolder.setValueForKey(groupKey, groupByResultValue);
+      }
+    }
+  }
+
+  @Override
+  public BigDecimal extractAggregationResult(AggregationResultHolder 
aggregationResultHolder) {
+    return getDefaultResult(aggregationResultHolder);
+  }
+
+  @Override
+  public BigDecimal extractGroupByResult(GroupByResultHolder 
groupByResultHolder, int groupKey) {
+    return getDefaultResult(groupByResultHolder, groupKey);
+  }
+
+  @Override
+  public BigDecimal merge(BigDecimal intermediateResult1, BigDecimal 
intermediateResult2) {
+    try {
+      return intermediateResult1.add(intermediateResult2);
+    } catch (Exception e) {
+      throw new RuntimeException("Caught Exception while merging results in 
sum with precision function", e);
+    }
+  }
+
+  @Override
+  public boolean isIntermediateResultComparable() {
+    return true;
+  }
+
+  @Override
+  public DataSchema.ColumnDataType getIntermediateResultColumnType() {
+    return DataSchema.ColumnDataType.OBJECT;
+  }
+
+  @Override
+  public DataSchema.ColumnDataType getFinalResultColumnType() {
+    return DataSchema.ColumnDataType.STRING;
+  }
+
+  @Override
+  public BigDecimal extractFinalResult(BigDecimal intermediateResult) {
+    return setScale(new BigDecimal(intermediateResult.toString(), 
_mathContext));
+  }
+
+  public BigDecimal getDefaultResult(AggregationResultHolder 
aggregationResultHolder) {
+    BigDecimal result = aggregationResultHolder.getResult();
+    if (result == null) {
+      result = new BigDecimal(0);
+      aggregationResultHolder.setValue(result);
+    }
+    return result;
+  }
+
+  public BigDecimal getDefaultResult(GroupByResultHolder groupByResultHolder, 
int groupKey) {
+    BigDecimal result = groupByResultHolder.getResult(groupKey);
+    if (result == null) {
+      result = new BigDecimal(0);
+      groupByResultHolder.setValueForKey(groupKey, result);
+    }
+    return result;
+  }
+
+  private BigDecimal setScale(BigDecimal value) {
+    if (_scale != null) {
+      value = value.setScale(_scale, BigDecimal.ROUND_HALF_EVEN);
+    }
+    return value;
+  }
+}
diff --git 
a/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactoryTest.java
 
b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactoryTest.java
index 476c086..f286c0b 100644
--- 
a/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactoryTest.java
+++ 
b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactoryTest.java
@@ -65,6 +65,13 @@ public class AggregationFunctionFactoryTest {
     assertEquals(aggregationFunction.getColumnName(), "sum_column");
     assertEquals(aggregationFunction.getResultColumnName(), 
function.toString());
 
+    function = getFunction("SuMPreCIsiON");
+    aggregationFunction = 
AggregationFunctionFactory.getAggregationFunction(function, 
DUMMY_QUERY_CONTEXT);
+    assertTrue(aggregationFunction instanceof SumPrecisionAggregationFunction);
+    assertEquals(aggregationFunction.getType(), 
AggregationFunctionType.SUMPRECISION);
+    assertEquals(aggregationFunction.getColumnName(), "sumPrecision_column");
+    assertEquals(aggregationFunction.getResultColumnName(), 
function.toString());
+
     function = getFunction("AvG");
     aggregationFunction = 
AggregationFunctionFactory.getAggregationFunction(function, 
DUMMY_QUERY_CONTEXT);
     assertTrue(aggregationFunction instanceof AvgAggregationFunction);
diff --git 
a/pinot-core/src/test/java/org/apache/pinot/queries/SumWithPrecisionTest.java 
b/pinot-core/src/test/java/org/apache/pinot/queries/SumWithPrecisionTest.java
new file mode 100644
index 0000000..8754bc4
--- /dev/null
+++ 
b/pinot-core/src/test/java/org/apache/pinot/queries/SumWithPrecisionTest.java
@@ -0,0 +1,221 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.queries;
+
+import java.io.File;
+import java.io.IOException;
+import java.math.BigDecimal;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Random;
+import org.apache.commons.io.FileUtils;
+import org.apache.pinot.common.function.scalar.DataTypeConversionFunctions;
+import org.apache.pinot.common.segment.ReadMode;
+import org.apache.pinot.core.common.Operator;
+import org.apache.pinot.core.data.readers.GenericRowRecordReader;
+import org.apache.pinot.core.indexsegment.IndexSegment;
+import org.apache.pinot.core.indexsegment.generator.SegmentGeneratorConfig;
+import org.apache.pinot.core.indexsegment.immutable.ImmutableSegment;
+import org.apache.pinot.core.indexsegment.immutable.ImmutableSegmentLoader;
+import org.apache.pinot.core.operator.blocks.IntermediateResultsBlock;
+import org.apache.pinot.core.operator.query.AggregationOperator;
+import 
org.apache.pinot.core.segment.creator.impl.SegmentIndexCreationDriverImpl;
+import org.apache.pinot.spi.config.table.TableConfig;
+import org.apache.pinot.spi.config.table.TableType;
+import org.apache.pinot.spi.data.FieldSpec;
+import org.apache.pinot.spi.data.Schema;
+import org.apache.pinot.spi.data.readers.GenericRow;
+import org.apache.pinot.spi.utils.builder.TableConfigBuilder;
+import org.testng.annotations.AfterClass;
+import org.testng.annotations.BeforeClass;
+import org.testng.annotations.Test;
+
+import static org.testng.Assert.assertEquals;
+import static org.testng.Assert.assertNotNull;
+import static org.testng.Assert.assertTrue;
+
+
+public class SumWithPrecisionTest extends BaseSingleValueQueriesTest {
+  private static final File INDEX_DIR = new File(FileUtils.getTempDirectory(), 
"SumWithPrecisionTest");
+  private static final String RAW_TABLE_NAME = "testTable";
+  private static final String SEGMENT_NAME = "testSegment";
+  private static final Random RANDOM = new Random();
+
+  private static final int NUM_RECORDS = 2000;
+  private static final int MAX_VALUE = 1000;
+
+  private static final String INT_COLUMN = "intColumn";
+  private static final String LONG_COLUMN = "longColumn";
+  private static final String FLOAT_COLUMN = "floatColumn";
+  private static final String DOUBLE_COLUMN = "doubleColumn";
+  private static final String STRING_COLUMN = "stringColumn";
+  private static final String BYTES_COLUMN = "bytesColumn";
+  private static final Schema SCHEMA =
+      new Schema.SchemaBuilder().addSingleValueDimension(INT_COLUMN, 
FieldSpec.DataType.INT)
+          .addSingleValueDimension(LONG_COLUMN, FieldSpec.DataType.LONG)
+          .addSingleValueDimension(FLOAT_COLUMN, FieldSpec.DataType.FLOAT)
+          .addSingleValueDimension(DOUBLE_COLUMN, FieldSpec.DataType.DOUBLE)
+          .addSingleValueDimension(STRING_COLUMN, FieldSpec.DataType.STRING)
+          .addSingleValueDimension(BYTES_COLUMN, 
FieldSpec.DataType.BYTES).build();
+  private static final TableConfig TABLE_CONFIG =
+      new 
TableConfigBuilder(TableType.OFFLINE).setTableName(RAW_TABLE_NAME).build();
+
+  private BigDecimal _aggregatedValuePerSegment;
+  private IndexSegment _indexSegment;
+  private List<IndexSegment> _indexSegments;
+
+  @Override
+  protected String getFilter() {
+    // NOTE: Use a match all filter to switch between 
DictionaryBasedAggregationOperator and AggregationOperator
+    return " WHERE intColumn >= 0";
+  }
+
+  @Override
+  protected IndexSegment getIndexSegment() {
+    return _indexSegment;
+  }
+
+  @Override
+  protected List<IndexSegment> getIndexSegments() {
+    return _indexSegments;
+  }
+
+  @BeforeClass
+  public void setUp()
+      throws Exception {
+    FileUtils.deleteDirectory(INDEX_DIR);
+    _aggregatedValuePerSegment = new BigDecimal(0);
+    List<GenericRow> records = new ArrayList<>(NUM_RECORDS);
+    for (int i = 0; i < NUM_RECORDS; i++) {
+      int value = RANDOM.nextInt(MAX_VALUE);
+      GenericRow record = new GenericRow();
+      record.putValue(INT_COLUMN, value);
+      record.putValue(LONG_COLUMN, (long) value);
+      record.putValue(FLOAT_COLUMN, (float) value);
+      record.putValue(DOUBLE_COLUMN, (double) value);
+      String stringValue = Integer.toString(value);
+      record.putValue(STRING_COLUMN, stringValue);
+      // NOTE: Create fixed-length bytes so that dictionary can be generated
+      BigDecimal bigDecimalValueLeft = new 
BigDecimal(Double.toString(RANDOM.nextDouble()));
+      BigDecimal bigDecimalValueRight = new 
BigDecimal(Double.toString(RANDOM.nextDouble()));
+      BigDecimal bigDecimalValue = 
bigDecimalValueLeft.multiply(bigDecimalValueRight);
+
+      _aggregatedValuePerSegment = 
_aggregatedValuePerSegment.add(bigDecimalValue);
+      byte[] bytesValue = 
DataTypeConversionFunctions.bigDecimalToBytes(bigDecimalValue.toString());
+      record.putValue(BYTES_COLUMN, bytesValue);
+      records.add(record);
+    }
+
+    SegmentGeneratorConfig segmentGeneratorConfig = new 
SegmentGeneratorConfig(TABLE_CONFIG, SCHEMA);
+    segmentGeneratorConfig.setTableName(RAW_TABLE_NAME);
+    segmentGeneratorConfig.setSegmentName(SEGMENT_NAME);
+    segmentGeneratorConfig.setOutDir(INDEX_DIR.getPath());
+
+    SegmentIndexCreationDriverImpl driver = new 
SegmentIndexCreationDriverImpl();
+    driver.init(segmentGeneratorConfig, new GenericRowRecordReader(records));
+    driver.build();
+
+    ImmutableSegment immutableSegment = ImmutableSegmentLoader.load(new 
File(INDEX_DIR, SEGMENT_NAME), ReadMode.mmap);
+    _indexSegment = immutableSegment;
+    _indexSegments = Arrays.asList(immutableSegment, immutableSegment);
+  }
+
+  @Test
+  public void testAggregationOnly() {
+    String query = "SELECT SUMPRECISION(bytesColumn) FROM testTable";
+
+    // Inner segment
+    Operator operator = getOperatorForPqlQuery(query);
+    assertTrue(operator instanceof AggregationOperator);
+    IntermediateResultsBlock resultsBlock = ((AggregationOperator) 
operator).nextBlock();
+    
QueriesTestUtils.testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(),
 NUM_RECORDS, 0, NUM_RECORDS,
+        NUM_RECORDS);
+    List<Object> aggregationResult = resultsBlock.getAggregationResult();
+
+    operator = getOperatorForPqlQueryWithFilter(query);
+    assertTrue(operator instanceof AggregationOperator);
+    IntermediateResultsBlock resultsBlockWithFilter = ((AggregationOperator) 
operator).nextBlock();
+    
QueriesTestUtils.testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(),
 NUM_RECORDS, 0, NUM_RECORDS,
+        NUM_RECORDS);
+    List<Object> aggregationResultWithFilter = 
resultsBlockWithFilter.getAggregationResult();
+
+    assertNotNull(aggregationResult);
+    assertNotNull(aggregationResultWithFilter);
+    assertEquals(aggregationResult, aggregationResultWithFilter);
+    assertEquals(aggregationResult.get(0), _aggregatedValuePerSegment);
+  }
+
+  @Test
+  public void testAggregationWithPrecision() {
+    String query = "SELECT SUMPRECISION(bytesColumn, 6) FROM testTable";
+
+    // Inner segment
+    Operator operator = getOperatorForPqlQuery(query);
+    assertTrue(operator instanceof AggregationOperator);
+    IntermediateResultsBlock resultsBlock = ((AggregationOperator) 
operator).nextBlock();
+    
QueriesTestUtils.testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(),
 NUM_RECORDS, 0, NUM_RECORDS,
+        NUM_RECORDS);
+    List<Object> aggregationResult = resultsBlock.getAggregationResult();
+
+    operator = getOperatorForPqlQueryWithFilter(query);
+    assertTrue(operator instanceof AggregationOperator);
+    IntermediateResultsBlock resultsBlockWithFilter = ((AggregationOperator) 
operator).nextBlock();
+    
QueriesTestUtils.testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(),
 NUM_RECORDS, 0, NUM_RECORDS,
+        NUM_RECORDS);
+    List<Object> aggregationResultWithFilter = 
resultsBlockWithFilter.getAggregationResult();
+
+    assertNotNull(aggregationResult);
+    assertNotNull(aggregationResultWithFilter);
+    assertEquals(aggregationResult, aggregationResultWithFilter);
+    assertTrue(_aggregatedValuePerSegment.subtract((BigDecimal) 
aggregationResult.get(0)).abs().doubleValue() <= 0.1);
+  }
+
+  @Test
+  public void testAggregationWithPrecisionAndScale() {
+    String query = "SELECT SUMPRECISION(bytesColumn, 10, 3) FROM testTable";
+
+    // Inner segment
+    Operator operator = getOperatorForPqlQuery(query);
+    assertTrue(operator instanceof AggregationOperator);
+    IntermediateResultsBlock resultsBlock = ((AggregationOperator) 
operator).nextBlock();
+    
QueriesTestUtils.testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(),
 NUM_RECORDS, 0, NUM_RECORDS,
+        NUM_RECORDS);
+    List<Object> aggregationResult = resultsBlock.getAggregationResult();
+
+    operator = getOperatorForPqlQueryWithFilter(query);
+    assertTrue(operator instanceof AggregationOperator);
+    IntermediateResultsBlock resultsBlockWithFilter = ((AggregationOperator) 
operator).nextBlock();
+    
QueriesTestUtils.testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(),
 NUM_RECORDS, 0, NUM_RECORDS,
+        NUM_RECORDS);
+    List<Object> aggregationResultWithFilter = 
resultsBlockWithFilter.getAggregationResult();
+
+    assertNotNull(aggregationResult);
+    assertNotNull(aggregationResultWithFilter);
+    assertEquals(aggregationResult, aggregationResultWithFilter);
+    assertTrue(_aggregatedValuePerSegment.subtract((BigDecimal) 
aggregationResult.get(0)).abs().doubleValue() <= 0.1);
+  }
+
+  @AfterClass
+  public void tearDown()
+      throws IOException {
+    _indexSegment.destroy();
+    FileUtils.deleteDirectory(INDEX_DIR);
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to