This is an automated email from the ASF dual-hosted git repository.
kishoreg pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-pinot.git
The following commit(s) were added to refs/heads/master by this push:
new deb3891 Add support for Decimal with Precision Sum aggregation (#6053)
deb3891 is described below
commit deb389182209db4a18761be9e20d7dcbf037b16b
Author: Kartik Khare <[email protected]>
AuthorDate: Fri Oct 2 00:57:32 2020 +0530
Add support for Decimal with Precision Sum aggregation (#6053)
* Add support for big decimal
* Add transform function to factory
* Add support for decimal with precision addition
* Add license header
* Add Sum Precision aggregation function
* Remove add with precision transform function
* Add license header
* Refactor: Correct import of Scalar function
* Add function to convert normal string to bigdecimal bytes
* Add test for big decimal
* Add test for big decimal precision
* Add support for scale along with precision
* Add license header
* Add base64 encode functions
* typo fix
* Move arguments logic inside constructor
* set scale and precision only in final result
* Reduce scale bytes from 4 to 2
* Add java docs for sum with precision function
* Rename sumWithPrecision to sumPrecision
* Adding methods to directly take big decimal input
Co-authored-by: Kartik Khare <[email protected]>
---
.../common/function/AggregationFunctionType.java | 1 +
.../scalar/DataTypeConversionFunctions.java | 142 +++++++++++++
.../apache/pinot/core/common/ObjectSerDeUtils.java | 31 ++-
.../function/AggregationFunctionFactory.java | 2 +
.../function/SumPrecisionAggregationFunction.java | 180 +++++++++++++++++
.../function/AggregationFunctionFactoryTest.java | 7 +
.../apache/pinot/queries/SumWithPrecisionTest.java | 221 +++++++++++++++++++++
7 files changed, 581 insertions(+), 3 deletions(-)
diff --git
a/pinot-common/src/main/java/org/apache/pinot/common/function/AggregationFunctionType.java
b/pinot-common/src/main/java/org/apache/pinot/common/function/AggregationFunctionType.java
index aeae907..37517b2 100644
---
a/pinot-common/src/main/java/org/apache/pinot/common/function/AggregationFunctionType.java
+++
b/pinot-common/src/main/java/org/apache/pinot/common/function/AggregationFunctionType.java
@@ -30,6 +30,7 @@ public enum AggregationFunctionType {
MIN("min"),
MAX("max"),
SUM("sum"),
+ SUMPRECISION("sumPrecision"),
AVG("avg"),
MINMAXRANGE("minMaxRange"),
DISTINCTCOUNT("distinctCount"),
diff --git
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/DataTypeConversionFunctions.java
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/DataTypeConversionFunctions.java
new file mode 100644
index 0000000..d9ec3b1
--- /dev/null
+++
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/DataTypeConversionFunctions.java
@@ -0,0 +1,142 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.common.function.scalar;
+
+import java.math.BigDecimal;
+import java.math.BigInteger;
+import java.util.Base64;
+import org.apache.pinot.spi.annotations.ScalarFunction;
+
+
+/**
+ * Contains function to convert a datatype to another datatype.
+ */
+public class DataTypeConversionFunctions {
+ private DataTypeConversionFunctions() {
+
+ }
+
+ /**
+ * Converts big decimal string representation to bytes.
+ * Only scale of upto 2 bytes is supported by the function
+ * @param number big decimal number in plain string. e.g. '1234.12121'
+ * @return The result byte array contains the bytes of the unscaled value
appended to bytes of the scale in BIG ENDIAN order.
+ */
+ @ScalarFunction
+ public static byte[] bigDecimalToBytes(String number) {
+ BigDecimal bigDecimal = new BigDecimal(number);
+ return bigDecimalToBytes(bigDecimal);
+ }
+
+ /**
+ * Converts bytes value representation generated by {@link
#bigDecimalToBytes(String)} back to string big decimal
+ * @param bytes array that contains the bytes of the unscaled value appended
to 2 bytes of the scale in BIG ENDIAN order.
+ * @return plain string representation of big decimal
+ */
+ @ScalarFunction
+ public static String bytesToBigDecimal(byte[] bytes) {
+ BigDecimal number = bytesToBigDecimalObject(bytes);
+ return number.toString();
+ }
+
+ /**
+ * convert simple hex string to byte array
+ * @param hex a plain hex string e.g. 'f0e1a3b2'
+ * @return byte array representation of hex string
+ */
+ @ScalarFunction
+ public static byte[] hexToBytes(String hex) {
+ int len = hex.length();
+ byte[] data = new byte[len / 2];
+ for (int i = 0; i < len; i += 2) {
+ data[i / 2] = (byte) ((Character.digit(hex.charAt(i), 16) << 4) +
Character.digit(hex.charAt(i + 1), 16));
+ }
+ return data;
+ }
+
+ /**
+ * convert simple bytes array to hex string
+ * @param bytes any byte array
+ * @return plain hex string e.g. 'f012be3c'
+ */
+ @ScalarFunction
+ public static String bytesToHex(byte[] bytes) {
+ StringBuilder sb = new StringBuilder();
+ for (byte b : bytes) {
+ sb.append(String.format("%02X ", b));
+ }
+
+ return sb.toString();
+ }
+
+ /**
+ * Converts bytes value representation generated by {@link
#bigDecimalToBytes(String)} back to string big decimal
+ * @param bytes array that contains the bytes of the unscaled value appended
to 2 bytes of the scale in BIG ENDIAN order.
+ * @return big decimal object
+ */
+ public static BigDecimal bytesToBigDecimalObject(byte[] bytes) {
+ int scale = 0;
+ scale += (((int) bytes[0]) << (8));
+ scale += (((int) bytes[1]));
+ byte[] vals = new byte[bytes.length - 2];
+ System.arraycopy(bytes, 2, vals, 0, vals.length);
+ BigInteger unscaled = new BigInteger(vals);
+ BigDecimal number = new BigDecimal(unscaled, scale);
+ return number;
+ }
+
+ /**
+ * Converts big decimal string representation to bytes.
+ * Only scale of upto 2 bytes is supported by the function
+ * @param bigDecimal big decimal number object
+ * @return The result byte array contains the bytes of the unscaled value
appended to bytes of the scale in BIG ENDIAN order.
+ */
+ public static byte[] bigDecimalToBytes(BigDecimal bigDecimal) {
+ int scale = bigDecimal.scale();
+ BigInteger unscaled = bigDecimal.unscaledValue();
+ byte[] value = unscaled.toByteArray();
+ byte[] bigDecimalBytesArray = new byte[value.length + 2];
+
+ bigDecimalBytesArray[0] = (byte) (scale >>> 8);
+ bigDecimalBytesArray[1] = (byte) (scale);
+
+ System.arraycopy(value, 0, bigDecimalBytesArray, 2, value.length);
+ return bigDecimalBytesArray;
+ }
+
+ /**
+ * encode byte array to base64 using {@link Base64}
+ * @param input original byte array
+ * @return base64 encoded byte array
+ */
+ @ScalarFunction
+ public static byte[] base64Encode(byte[] input) {
+ return Base64.getEncoder().encodeToString(input).getBytes();
+ }
+
+ /**
+ * decode base64 encoded string to bytes using {@link Base64}
+ * @param input base64 encoded string
+ * @return decoded byte array
+ */
+ @ScalarFunction
+ public static byte[] base64Decode(String input) {
+ return Base64.getDecoder().decode(input.getBytes());
+ }
+}
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/common/ObjectSerDeUtils.java
b/pinot-core/src/main/java/org/apache/pinot/core/common/ObjectSerDeUtils.java
index 45ed146..0367f29 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/common/ObjectSerDeUtils.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/common/ObjectSerDeUtils.java
@@ -40,6 +40,7 @@ import it.unimi.dsi.fastutil.objects.ObjectSet;
import java.io.ByteArrayOutputStream;
import java.io.DataOutputStream;
import java.io.IOException;
+import java.math.BigDecimal;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.HashMap;
@@ -49,6 +50,7 @@ import java.util.Map;
import java.util.Set;
import org.apache.datasketches.memory.Memory;
import org.apache.datasketches.theta.Sketch;
+import org.apache.pinot.common.function.scalar.DataTypeConversionFunctions;
import org.apache.pinot.common.utils.StringUtil;
import org.apache.pinot.core.geospatial.serde.GeometrySerializer;
import org.apache.pinot.core.query.aggregation.function.customobject.AvgPair;
@@ -94,8 +96,8 @@ public class ObjectSerDeUtils {
StringSet(18),
BytesSet(19),
IdSet(20),
- List(21);
-
+ List(21),
+ BigDecimal(22);
private final int _value;
ObjectType(int value) {
@@ -113,6 +115,8 @@ public class ObjectSerDeUtils {
return ObjectType.Long;
} else if (value instanceof Double) {
return ObjectType.Double;
+ } else if (value instanceof BigDecimal) {
+ return ObjectType.BigDecimal;
} else if (value instanceof DoubleArrayList) {
return ObjectType.DoubleArrayList;
} else if (value instanceof AvgPair) {
@@ -850,6 +854,26 @@ public class ObjectSerDeUtils {
}
};
+ public static final ObjectSerDe<BigDecimal> BIGDECIMAL_SER_DE = new
ObjectSerDe<BigDecimal>() {
+
+ @Override
+ public byte[] serialize(BigDecimal value) {
+ return DataTypeConversionFunctions.bigDecimalToBytes(value.toString());
+ }
+
+ @Override
+ public BigDecimal deserialize(byte[] bytes) {
+ return new
BigDecimal(DataTypeConversionFunctions.bytesToBigDecimal(bytes));
+ }
+
+ @Override
+ public BigDecimal deserialize(ByteBuffer byteBuffer) {
+ byte[] bytes = new byte[byteBuffer.remaining()];
+ byteBuffer.get(bytes);
+ return deserialize(bytes);
+ }
+ };
+
// NOTE: DO NOT change the order, it has to be the same order as the
ObjectType
//@formatter:off
private static final ObjectSerDe[] SER_DES = {
@@ -874,7 +898,8 @@ public class ObjectSerDeUtils {
STRING_SET_SER_DE,
BYTES_SET_SER_DE,
ID_SET_SER_DE,
- LIST_SER_DE
+ LIST_SER_DE,
+ BIGDECIMAL_SER_DE
};
//@formatter:on
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
index d74db3e..9093a8c 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
@@ -117,6 +117,8 @@ public class AggregationFunctionFactory {
return new MaxAggregationFunction(firstArgument);
case SUM:
return new SumAggregationFunction(firstArgument);
+ case SUMPRECISION:
+ return new SumPrecisionAggregationFunction(arguments);
case AVG:
return new AvgAggregationFunction(firstArgument);
case MINMAXRANGE:
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumPrecisionAggregationFunction.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumPrecisionAggregationFunction.java
new file mode 100644
index 0000000..8f4897d
--- /dev/null
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumPrecisionAggregationFunction.java
@@ -0,0 +1,180 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.query.aggregation.function;
+
+import java.math.BigDecimal;
+import java.math.MathContext;
+import java.util.List;
+import java.util.Map;
+import org.apache.pinot.common.function.AggregationFunctionType;
+import org.apache.pinot.common.function.scalar.DataTypeConversionFunctions;
+import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.core.common.BlockValSet;
+import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
+import org.apache.pinot.core.query.aggregation.ObjectAggregationResultHolder;
+import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
+import
org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder;
+import org.apache.pinot.core.query.request.context.ExpressionContext;
+
+
+/**
+ * This function is used for BigDecimal calculcations. It supports the sum
aggregation using both precision and scale.
+ * The function can be used as SUMPRECISION(column, 10, 2)
+ * Following arguments are supported
+ * bytes column - this is a column which contains big decimal value as bytes
+ * precision - precision to be set to the final result
+ * scale - scale to be set to the final result
+ */
+public class SumPrecisionAggregationFunction extends
BaseSingleInputAggregationFunction<BigDecimal, BigDecimal> {
+ MathContext _mathContext = new MathContext(0);
+ Integer _scale = null;
+
+ public SumPrecisionAggregationFunction(List<ExpressionContext> arguments) {
+ super(arguments.get(0));
+ int numArguments = arguments.size();
+
+ if (numArguments == 3) {
+ Integer precision = Integer.parseInt(arguments.get(1).getLiteral());
+ _scale = Integer.parseInt(arguments.get(2).getLiteral());
+ _mathContext = new MathContext(precision);
+ } else if (numArguments == 2) {
+ Integer precision = Integer.parseInt(arguments.get(1).getLiteral());
+ _mathContext = new MathContext(precision);
+ }
+ }
+
+ @Override
+ public AggregationFunctionType getType() {
+ return AggregationFunctionType.SUMPRECISION;
+ }
+
+ @Override
+ public AggregationResultHolder createAggregationResultHolder() {
+ return new ObjectAggregationResultHolder();
+ }
+
+ @Override
+ public GroupByResultHolder createGroupByResultHolder(int initialCapacity,
int maxCapacity) {
+ return new ObjectGroupByResultHolder(initialCapacity, maxCapacity);
+ }
+
+ @Override
+ public void aggregate(int length, AggregationResultHolder
aggregationResultHolder,
+ Map<ExpressionContext, BlockValSet> blockValSetMap) {
+ byte[][] valueArray = blockValSetMap.get(_expression).getBytesValuesSV();
+ BigDecimal sumValue = getDefaultResult(aggregationResultHolder);
+ for (int i = 0; i < length; i++) {
+ BigDecimal value =
DataTypeConversionFunctions.bytesToBigDecimalObject(valueArray[i]);
+ sumValue = sumValue.add(value);
+ }
+ aggregationResultHolder.setValue(sumValue);
+ }
+
+ @Override
+ public void aggregateGroupBySV(int length, int[] groupKeyArray,
GroupByResultHolder groupByResultHolder,
+ Map<ExpressionContext, BlockValSet> blockValSetMap) {
+ byte[][] valueArray = blockValSetMap.get(_expression).getBytesValuesSV();
+ for (int i = 0; i < length; i++) {
+ int groupKey = groupKeyArray[i];
+ BigDecimal groupByResultValue = getDefaultResult(groupByResultHolder,
groupKey);
+ BigDecimal value =
DataTypeConversionFunctions.bytesToBigDecimalObject(valueArray[i]);
+ groupByResultValue = groupByResultValue.add(value);
+ groupByResultHolder.setValueForKey(groupKey, groupByResultValue);
+ }
+ }
+
+ @Override
+ public void aggregateGroupByMV(int length, int[][] groupKeysArray,
GroupByResultHolder groupByResultHolder,
+ Map<ExpressionContext, BlockValSet> blockValSetMap) {
+ byte[][] valueArray = blockValSetMap.get(_expression).getBytesValuesSV();
+ for (int i = 0; i < length; i++) {
+ byte[] value = valueArray[i];
+ for (int groupKey : groupKeysArray[i]) {
+ BigDecimal groupByResultValue = getDefaultResult(groupByResultHolder,
groupKey);
+ BigDecimal valueBigDecimal =
DataTypeConversionFunctions.bytesToBigDecimalObject(value);
+ groupByResultValue = groupByResultValue.add(valueBigDecimal);
+ groupByResultHolder.setValueForKey(groupKey, groupByResultValue);
+ }
+ }
+ }
+
+ @Override
+ public BigDecimal extractAggregationResult(AggregationResultHolder
aggregationResultHolder) {
+ return getDefaultResult(aggregationResultHolder);
+ }
+
+ @Override
+ public BigDecimal extractGroupByResult(GroupByResultHolder
groupByResultHolder, int groupKey) {
+ return getDefaultResult(groupByResultHolder, groupKey);
+ }
+
+ @Override
+ public BigDecimal merge(BigDecimal intermediateResult1, BigDecimal
intermediateResult2) {
+ try {
+ return intermediateResult1.add(intermediateResult2);
+ } catch (Exception e) {
+ throw new RuntimeException("Caught Exception while merging results in
sum with precision function", e);
+ }
+ }
+
+ @Override
+ public boolean isIntermediateResultComparable() {
+ return true;
+ }
+
+ @Override
+ public DataSchema.ColumnDataType getIntermediateResultColumnType() {
+ return DataSchema.ColumnDataType.OBJECT;
+ }
+
+ @Override
+ public DataSchema.ColumnDataType getFinalResultColumnType() {
+ return DataSchema.ColumnDataType.STRING;
+ }
+
+ @Override
+ public BigDecimal extractFinalResult(BigDecimal intermediateResult) {
+ return setScale(new BigDecimal(intermediateResult.toString(),
_mathContext));
+ }
+
+ public BigDecimal getDefaultResult(AggregationResultHolder
aggregationResultHolder) {
+ BigDecimal result = aggregationResultHolder.getResult();
+ if (result == null) {
+ result = new BigDecimal(0);
+ aggregationResultHolder.setValue(result);
+ }
+ return result;
+ }
+
+ public BigDecimal getDefaultResult(GroupByResultHolder groupByResultHolder,
int groupKey) {
+ BigDecimal result = groupByResultHolder.getResult(groupKey);
+ if (result == null) {
+ result = new BigDecimal(0);
+ groupByResultHolder.setValueForKey(groupKey, result);
+ }
+ return result;
+ }
+
+ private BigDecimal setScale(BigDecimal value) {
+ if (_scale != null) {
+ value = value.setScale(_scale, BigDecimal.ROUND_HALF_EVEN);
+ }
+ return value;
+ }
+}
diff --git
a/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactoryTest.java
b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactoryTest.java
index 476c086..f286c0b 100644
---
a/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactoryTest.java
+++
b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactoryTest.java
@@ -65,6 +65,13 @@ public class AggregationFunctionFactoryTest {
assertEquals(aggregationFunction.getColumnName(), "sum_column");
assertEquals(aggregationFunction.getResultColumnName(),
function.toString());
+ function = getFunction("SuMPreCIsiON");
+ aggregationFunction =
AggregationFunctionFactory.getAggregationFunction(function,
DUMMY_QUERY_CONTEXT);
+ assertTrue(aggregationFunction instanceof SumPrecisionAggregationFunction);
+ assertEquals(aggregationFunction.getType(),
AggregationFunctionType.SUMPRECISION);
+ assertEquals(aggregationFunction.getColumnName(), "sumPrecision_column");
+ assertEquals(aggregationFunction.getResultColumnName(),
function.toString());
+
function = getFunction("AvG");
aggregationFunction =
AggregationFunctionFactory.getAggregationFunction(function,
DUMMY_QUERY_CONTEXT);
assertTrue(aggregationFunction instanceof AvgAggregationFunction);
diff --git
a/pinot-core/src/test/java/org/apache/pinot/queries/SumWithPrecisionTest.java
b/pinot-core/src/test/java/org/apache/pinot/queries/SumWithPrecisionTest.java
new file mode 100644
index 0000000..8754bc4
--- /dev/null
+++
b/pinot-core/src/test/java/org/apache/pinot/queries/SumWithPrecisionTest.java
@@ -0,0 +1,221 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.queries;
+
+import java.io.File;
+import java.io.IOException;
+import java.math.BigDecimal;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Random;
+import org.apache.commons.io.FileUtils;
+import org.apache.pinot.common.function.scalar.DataTypeConversionFunctions;
+import org.apache.pinot.common.segment.ReadMode;
+import org.apache.pinot.core.common.Operator;
+import org.apache.pinot.core.data.readers.GenericRowRecordReader;
+import org.apache.pinot.core.indexsegment.IndexSegment;
+import org.apache.pinot.core.indexsegment.generator.SegmentGeneratorConfig;
+import org.apache.pinot.core.indexsegment.immutable.ImmutableSegment;
+import org.apache.pinot.core.indexsegment.immutable.ImmutableSegmentLoader;
+import org.apache.pinot.core.operator.blocks.IntermediateResultsBlock;
+import org.apache.pinot.core.operator.query.AggregationOperator;
+import
org.apache.pinot.core.segment.creator.impl.SegmentIndexCreationDriverImpl;
+import org.apache.pinot.spi.config.table.TableConfig;
+import org.apache.pinot.spi.config.table.TableType;
+import org.apache.pinot.spi.data.FieldSpec;
+import org.apache.pinot.spi.data.Schema;
+import org.apache.pinot.spi.data.readers.GenericRow;
+import org.apache.pinot.spi.utils.builder.TableConfigBuilder;
+import org.testng.annotations.AfterClass;
+import org.testng.annotations.BeforeClass;
+import org.testng.annotations.Test;
+
+import static org.testng.Assert.assertEquals;
+import static org.testng.Assert.assertNotNull;
+import static org.testng.Assert.assertTrue;
+
+
+public class SumWithPrecisionTest extends BaseSingleValueQueriesTest {
+ private static final File INDEX_DIR = new File(FileUtils.getTempDirectory(),
"SumWithPrecisionTest");
+ private static final String RAW_TABLE_NAME = "testTable";
+ private static final String SEGMENT_NAME = "testSegment";
+ private static final Random RANDOM = new Random();
+
+ private static final int NUM_RECORDS = 2000;
+ private static final int MAX_VALUE = 1000;
+
+ private static final String INT_COLUMN = "intColumn";
+ private static final String LONG_COLUMN = "longColumn";
+ private static final String FLOAT_COLUMN = "floatColumn";
+ private static final String DOUBLE_COLUMN = "doubleColumn";
+ private static final String STRING_COLUMN = "stringColumn";
+ private static final String BYTES_COLUMN = "bytesColumn";
+ private static final Schema SCHEMA =
+ new Schema.SchemaBuilder().addSingleValueDimension(INT_COLUMN,
FieldSpec.DataType.INT)
+ .addSingleValueDimension(LONG_COLUMN, FieldSpec.DataType.LONG)
+ .addSingleValueDimension(FLOAT_COLUMN, FieldSpec.DataType.FLOAT)
+ .addSingleValueDimension(DOUBLE_COLUMN, FieldSpec.DataType.DOUBLE)
+ .addSingleValueDimension(STRING_COLUMN, FieldSpec.DataType.STRING)
+ .addSingleValueDimension(BYTES_COLUMN,
FieldSpec.DataType.BYTES).build();
+ private static final TableConfig TABLE_CONFIG =
+ new
TableConfigBuilder(TableType.OFFLINE).setTableName(RAW_TABLE_NAME).build();
+
+ private BigDecimal _aggregatedValuePerSegment;
+ private IndexSegment _indexSegment;
+ private List<IndexSegment> _indexSegments;
+
+ @Override
+ protected String getFilter() {
+ // NOTE: Use a match all filter to switch between
DictionaryBasedAggregationOperator and AggregationOperator
+ return " WHERE intColumn >= 0";
+ }
+
+ @Override
+ protected IndexSegment getIndexSegment() {
+ return _indexSegment;
+ }
+
+ @Override
+ protected List<IndexSegment> getIndexSegments() {
+ return _indexSegments;
+ }
+
+ @BeforeClass
+ public void setUp()
+ throws Exception {
+ FileUtils.deleteDirectory(INDEX_DIR);
+ _aggregatedValuePerSegment = new BigDecimal(0);
+ List<GenericRow> records = new ArrayList<>(NUM_RECORDS);
+ for (int i = 0; i < NUM_RECORDS; i++) {
+ int value = RANDOM.nextInt(MAX_VALUE);
+ GenericRow record = new GenericRow();
+ record.putValue(INT_COLUMN, value);
+ record.putValue(LONG_COLUMN, (long) value);
+ record.putValue(FLOAT_COLUMN, (float) value);
+ record.putValue(DOUBLE_COLUMN, (double) value);
+ String stringValue = Integer.toString(value);
+ record.putValue(STRING_COLUMN, stringValue);
+ // NOTE: Create fixed-length bytes so that dictionary can be generated
+ BigDecimal bigDecimalValueLeft = new
BigDecimal(Double.toString(RANDOM.nextDouble()));
+ BigDecimal bigDecimalValueRight = new
BigDecimal(Double.toString(RANDOM.nextDouble()));
+ BigDecimal bigDecimalValue =
bigDecimalValueLeft.multiply(bigDecimalValueRight);
+
+ _aggregatedValuePerSegment =
_aggregatedValuePerSegment.add(bigDecimalValue);
+ byte[] bytesValue =
DataTypeConversionFunctions.bigDecimalToBytes(bigDecimalValue.toString());
+ record.putValue(BYTES_COLUMN, bytesValue);
+ records.add(record);
+ }
+
+ SegmentGeneratorConfig segmentGeneratorConfig = new
SegmentGeneratorConfig(TABLE_CONFIG, SCHEMA);
+ segmentGeneratorConfig.setTableName(RAW_TABLE_NAME);
+ segmentGeneratorConfig.setSegmentName(SEGMENT_NAME);
+ segmentGeneratorConfig.setOutDir(INDEX_DIR.getPath());
+
+ SegmentIndexCreationDriverImpl driver = new
SegmentIndexCreationDriverImpl();
+ driver.init(segmentGeneratorConfig, new GenericRowRecordReader(records));
+ driver.build();
+
+ ImmutableSegment immutableSegment = ImmutableSegmentLoader.load(new
File(INDEX_DIR, SEGMENT_NAME), ReadMode.mmap);
+ _indexSegment = immutableSegment;
+ _indexSegments = Arrays.asList(immutableSegment, immutableSegment);
+ }
+
+ @Test
+ public void testAggregationOnly() {
+ String query = "SELECT SUMPRECISION(bytesColumn) FROM testTable";
+
+ // Inner segment
+ Operator operator = getOperatorForPqlQuery(query);
+ assertTrue(operator instanceof AggregationOperator);
+ IntermediateResultsBlock resultsBlock = ((AggregationOperator)
operator).nextBlock();
+
QueriesTestUtils.testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(),
NUM_RECORDS, 0, NUM_RECORDS,
+ NUM_RECORDS);
+ List<Object> aggregationResult = resultsBlock.getAggregationResult();
+
+ operator = getOperatorForPqlQueryWithFilter(query);
+ assertTrue(operator instanceof AggregationOperator);
+ IntermediateResultsBlock resultsBlockWithFilter = ((AggregationOperator)
operator).nextBlock();
+
QueriesTestUtils.testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(),
NUM_RECORDS, 0, NUM_RECORDS,
+ NUM_RECORDS);
+ List<Object> aggregationResultWithFilter =
resultsBlockWithFilter.getAggregationResult();
+
+ assertNotNull(aggregationResult);
+ assertNotNull(aggregationResultWithFilter);
+ assertEquals(aggregationResult, aggregationResultWithFilter);
+ assertEquals(aggregationResult.get(0), _aggregatedValuePerSegment);
+ }
+
+ @Test
+ public void testAggregationWithPrecision() {
+ String query = "SELECT SUMPRECISION(bytesColumn, 6) FROM testTable";
+
+ // Inner segment
+ Operator operator = getOperatorForPqlQuery(query);
+ assertTrue(operator instanceof AggregationOperator);
+ IntermediateResultsBlock resultsBlock = ((AggregationOperator)
operator).nextBlock();
+
QueriesTestUtils.testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(),
NUM_RECORDS, 0, NUM_RECORDS,
+ NUM_RECORDS);
+ List<Object> aggregationResult = resultsBlock.getAggregationResult();
+
+ operator = getOperatorForPqlQueryWithFilter(query);
+ assertTrue(operator instanceof AggregationOperator);
+ IntermediateResultsBlock resultsBlockWithFilter = ((AggregationOperator)
operator).nextBlock();
+
QueriesTestUtils.testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(),
NUM_RECORDS, 0, NUM_RECORDS,
+ NUM_RECORDS);
+ List<Object> aggregationResultWithFilter =
resultsBlockWithFilter.getAggregationResult();
+
+ assertNotNull(aggregationResult);
+ assertNotNull(aggregationResultWithFilter);
+ assertEquals(aggregationResult, aggregationResultWithFilter);
+ assertTrue(_aggregatedValuePerSegment.subtract((BigDecimal)
aggregationResult.get(0)).abs().doubleValue() <= 0.1);
+ }
+
+ @Test
+ public void testAggregationWithPrecisionAndScale() {
+ String query = "SELECT SUMPRECISION(bytesColumn, 10, 3) FROM testTable";
+
+ // Inner segment
+ Operator operator = getOperatorForPqlQuery(query);
+ assertTrue(operator instanceof AggregationOperator);
+ IntermediateResultsBlock resultsBlock = ((AggregationOperator)
operator).nextBlock();
+
QueriesTestUtils.testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(),
NUM_RECORDS, 0, NUM_RECORDS,
+ NUM_RECORDS);
+ List<Object> aggregationResult = resultsBlock.getAggregationResult();
+
+ operator = getOperatorForPqlQueryWithFilter(query);
+ assertTrue(operator instanceof AggregationOperator);
+ IntermediateResultsBlock resultsBlockWithFilter = ((AggregationOperator)
operator).nextBlock();
+
QueriesTestUtils.testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(),
NUM_RECORDS, 0, NUM_RECORDS,
+ NUM_RECORDS);
+ List<Object> aggregationResultWithFilter =
resultsBlockWithFilter.getAggregationResult();
+
+ assertNotNull(aggregationResult);
+ assertNotNull(aggregationResultWithFilter);
+ assertEquals(aggregationResult, aggregationResultWithFilter);
+ assertTrue(_aggregatedValuePerSegment.subtract((BigDecimal)
aggregationResult.get(0)).abs().doubleValue() <= 0.1);
+ }
+
+ @AfterClass
+ public void tearDown()
+ throws IOException {
+ _indexSegment.destroy();
+ FileUtils.deleteDirectory(INDEX_DIR);
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]