This is an automated email from the ASF dual-hosted git repository. jackie pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push: new b71f38a Infer the data type for LiteralTransformFunction (#7332) b71f38a is described below commit b71f38a2b2ed7ff715b03f4b780105653f726969 Author: Xiaotian (Jackie) Jiang <17555551+jackie-ji...@users.noreply.github.com> AuthorDate: Fri Aug 20 16:30:54 2021 -0700 Infer the data type for LiteralTransformFunction (#7332) - Infer the data type for `LiteralTransformFunction` so that the parent transform function can read the correct data type - Fix the data type handling for `CaseTransformFunction` and use `STRING` type to handle non-numeric types - For `BinaryOperatorTransformFunction`, use the lhs expression as the main data type, and match rhs data type with lhs, e.g. - stringCol > 123 will use string comparison - intCol > '123' will use integer comparison --- .../function/BinaryOperatorTransformFunction.java | 65 ++---------- .../transform/function/CaseTransformFunction.java | 118 ++++----------------- .../function/LiteralTransformFunction.java | 68 +++++++++--- .../function/CaseTransformFunctionTest.java | 75 ++++++++----- .../function/LiteralTransformFunctionTest.java | 20 ++-- 5 files changed, 135 insertions(+), 211 deletions(-) diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/BinaryOperatorTransformFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/BinaryOperatorTransformFunction.java index df21c26..a610ace 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/BinaryOperatorTransformFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/BinaryOperatorTransformFunction.java @@ -44,8 +44,8 @@ public abstract class BinaryOperatorTransformFunction extends BaseTransformFunct @Override public void init(List<TransformFunction> arguments, Map<String, DataSource> dataSourceMap) { // Check that there are exact 2 arguments - Preconditions - .checkArgument(arguments.size() == 2, "Exact 2 arguments are required for binary operator transform function"); + Preconditions.checkArgument(arguments.size() == 2, + "Exact 2 arguments are required for binary operator transform function"); _leftTransformFunction = arguments.get(0); _rightTransformFunction = arguments.get(1); _leftStoredType = _leftTransformFunction.getResultMetadata().getDataType().getStoredType(); @@ -93,7 +93,7 @@ public abstract class BinaryOperatorTransformFunction extends BaseTransformFunct case FLOAT: float[] rightFloatValues = _rightTransformFunction.transformToFloatValuesSV(projectionBlock); for (int i = 0; i < length; i++) { - _results[i] = getIntResult(Float.compare(leftIntValues[i], rightFloatValues[i])); + _results[i] = getIntResult(Double.compare(leftIntValues[i], rightFloatValues[i])); } break; case DOUBLE: @@ -261,62 +261,9 @@ public abstract class BinaryOperatorTransformFunction extends BaseTransformFunct break; case STRING: String[] leftStringValues = _leftTransformFunction.transformToStringValuesSV(projectionBlock); - switch (_rightStoredType) { - case INT: - int[] rightIntValues = _rightTransformFunction.transformToIntValuesSV(projectionBlock); - for (int i = 0; i < length; i++) { - try { - _results[i] = - getIntResult(new BigDecimal(leftStringValues[i]).compareTo(BigDecimal.valueOf(rightIntValues[i]))); - } catch (NumberFormatException e) { - _results[i] = 0; - } - } - break; - case LONG: - long[] rightLongValues = _rightTransformFunction.transformToLongValuesSV(projectionBlock); - for (int i = 0; i < length; i++) { - try { - _results[i] = - getIntResult(new BigDecimal(leftStringValues[i]).compareTo(BigDecimal.valueOf(rightLongValues[i]))); - } catch (NumberFormatException e) { - _results[i] = 0; - } - } - break; - case FLOAT: - float[] rightFloatValues = _rightTransformFunction.transformToFloatValuesSV(projectionBlock); - for (int i = 0; i < length; i++) { - try { - _results[i] = getIntResult( - new BigDecimal(leftStringValues[i]).compareTo(BigDecimal.valueOf(rightFloatValues[i]))); - } catch (NumberFormatException e) { - _results[i] = 0; - } - } - break; - case DOUBLE: - double[] rightDoubleValues = _rightTransformFunction.transformToDoubleValuesSV(projectionBlock); - for (int i = 0; i < length; i++) { - try { - _results[i] = getIntResult( - new BigDecimal(leftStringValues[i]).compareTo(BigDecimal.valueOf(rightDoubleValues[i]))); - } catch (NumberFormatException e) { - _results[i] = 0; - } - } - break; - case STRING: - String[] rightStringValues = _rightTransformFunction.transformToStringValuesSV(projectionBlock); - for (int i = 0; i < length; i++) { - _results[i] = getIntResult(leftStringValues[i].compareTo(rightStringValues[i])); - } - break; - default: - throw new IllegalStateException(String.format( - "Unsupported data type for comparison: [Left Transform Function [%s] result type is [%s], Right Transform Function [%s] result type is [%s]]", - _leftTransformFunction.getName(), _leftStoredType, _rightTransformFunction.getName(), - _rightStoredType)); + String[] rightStringValues = _rightTransformFunction.transformToStringValuesSV(projectionBlock); + for (int i = 0; i < length; i++) { + _results[i] = getIntResult(leftStringValues[i].compareTo(rightStringValues[i])); } break; case BYTES: diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/CaseTransformFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/CaseTransformFunction.java index e04db32..7de3a15 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/CaseTransformFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/CaseTransformFunction.java @@ -104,147 +104,71 @@ public class CaseTransformFunction extends BaseTransformFunction { Preconditions.checkState(thenStatementResultMetadata.isSingleValue(), String.format("Unsupported multi-value expression in the THEN clause of index: %d", i)); DataType thenStatementDataType = thenStatementResultMetadata.getDataType(); + + // Upcast the data type to cover all the data types in THEN and ELSE clauses if they don't match + // For numeric types: + // - INT & LONG -> LONG + // - INT & FLOAT/DOUBLE -> DOUBLE + // - LONG & FLOAT/DOUBLE -> DOUBLE (might lose precision) + // - FLOAT & DOUBLE -> DOUBLE + // Use STRING to handle non-numeric types + if (thenStatementDataType == dataType) { + continue; + } switch (dataType) { case INT: - if (thenStatement instanceof LiteralTransformFunction) { - dataType = LiteralTransformFunction.inferLiteralDataType((LiteralTransformFunction) thenStatement); - break; - } switch (thenStatementDataType) { - case INT: case LONG: + dataType = DataType.LONG; + break; case FLOAT: case DOUBLE: - case STRING: - dataType = thenStatementDataType; + dataType = DataType.DOUBLE; break; default: - throw new IllegalStateException(String - .format("Incompatible expression type: %s in the THEN clause of index: %d, main type: %s", - thenStatementDataType, i, dataType)); + dataType = DataType.STRING; + break; } break; case LONG: - if (thenStatement instanceof LiteralTransformFunction) { - DataType literalDataType = - LiteralTransformFunction.inferLiteralDataType((LiteralTransformFunction) thenStatement); - switch (literalDataType) { - case INT: - case LONG: - break; - case FLOAT: - case DOUBLE: - dataType = DataType.DOUBLE; - break; - default: - dataType = literalDataType; - } - break; - } switch (thenStatementDataType) { case INT: - case LONG: break; case FLOAT: case DOUBLE: dataType = DataType.DOUBLE; break; - case STRING: + default: dataType = DataType.STRING; break; - default: - throw new IllegalStateException(String - .format("Incompatible expression type: %s in the THEN clause of index: %d, main type: %s", - thenStatementDataType, i, dataType)); } break; case FLOAT: - if (thenStatement instanceof LiteralTransformFunction) { - DataType literalDataType = - LiteralTransformFunction.inferLiteralDataType((LiteralTransformFunction) thenStatement); - switch (literalDataType) { - case INT: - case FLOAT: - break; - case LONG: - case DOUBLE: - dataType = DataType.DOUBLE; - break; - default: - dataType = literalDataType; - } - break; - } switch (thenStatementDataType) { case INT: - case FLOAT: - break; case LONG: case DOUBLE: dataType = DataType.DOUBLE; break; - case STRING: + default: dataType = DataType.STRING; break; - default: - throw new IllegalStateException(String - .format("Incompatible expression type: %s in the THEN clause of index: %d, main type: %s", - thenStatementDataType, i, dataType)); } break; case DOUBLE: - if (thenStatement instanceof LiteralTransformFunction) { - DataType literalDataType = - LiteralTransformFunction.inferLiteralDataType((LiteralTransformFunction) thenStatement); - switch (literalDataType) { - case INT: - case LONG: - case FLOAT: - case DOUBLE: - break; - default: - dataType = literalDataType; - } - break; - } switch (thenStatementDataType) { case INT: case FLOAT: case LONG: - case DOUBLE: - break; - case STRING: - dataType = thenStatementDataType; break; default: - throw new IllegalStateException(String - .format("Incompatible expression type: %s in the THEN clause of index: %d, main type: %s", - thenStatementDataType, i, dataType)); - } - break; - case STRING: - if (thenStatement instanceof LiteralTransformFunction) { - break; - } - switch (thenStatementDataType) { - case INT: - case FLOAT: - case LONG: - case DOUBLE: - case STRING: + dataType = DataType.STRING; break; - default: - throw new IllegalStateException(String - .format("Incompatible expression type: %s in the THEN clause of index: %d, main type: %s", - thenStatementDataType, i, dataType)); } break; default: - if (thenStatementDataType != dataType) { - throw new IllegalStateException(String - .format("Incompatible expression type: %s in the THEN clause of index: %d, main type: %s", - thenStatementDataType, i, dataType)); - } + dataType = DataType.STRING; + break; } } return new TransformResultMetadata(dataType, true, false); diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/LiteralTransformFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/LiteralTransformFunction.java index fdf6580..dd4606a 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/LiteralTransformFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/LiteralTransformFunction.java @@ -18,7 +18,9 @@ */ package org.apache.pinot.core.operator.transform.function; +import com.google.common.annotations.VisibleForTesting; import java.math.BigDecimal; +import java.sql.Timestamp; import java.util.Arrays; import java.util.List; import java.util.Map; @@ -28,16 +30,19 @@ import org.apache.pinot.core.operator.transform.TransformResultMetadata; import org.apache.pinot.core.plan.DocIdSetPlanNode; import org.apache.pinot.segment.spi.datasource.DataSource; import org.apache.pinot.segment.spi.index.reader.Dictionary; -import org.apache.pinot.spi.data.FieldSpec; +import org.apache.pinot.spi.data.FieldSpec.DataType; import org.apache.pinot.spi.utils.BytesUtils; /** * The <code>LiteralTransformFunction</code> class is a special transform function which is a wrapper on top of a - * LITERAL, and only supports {@link #getLiteral()}. + * LITERAL. The data type is inferred from the literal string. + * TODO: Preserve the type of the literal instead of inferring the type from the string */ public class LiteralTransformFunction implements TransformFunction { private final String _literal; + private final DataType _dataType; + private int[] _intResult; private long[] _longResult; private float[] _floatResult; @@ -47,24 +52,45 @@ public class LiteralTransformFunction implements TransformFunction { public LiteralTransformFunction(String literal) { _literal = literal; + _dataType = inferLiteralDataType(literal); } - public static FieldSpec.DataType inferLiteralDataType(LiteralTransformFunction transformFunction) { - String literal = transformFunction.getLiteral(); + @VisibleForTesting + static DataType inferLiteralDataType(String literal) { + // Try to interpret the literal as number try { - Number literalNum = NumberUtils.createNumber(literal); - if (literalNum instanceof Integer) { - return FieldSpec.DataType.INT; - } else if (literalNum instanceof Long) { - return FieldSpec.DataType.LONG; - } else if (literalNum instanceof Float) { - return FieldSpec.DataType.FLOAT; - } else if (literalNum instanceof Double) { - return FieldSpec.DataType.DOUBLE; + Number number = NumberUtils.createNumber(literal); + if (number instanceof Integer) { + return DataType.INT; + } else if (number instanceof Long) { + return DataType.LONG; + } else if (number instanceof Float) { + return DataType.FLOAT; + } else if (number instanceof Double) { + return DataType.DOUBLE; + } else { + return DataType.STRING; } } catch (Exception e) { + // Ignored + } + + // Try to interpret the literal as BOOLEAN + // NOTE: Intentionally use equals() instead of equalsIgnoreCase() here because boolean literal will always be parsed + // into lowercase string. We don't want to parse string "TRUE" as boolean. + if (literal.equals("true") || literal.equals("false")) { + return DataType.BOOLEAN; + } + + // Try to interpret the literal as TIMESTAMP + try { + Timestamp.valueOf(literal); + return DataType.TIMESTAMP; + } catch (Exception e) { + // Ignored } - return FieldSpec.DataType.STRING; + + return DataType.STRING; } public String getLiteral() { @@ -82,7 +108,7 @@ public class LiteralTransformFunction implements TransformFunction { @Override public TransformResultMetadata getResultMetadata() { - return BaseTransformFunction.STRING_SV_NO_DICTIONARY_METADATA; + return new TransformResultMetadata(_dataType, true, false); } @Override @@ -104,7 +130,11 @@ public class LiteralTransformFunction implements TransformFunction { public int[] transformToIntValuesSV(ProjectionBlock projectionBlock) { if (_intResult == null) { _intResult = new int[DocIdSetPlanNode.MAX_DOC_PER_CALL]; - Arrays.fill(_intResult, Integer.parseInt(_literal)); + if (_dataType != DataType.BOOLEAN) { + Arrays.fill(_intResult, new BigDecimal(_literal).intValue()); + } else { + Arrays.fill(_intResult, _literal.equals("true") ? 1 : 0); + } } return _intResult; } @@ -113,7 +143,11 @@ public class LiteralTransformFunction implements TransformFunction { public long[] transformToLongValuesSV(ProjectionBlock projectionBlock) { if (_longResult == null) { _longResult = new long[DocIdSetPlanNode.MAX_DOC_PER_CALL]; - Arrays.fill(_longResult, new BigDecimal(_literal).longValue()); + if (_dataType != DataType.TIMESTAMP) { + Arrays.fill(_longResult, new BigDecimal(_literal).longValue()); + } else { + Arrays.fill(_longResult, Timestamp.valueOf(_literal).getTime()); + } } return _longResult; } diff --git a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/CaseTransformFunctionTest.java b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/CaseTransformFunctionTest.java index 35b1011..33b0496 100644 --- a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/CaseTransformFunctionTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/CaseTransformFunctionTest.java @@ -18,22 +18,29 @@ */ package org.apache.pinot.core.operator.transform.function; +import java.util.Arrays; import java.util.Random; import org.apache.pinot.common.function.TransformFunctionType; import org.apache.pinot.common.request.context.ExpressionContext; import org.apache.pinot.common.request.context.RequestContextUtils; +import org.apache.pinot.spi.data.FieldSpec.DataType; import org.testng.Assert; import org.testng.annotations.Test; public class CaseTransformFunctionTest extends BaseTransformFunctionTest { - - private final int INDEX_TO_COMPARE = new Random(System.currentTimeMillis()).nextInt(NUM_ROWS); - private final TransformFunctionType[] BINARY_OPERATOR_TRANSFORM_FUNCTIONS = + private static final int INDEX_TO_COMPARE = new Random(System.currentTimeMillis()).nextInt(NUM_ROWS); + private static final TransformFunctionType[] BINARY_OPERATOR_TRANSFORM_FUNCTIONS = new TransformFunctionType[]{TransformFunctionType.EQUALS, TransformFunctionType.NOT_EQUALS, TransformFunctionType.GREATER_THAN, TransformFunctionType.GREATER_THAN_OR_EQUAL, TransformFunctionType.LESS_THAN, TransformFunctionType.LESS_THAN_OR_EQUAL}; @Test public void testCaseTransformFunctionWithIntResults() { + int[] expectedIntResults = new int[NUM_ROWS]; + Arrays.fill(expectedIntResults, 100); + testCaseQueryWithIntResults("true", expectedIntResults); + Arrays.fill(expectedIntResults, 10); + testCaseQueryWithIntResults("false", expectedIntResults); + for (TransformFunctionType functionType : BINARY_OPERATOR_TRANSFORM_FUNCTIONS) { testCaseQueryWithIntResults(String.format("%s(%s, %s)", functionType.getName(), INT_SV_COLUMN, String.format("%d", _intSVValues[INDEX_TO_COMPARE])), getExpectedIntResults(INT_SV_COLUMN, functionType)); @@ -42,50 +49,61 @@ public class CaseTransformFunctionTest extends BaseTransformFunctionTest { testCaseQueryWithIntResults(String.format("%s(%s, %s)", functionType.getName(), FLOAT_SV_COLUMN, String.format("%f", _floatSVValues[INDEX_TO_COMPARE])), getExpectedIntResults(FLOAT_SV_COLUMN, functionType)); testCaseQueryWithIntResults(String.format("%s(%s, %s)", functionType.getName(), DOUBLE_SV_COLUMN, - String.format("%.20f", _doubleSVValues[INDEX_TO_COMPARE])), + String.format("%.20f", _doubleSVValues[INDEX_TO_COMPARE])), getExpectedIntResults(DOUBLE_SV_COLUMN, functionType)); testCaseQueryWithIntResults(String.format("%s(%s, %s)", functionType.getName(), STRING_SV_COLUMN, - String.format("'%s'", _stringSVValues[INDEX_TO_COMPARE])), + String.format("'%s'", _stringSVValues[INDEX_TO_COMPARE])), getExpectedIntResults(STRING_SV_COLUMN, functionType)); } } @Test - public void testCaseTransformFunctionWithDoubleResults() { + public void testCaseTransformFunctionWithFloatResults() { + float[] expectedFloatResults = new float[NUM_ROWS]; + Arrays.fill(expectedFloatResults, 100); + testCaseQueryWithFloatResults("true", expectedFloatResults); + Arrays.fill(expectedFloatResults, 10); + testCaseQueryWithFloatResults("false", expectedFloatResults); + for (TransformFunctionType functionType : BINARY_OPERATOR_TRANSFORM_FUNCTIONS) { - testCaseQueryWithDoubleResults(String.format("%s(%s, %s)", functionType.getName(), INT_SV_COLUMN, - String.format("%d", _intSVValues[INDEX_TO_COMPARE])), getExpectedDoubleResults(INT_SV_COLUMN, functionType)); - testCaseQueryWithDoubleResults(String.format("%s(%s, %s)", functionType.getName(), LONG_SV_COLUMN, - String.format("%d", _longSVValues[INDEX_TO_COMPARE])), - getExpectedDoubleResults(LONG_SV_COLUMN, functionType)); - testCaseQueryWithDoubleResults(String.format("%s(%s, %s)", functionType.getName(), FLOAT_SV_COLUMN, - String.format("%f", _floatSVValues[INDEX_TO_COMPARE])), - getExpectedDoubleResults(FLOAT_SV_COLUMN, functionType)); - testCaseQueryWithDoubleResults(String.format("%s(%s, %s)", functionType.getName(), DOUBLE_SV_COLUMN, - String.format("%.20f", _doubleSVValues[INDEX_TO_COMPARE])), - getExpectedDoubleResults(DOUBLE_SV_COLUMN, functionType)); - testCaseQueryWithDoubleResults(String.format("%s(%s, %s)", functionType.getName(), STRING_SV_COLUMN, - String.format("'%s'", _stringSVValues[INDEX_TO_COMPARE])), - getExpectedDoubleResults(STRING_SV_COLUMN, functionType)); + testCaseQueryWithFloatResults(String.format("%s(%s, %s)", functionType.getName(), INT_SV_COLUMN, + String.format("%d", _intSVValues[INDEX_TO_COMPARE])), getExpectedFloatResults(INT_SV_COLUMN, functionType)); + testCaseQueryWithFloatResults(String.format("%s(%s, %s)", functionType.getName(), LONG_SV_COLUMN, + String.format("%d", _longSVValues[INDEX_TO_COMPARE])), getExpectedFloatResults(LONG_SV_COLUMN, functionType)); + testCaseQueryWithFloatResults(String.format("%s(%s, %s)", functionType.getName(), FLOAT_SV_COLUMN, + String.format("%f", _floatSVValues[INDEX_TO_COMPARE])), + getExpectedFloatResults(FLOAT_SV_COLUMN, functionType)); + testCaseQueryWithFloatResults(String.format("%s(%s, %s)", functionType.getName(), DOUBLE_SV_COLUMN, + String.format("%.20f", _doubleSVValues[INDEX_TO_COMPARE])), + getExpectedFloatResults(DOUBLE_SV_COLUMN, functionType)); + testCaseQueryWithFloatResults(String.format("%s(%s, %s)", functionType.getName(), STRING_SV_COLUMN, + String.format("'%s'", _stringSVValues[INDEX_TO_COMPARE])), + getExpectedFloatResults(STRING_SV_COLUMN, functionType)); } } @Test public void testCaseTransformFunctionWithStringResults() { + String[] expectedStringResults = new String[NUM_ROWS]; + Arrays.fill(expectedStringResults, "aaa"); + testCaseQueryWithStringResults("true", expectedStringResults); + Arrays.fill(expectedStringResults, "bbb"); + testCaseQueryWithStringResults("false", expectedStringResults); + for (TransformFunctionType functionType : BINARY_OPERATOR_TRANSFORM_FUNCTIONS) { testCaseQueryWithStringResults(String.format("%s(%s, %s)", functionType.getName(), INT_SV_COLUMN, String.format("%d", _intSVValues[INDEX_TO_COMPARE])), getExpectedStringResults(INT_SV_COLUMN, functionType)); testCaseQueryWithStringResults(String.format("%s(%s, %s)", functionType.getName(), LONG_SV_COLUMN, - String.format("%d", _longSVValues[INDEX_TO_COMPARE])), + String.format("%d", _longSVValues[INDEX_TO_COMPARE])), getExpectedStringResults(LONG_SV_COLUMN, functionType)); testCaseQueryWithStringResults(String.format("%s(%s, %s)", functionType.getName(), FLOAT_SV_COLUMN, - String.format("%f", _floatSVValues[INDEX_TO_COMPARE])), + String.format("%f", _floatSVValues[INDEX_TO_COMPARE])), getExpectedStringResults(FLOAT_SV_COLUMN, functionType)); testCaseQueryWithStringResults(String.format("%s(%s, %s)", functionType.getName(), DOUBLE_SV_COLUMN, - String.format("%.20f", _doubleSVValues[INDEX_TO_COMPARE])), + String.format("%.20f", _doubleSVValues[INDEX_TO_COMPARE])), getExpectedStringResults(DOUBLE_SV_COLUMN, functionType)); testCaseQueryWithStringResults(String.format("%s(%s, %s)", functionType.getName(), STRING_SV_COLUMN, - String.format("'%s'", _stringSVValues[INDEX_TO_COMPARE])), + String.format("'%s'", _stringSVValues[INDEX_TO_COMPARE])), getExpectedStringResults(STRING_SV_COLUMN, functionType)); } } @@ -96,15 +114,17 @@ public class CaseTransformFunctionTest extends BaseTransformFunctionTest { TransformFunction transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap); Assert.assertTrue(transformFunction instanceof CaseTransformFunction); Assert.assertEquals(transformFunction.getName(), CaseTransformFunction.FUNCTION_NAME); + Assert.assertEquals(transformFunction.getResultMetadata().getDataType(), DataType.INT); testTransformFunction(transformFunction, expectedValues); } - private void testCaseQueryWithDoubleResults(String predicate, double[] expectedValues) { + private void testCaseQueryWithFloatResults(String predicate, float[] expectedValues) { ExpressionContext expression = RequestContextUtils.getExpressionFromSQL(String.format("CASE WHEN %s THEN 100.0 ELSE 10.0 END", predicate)); TransformFunction transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap); Assert.assertTrue(transformFunction instanceof CaseTransformFunction); Assert.assertEquals(transformFunction.getName(), CaseTransformFunction.FUNCTION_NAME); + Assert.assertEquals(transformFunction.getResultMetadata().getDataType(), DataType.FLOAT); testTransformFunction(transformFunction, expectedValues); } @@ -114,6 +134,7 @@ public class CaseTransformFunctionTest extends BaseTransformFunctionTest { TransformFunction transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap); Assert.assertTrue(transformFunction instanceof CaseTransformFunction); Assert.assertEquals(transformFunction.getName(), CaseTransformFunction.FUNCTION_NAME); + Assert.assertEquals(transformFunction.getResultMetadata().getDataType(), DataType.STRING); testTransformFunction(transformFunction, expectedValues); } @@ -246,8 +267,8 @@ public class CaseTransformFunctionTest extends BaseTransformFunctionTest { return result; } - private double[] getExpectedDoubleResults(String column, TransformFunctionType type) { - double[] result = new double[NUM_ROWS]; + private float[] getExpectedFloatResults(String column, TransformFunctionType type) { + float[] result = new float[NUM_ROWS]; for (int i = 0; i < NUM_ROWS; i++) { switch (column) { case INT_SV_COLUMN: diff --git a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/LiteralTransformFunctionTest.java b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/LiteralTransformFunctionTest.java index 8a383c6..529245d 100644 --- a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/LiteralTransformFunctionTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/LiteralTransformFunctionTest.java @@ -18,7 +18,7 @@ */ package org.apache.pinot.core.operator.transform.function; -import org.apache.pinot.spi.data.FieldSpec; +import org.apache.pinot.spi.data.FieldSpec.DataType; import org.testng.Assert; import org.testng.annotations.Test; @@ -27,15 +27,13 @@ public class LiteralTransformFunctionTest { @Test public void testLiteralTransformFunction() { - Assert.assertEquals(LiteralTransformFunction.inferLiteralDataType(new LiteralTransformFunction("abc")), - FieldSpec.DataType.STRING); - Assert.assertEquals(LiteralTransformFunction.inferLiteralDataType(new LiteralTransformFunction("123")), - FieldSpec.DataType.INT); - Assert.assertEquals(LiteralTransformFunction.inferLiteralDataType(new LiteralTransformFunction("2147483649")), - FieldSpec.DataType.LONG); - Assert.assertEquals(LiteralTransformFunction.inferLiteralDataType(new LiteralTransformFunction("1.2")), - FieldSpec.DataType.FLOAT); - Assert.assertEquals(LiteralTransformFunction.inferLiteralDataType(new LiteralTransformFunction("41241241.2412")), - FieldSpec.DataType.DOUBLE); + Assert.assertEquals(LiteralTransformFunction.inferLiteralDataType("abc"), DataType.STRING); + Assert.assertEquals(LiteralTransformFunction.inferLiteralDataType("123"), DataType.INT); + Assert.assertEquals(LiteralTransformFunction.inferLiteralDataType("2147483649"), DataType.LONG); + Assert.assertEquals(LiteralTransformFunction.inferLiteralDataType("1.2"), DataType.FLOAT); + Assert.assertEquals(LiteralTransformFunction.inferLiteralDataType("41241241.2412"), DataType.DOUBLE); + Assert.assertEquals(LiteralTransformFunction.inferLiteralDataType("true"), DataType.BOOLEAN); + Assert.assertEquals(LiteralTransformFunction.inferLiteralDataType("false"), DataType.BOOLEAN); + Assert.assertEquals(LiteralTransformFunction.inferLiteralDataType("2020-02-02 20:20:20.20"), DataType.TIMESTAMP); } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org For additional commands, e-mail: commits-h...@pinot.apache.org