This is an automated email from the ASF dual-hosted git repository.

jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new b71f38a  Infer the data type for LiteralTransformFunction (#7332)
b71f38a is described below

commit b71f38a2b2ed7ff715b03f4b780105653f726969
Author: Xiaotian (Jackie) Jiang <17555551+jackie-ji...@users.noreply.github.com>
AuthorDate: Fri Aug 20 16:30:54 2021 -0700

    Infer the data type for LiteralTransformFunction (#7332)
    
    - Infer the data type for `LiteralTransformFunction` so that the parent 
transform function can read the correct data type
    - Fix the data type handling for `CaseTransformFunction` and use `STRING` 
type to handle non-numeric types
    - For `BinaryOperatorTransformFunction`, use the lhs expression as the main 
data type, and match rhs data type with lhs, e.g.
      - stringCol > 123 will use string comparison
      - intCol > '123' will use integer comparison
---
 .../function/BinaryOperatorTransformFunction.java  |  65 ++----------
 .../transform/function/CaseTransformFunction.java  | 118 ++++-----------------
 .../function/LiteralTransformFunction.java         |  68 +++++++++---
 .../function/CaseTransformFunctionTest.java        |  75 ++++++++-----
 .../function/LiteralTransformFunctionTest.java     |  20 ++--
 5 files changed, 135 insertions(+), 211 deletions(-)

diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/BinaryOperatorTransformFunction.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/BinaryOperatorTransformFunction.java
index df21c26..a610ace 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/BinaryOperatorTransformFunction.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/BinaryOperatorTransformFunction.java
@@ -44,8 +44,8 @@ public abstract class BinaryOperatorTransformFunction extends 
BaseTransformFunct
   @Override
   public void init(List<TransformFunction> arguments, Map<String, DataSource> 
dataSourceMap) {
     // Check that there are exact 2 arguments
-    Preconditions
-        .checkArgument(arguments.size() == 2, "Exact 2 arguments are required 
for binary operator transform function");
+    Preconditions.checkArgument(arguments.size() == 2,
+        "Exact 2 arguments are required for binary operator transform 
function");
     _leftTransformFunction = arguments.get(0);
     _rightTransformFunction = arguments.get(1);
     _leftStoredType = 
_leftTransformFunction.getResultMetadata().getDataType().getStoredType();
@@ -93,7 +93,7 @@ public abstract class BinaryOperatorTransformFunction extends 
BaseTransformFunct
           case FLOAT:
             float[] rightFloatValues = 
_rightTransformFunction.transformToFloatValuesSV(projectionBlock);
             for (int i = 0; i < length; i++) {
-              _results[i] = getIntResult(Float.compare(leftIntValues[i], 
rightFloatValues[i]));
+              _results[i] = getIntResult(Double.compare(leftIntValues[i], 
rightFloatValues[i]));
             }
             break;
           case DOUBLE:
@@ -261,62 +261,9 @@ public abstract class BinaryOperatorTransformFunction 
extends BaseTransformFunct
         break;
       case STRING:
         String[] leftStringValues = 
_leftTransformFunction.transformToStringValuesSV(projectionBlock);
-        switch (_rightStoredType) {
-          case INT:
-            int[] rightIntValues = 
_rightTransformFunction.transformToIntValuesSV(projectionBlock);
-            for (int i = 0; i < length; i++) {
-              try {
-                _results[i] =
-                    getIntResult(new 
BigDecimal(leftStringValues[i]).compareTo(BigDecimal.valueOf(rightIntValues[i])));
-              } catch (NumberFormatException e) {
-                _results[i] = 0;
-              }
-            }
-            break;
-          case LONG:
-            long[] rightLongValues = 
_rightTransformFunction.transformToLongValuesSV(projectionBlock);
-            for (int i = 0; i < length; i++) {
-              try {
-                _results[i] =
-                    getIntResult(new 
BigDecimal(leftStringValues[i]).compareTo(BigDecimal.valueOf(rightLongValues[i])));
-              } catch (NumberFormatException e) {
-                _results[i] = 0;
-              }
-            }
-            break;
-          case FLOAT:
-            float[] rightFloatValues = 
_rightTransformFunction.transformToFloatValuesSV(projectionBlock);
-            for (int i = 0; i < length; i++) {
-              try {
-                _results[i] = getIntResult(
-                    new 
BigDecimal(leftStringValues[i]).compareTo(BigDecimal.valueOf(rightFloatValues[i])));
-              } catch (NumberFormatException e) {
-                _results[i] = 0;
-              }
-            }
-            break;
-          case DOUBLE:
-            double[] rightDoubleValues = 
_rightTransformFunction.transformToDoubleValuesSV(projectionBlock);
-            for (int i = 0; i < length; i++) {
-              try {
-                _results[i] = getIntResult(
-                    new 
BigDecimal(leftStringValues[i]).compareTo(BigDecimal.valueOf(rightDoubleValues[i])));
-              } catch (NumberFormatException e) {
-                _results[i] = 0;
-              }
-            }
-            break;
-          case STRING:
-            String[] rightStringValues = 
_rightTransformFunction.transformToStringValuesSV(projectionBlock);
-            for (int i = 0; i < length; i++) {
-              _results[i] = 
getIntResult(leftStringValues[i].compareTo(rightStringValues[i]));
-            }
-            break;
-          default:
-            throw new IllegalStateException(String.format(
-                "Unsupported data type for comparison: [Left Transform 
Function [%s] result type is [%s], Right Transform Function [%s] result type is 
[%s]]",
-                _leftTransformFunction.getName(), _leftStoredType, 
_rightTransformFunction.getName(),
-                _rightStoredType));
+        String[] rightStringValues = 
_rightTransformFunction.transformToStringValuesSV(projectionBlock);
+        for (int i = 0; i < length; i++) {
+          _results[i] = 
getIntResult(leftStringValues[i].compareTo(rightStringValues[i]));
         }
         break;
       case BYTES:
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/CaseTransformFunction.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/CaseTransformFunction.java
index e04db32..7de3a15 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/CaseTransformFunction.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/CaseTransformFunction.java
@@ -104,147 +104,71 @@ public class CaseTransformFunction extends 
BaseTransformFunction {
       Preconditions.checkState(thenStatementResultMetadata.isSingleValue(),
           String.format("Unsupported multi-value expression in the THEN clause 
of index: %d", i));
       DataType thenStatementDataType = 
thenStatementResultMetadata.getDataType();
+
+      // Upcast the data type to cover all the data types in THEN and ELSE 
clauses if they don't match
+      // For numeric types:
+      // - INT & LONG -> LONG
+      // - INT & FLOAT/DOUBLE -> DOUBLE
+      // - LONG & FLOAT/DOUBLE -> DOUBLE (might lose precision)
+      // - FLOAT & DOUBLE -> DOUBLE
+      // Use STRING to handle non-numeric types
+      if (thenStatementDataType == dataType) {
+        continue;
+      }
       switch (dataType) {
         case INT:
-          if (thenStatement instanceof LiteralTransformFunction) {
-            dataType = 
LiteralTransformFunction.inferLiteralDataType((LiteralTransformFunction) 
thenStatement);
-            break;
-          }
           switch (thenStatementDataType) {
-            case INT:
             case LONG:
+              dataType = DataType.LONG;
+              break;
             case FLOAT:
             case DOUBLE:
-            case STRING:
-              dataType = thenStatementDataType;
+              dataType = DataType.DOUBLE;
               break;
             default:
-              throw new IllegalStateException(String
-                  .format("Incompatible expression type: %s in the THEN clause 
of index: %d, main type: %s",
-                      thenStatementDataType, i, dataType));
+              dataType = DataType.STRING;
+              break;
           }
           break;
         case LONG:
-          if (thenStatement instanceof LiteralTransformFunction) {
-            DataType literalDataType =
-                
LiteralTransformFunction.inferLiteralDataType((LiteralTransformFunction) 
thenStatement);
-            switch (literalDataType) {
-              case INT:
-              case LONG:
-                break;
-              case FLOAT:
-              case DOUBLE:
-                dataType = DataType.DOUBLE;
-                break;
-              default:
-                dataType = literalDataType;
-            }
-            break;
-          }
           switch (thenStatementDataType) {
             case INT:
-            case LONG:
               break;
             case FLOAT:
             case DOUBLE:
               dataType = DataType.DOUBLE;
               break;
-            case STRING:
+            default:
               dataType = DataType.STRING;
               break;
-            default:
-              throw new IllegalStateException(String
-                  .format("Incompatible expression type: %s in the THEN clause 
of index: %d, main type: %s",
-                      thenStatementDataType, i, dataType));
           }
           break;
         case FLOAT:
-          if (thenStatement instanceof LiteralTransformFunction) {
-            DataType literalDataType =
-                
LiteralTransformFunction.inferLiteralDataType((LiteralTransformFunction) 
thenStatement);
-            switch (literalDataType) {
-              case INT:
-              case FLOAT:
-                break;
-              case LONG:
-              case DOUBLE:
-                dataType = DataType.DOUBLE;
-                break;
-              default:
-                dataType = literalDataType;
-            }
-            break;
-          }
           switch (thenStatementDataType) {
             case INT:
-            case FLOAT:
-              break;
             case LONG:
             case DOUBLE:
               dataType = DataType.DOUBLE;
               break;
-            case STRING:
+            default:
               dataType = DataType.STRING;
               break;
-            default:
-              throw new IllegalStateException(String
-                  .format("Incompatible expression type: %s in the THEN clause 
of index: %d, main type: %s",
-                      thenStatementDataType, i, dataType));
           }
           break;
         case DOUBLE:
-          if (thenStatement instanceof LiteralTransformFunction) {
-            DataType literalDataType =
-                
LiteralTransformFunction.inferLiteralDataType((LiteralTransformFunction) 
thenStatement);
-            switch (literalDataType) {
-              case INT:
-              case LONG:
-              case FLOAT:
-              case DOUBLE:
-                break;
-              default:
-                dataType = literalDataType;
-            }
-            break;
-          }
           switch (thenStatementDataType) {
             case INT:
             case FLOAT:
             case LONG:
-            case DOUBLE:
-              break;
-            case STRING:
-              dataType = thenStatementDataType;
               break;
             default:
-              throw new IllegalStateException(String
-                  .format("Incompatible expression type: %s in the THEN clause 
of index: %d, main type: %s",
-                      thenStatementDataType, i, dataType));
-          }
-          break;
-        case STRING:
-          if (thenStatement instanceof LiteralTransformFunction) {
-            break;
-          }
-          switch (thenStatementDataType) {
-            case INT:
-            case FLOAT:
-            case LONG:
-            case DOUBLE:
-            case STRING:
+              dataType = DataType.STRING;
               break;
-            default:
-              throw new IllegalStateException(String
-                  .format("Incompatible expression type: %s in the THEN clause 
of index: %d, main type: %s",
-                      thenStatementDataType, i, dataType));
           }
           break;
         default:
-          if (thenStatementDataType != dataType) {
-            throw new IllegalStateException(String
-                .format("Incompatible expression type: %s in the THEN clause 
of index: %d, main type: %s",
-                    thenStatementDataType, i, dataType));
-          }
+          dataType = DataType.STRING;
+          break;
       }
     }
     return new TransformResultMetadata(dataType, true, false);
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/LiteralTransformFunction.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/LiteralTransformFunction.java
index fdf6580..dd4606a 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/LiteralTransformFunction.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/LiteralTransformFunction.java
@@ -18,7 +18,9 @@
  */
 package org.apache.pinot.core.operator.transform.function;
 
+import com.google.common.annotations.VisibleForTesting;
 import java.math.BigDecimal;
+import java.sql.Timestamp;
 import java.util.Arrays;
 import java.util.List;
 import java.util.Map;
@@ -28,16 +30,19 @@ import 
org.apache.pinot.core.operator.transform.TransformResultMetadata;
 import org.apache.pinot.core.plan.DocIdSetPlanNode;
 import org.apache.pinot.segment.spi.datasource.DataSource;
 import org.apache.pinot.segment.spi.index.reader.Dictionary;
-import org.apache.pinot.spi.data.FieldSpec;
+import org.apache.pinot.spi.data.FieldSpec.DataType;
 import org.apache.pinot.spi.utils.BytesUtils;
 
 
 /**
  * The <code>LiteralTransformFunction</code> class is a special transform 
function which is a wrapper on top of a
- * LITERAL, and only supports {@link #getLiteral()}.
+ * LITERAL. The data type is inferred from the literal string.
+ * TODO: Preserve the type of the literal instead of inferring the type from 
the string
  */
 public class LiteralTransformFunction implements TransformFunction {
   private final String _literal;
+  private final DataType _dataType;
+
   private int[] _intResult;
   private long[] _longResult;
   private float[] _floatResult;
@@ -47,24 +52,45 @@ public class LiteralTransformFunction implements 
TransformFunction {
 
   public LiteralTransformFunction(String literal) {
     _literal = literal;
+    _dataType = inferLiteralDataType(literal);
   }
 
-  public static FieldSpec.DataType 
inferLiteralDataType(LiteralTransformFunction transformFunction) {
-    String literal = transformFunction.getLiteral();
+  @VisibleForTesting
+  static DataType inferLiteralDataType(String literal) {
+    // Try to interpret the literal as number
     try {
-      Number literalNum = NumberUtils.createNumber(literal);
-      if (literalNum instanceof Integer) {
-        return FieldSpec.DataType.INT;
-      } else if (literalNum instanceof Long) {
-        return FieldSpec.DataType.LONG;
-      } else if (literalNum instanceof Float) {
-        return FieldSpec.DataType.FLOAT;
-      } else if (literalNum instanceof Double) {
-        return FieldSpec.DataType.DOUBLE;
+      Number number = NumberUtils.createNumber(literal);
+      if (number instanceof Integer) {
+        return DataType.INT;
+      } else if (number instanceof Long) {
+        return DataType.LONG;
+      } else if (number instanceof Float) {
+        return DataType.FLOAT;
+      } else if (number instanceof Double) {
+        return DataType.DOUBLE;
+      } else {
+        return DataType.STRING;
       }
     } catch (Exception e) {
+      // Ignored
+    }
+
+    // Try to interpret the literal as BOOLEAN
+    // NOTE: Intentionally use equals() instead of equalsIgnoreCase() here 
because boolean literal will always be parsed
+    //       into lowercase string. We don't want to parse string "TRUE" as 
boolean.
+    if (literal.equals("true") || literal.equals("false")) {
+      return DataType.BOOLEAN;
+    }
+
+    // Try to interpret the literal as TIMESTAMP
+    try {
+      Timestamp.valueOf(literal);
+      return DataType.TIMESTAMP;
+    } catch (Exception e) {
+      // Ignored
     }
-    return FieldSpec.DataType.STRING;
+
+    return DataType.STRING;
   }
 
   public String getLiteral() {
@@ -82,7 +108,7 @@ public class LiteralTransformFunction implements 
TransformFunction {
 
   @Override
   public TransformResultMetadata getResultMetadata() {
-    return BaseTransformFunction.STRING_SV_NO_DICTIONARY_METADATA;
+    return new TransformResultMetadata(_dataType, true, false);
   }
 
   @Override
@@ -104,7 +130,11 @@ public class LiteralTransformFunction implements 
TransformFunction {
   public int[] transformToIntValuesSV(ProjectionBlock projectionBlock) {
     if (_intResult == null) {
       _intResult = new int[DocIdSetPlanNode.MAX_DOC_PER_CALL];
-      Arrays.fill(_intResult, Integer.parseInt(_literal));
+      if (_dataType != DataType.BOOLEAN) {
+        Arrays.fill(_intResult, new BigDecimal(_literal).intValue());
+      } else {
+        Arrays.fill(_intResult, _literal.equals("true") ? 1 : 0);
+      }
     }
     return _intResult;
   }
@@ -113,7 +143,11 @@ public class LiteralTransformFunction implements 
TransformFunction {
   public long[] transformToLongValuesSV(ProjectionBlock projectionBlock) {
     if (_longResult == null) {
       _longResult = new long[DocIdSetPlanNode.MAX_DOC_PER_CALL];
-      Arrays.fill(_longResult, new BigDecimal(_literal).longValue());
+      if (_dataType != DataType.TIMESTAMP) {
+        Arrays.fill(_longResult, new BigDecimal(_literal).longValue());
+      } else {
+        Arrays.fill(_longResult, Timestamp.valueOf(_literal).getTime());
+      }
     }
     return _longResult;
   }
diff --git 
a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/CaseTransformFunctionTest.java
 
b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/CaseTransformFunctionTest.java
index 35b1011..33b0496 100644
--- 
a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/CaseTransformFunctionTest.java
+++ 
b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/CaseTransformFunctionTest.java
@@ -18,22 +18,29 @@
  */
 package org.apache.pinot.core.operator.transform.function;
 
+import java.util.Arrays;
 import java.util.Random;
 import org.apache.pinot.common.function.TransformFunctionType;
 import org.apache.pinot.common.request.context.ExpressionContext;
 import org.apache.pinot.common.request.context.RequestContextUtils;
+import org.apache.pinot.spi.data.FieldSpec.DataType;
 import org.testng.Assert;
 import org.testng.annotations.Test;
 
 
 public class CaseTransformFunctionTest extends BaseTransformFunctionTest {
-
-  private final int INDEX_TO_COMPARE = new 
Random(System.currentTimeMillis()).nextInt(NUM_ROWS);
-  private final TransformFunctionType[] BINARY_OPERATOR_TRANSFORM_FUNCTIONS =
+  private static final int INDEX_TO_COMPARE = new 
Random(System.currentTimeMillis()).nextInt(NUM_ROWS);
+  private static final TransformFunctionType[] 
BINARY_OPERATOR_TRANSFORM_FUNCTIONS =
       new TransformFunctionType[]{TransformFunctionType.EQUALS, 
TransformFunctionType.NOT_EQUALS, TransformFunctionType.GREATER_THAN, 
TransformFunctionType.GREATER_THAN_OR_EQUAL, TransformFunctionType.LESS_THAN, 
TransformFunctionType.LESS_THAN_OR_EQUAL};
 
   @Test
   public void testCaseTransformFunctionWithIntResults() {
+    int[] expectedIntResults = new int[NUM_ROWS];
+    Arrays.fill(expectedIntResults, 100);
+    testCaseQueryWithIntResults("true", expectedIntResults);
+    Arrays.fill(expectedIntResults, 10);
+    testCaseQueryWithIntResults("false", expectedIntResults);
+
     for (TransformFunctionType functionType : 
BINARY_OPERATOR_TRANSFORM_FUNCTIONS) {
       testCaseQueryWithIntResults(String.format("%s(%s, %s)", 
functionType.getName(), INT_SV_COLUMN,
           String.format("%d", _intSVValues[INDEX_TO_COMPARE])), 
getExpectedIntResults(INT_SV_COLUMN, functionType));
@@ -42,50 +49,61 @@ public class CaseTransformFunctionTest extends 
BaseTransformFunctionTest {
       testCaseQueryWithIntResults(String.format("%s(%s, %s)", 
functionType.getName(), FLOAT_SV_COLUMN,
           String.format("%f", _floatSVValues[INDEX_TO_COMPARE])), 
getExpectedIntResults(FLOAT_SV_COLUMN, functionType));
       testCaseQueryWithIntResults(String.format("%s(%s, %s)", 
functionType.getName(), DOUBLE_SV_COLUMN,
-          String.format("%.20f", _doubleSVValues[INDEX_TO_COMPARE])),
+              String.format("%.20f", _doubleSVValues[INDEX_TO_COMPARE])),
           getExpectedIntResults(DOUBLE_SV_COLUMN, functionType));
       testCaseQueryWithIntResults(String.format("%s(%s, %s)", 
functionType.getName(), STRING_SV_COLUMN,
-          String.format("'%s'", _stringSVValues[INDEX_TO_COMPARE])),
+              String.format("'%s'", _stringSVValues[INDEX_TO_COMPARE])),
           getExpectedIntResults(STRING_SV_COLUMN, functionType));
     }
   }
 
   @Test
-  public void testCaseTransformFunctionWithDoubleResults() {
+  public void testCaseTransformFunctionWithFloatResults() {
+    float[] expectedFloatResults = new float[NUM_ROWS];
+    Arrays.fill(expectedFloatResults, 100);
+    testCaseQueryWithFloatResults("true", expectedFloatResults);
+    Arrays.fill(expectedFloatResults, 10);
+    testCaseQueryWithFloatResults("false", expectedFloatResults);
+
     for (TransformFunctionType functionType : 
BINARY_OPERATOR_TRANSFORM_FUNCTIONS) {
-      testCaseQueryWithDoubleResults(String.format("%s(%s, %s)", 
functionType.getName(), INT_SV_COLUMN,
-          String.format("%d", _intSVValues[INDEX_TO_COMPARE])), 
getExpectedDoubleResults(INT_SV_COLUMN, functionType));
-      testCaseQueryWithDoubleResults(String.format("%s(%s, %s)", 
functionType.getName(), LONG_SV_COLUMN,
-          String.format("%d", _longSVValues[INDEX_TO_COMPARE])),
-          getExpectedDoubleResults(LONG_SV_COLUMN, functionType));
-      testCaseQueryWithDoubleResults(String.format("%s(%s, %s)", 
functionType.getName(), FLOAT_SV_COLUMN,
-          String.format("%f", _floatSVValues[INDEX_TO_COMPARE])),
-          getExpectedDoubleResults(FLOAT_SV_COLUMN, functionType));
-      testCaseQueryWithDoubleResults(String.format("%s(%s, %s)", 
functionType.getName(), DOUBLE_SV_COLUMN,
-          String.format("%.20f", _doubleSVValues[INDEX_TO_COMPARE])),
-          getExpectedDoubleResults(DOUBLE_SV_COLUMN, functionType));
-      testCaseQueryWithDoubleResults(String.format("%s(%s, %s)", 
functionType.getName(), STRING_SV_COLUMN,
-          String.format("'%s'", _stringSVValues[INDEX_TO_COMPARE])),
-          getExpectedDoubleResults(STRING_SV_COLUMN, functionType));
+      testCaseQueryWithFloatResults(String.format("%s(%s, %s)", 
functionType.getName(), INT_SV_COLUMN,
+          String.format("%d", _intSVValues[INDEX_TO_COMPARE])), 
getExpectedFloatResults(INT_SV_COLUMN, functionType));
+      testCaseQueryWithFloatResults(String.format("%s(%s, %s)", 
functionType.getName(), LONG_SV_COLUMN,
+          String.format("%d", _longSVValues[INDEX_TO_COMPARE])), 
getExpectedFloatResults(LONG_SV_COLUMN, functionType));
+      testCaseQueryWithFloatResults(String.format("%s(%s, %s)", 
functionType.getName(), FLOAT_SV_COLUMN,
+              String.format("%f", _floatSVValues[INDEX_TO_COMPARE])),
+          getExpectedFloatResults(FLOAT_SV_COLUMN, functionType));
+      testCaseQueryWithFloatResults(String.format("%s(%s, %s)", 
functionType.getName(), DOUBLE_SV_COLUMN,
+              String.format("%.20f", _doubleSVValues[INDEX_TO_COMPARE])),
+          getExpectedFloatResults(DOUBLE_SV_COLUMN, functionType));
+      testCaseQueryWithFloatResults(String.format("%s(%s, %s)", 
functionType.getName(), STRING_SV_COLUMN,
+              String.format("'%s'", _stringSVValues[INDEX_TO_COMPARE])),
+          getExpectedFloatResults(STRING_SV_COLUMN, functionType));
     }
   }
 
   @Test
   public void testCaseTransformFunctionWithStringResults() {
+    String[] expectedStringResults = new String[NUM_ROWS];
+    Arrays.fill(expectedStringResults, "aaa");
+    testCaseQueryWithStringResults("true", expectedStringResults);
+    Arrays.fill(expectedStringResults, "bbb");
+    testCaseQueryWithStringResults("false", expectedStringResults);
+
     for (TransformFunctionType functionType : 
BINARY_OPERATOR_TRANSFORM_FUNCTIONS) {
       testCaseQueryWithStringResults(String.format("%s(%s, %s)", 
functionType.getName(), INT_SV_COLUMN,
           String.format("%d", _intSVValues[INDEX_TO_COMPARE])), 
getExpectedStringResults(INT_SV_COLUMN, functionType));
       testCaseQueryWithStringResults(String.format("%s(%s, %s)", 
functionType.getName(), LONG_SV_COLUMN,
-          String.format("%d", _longSVValues[INDEX_TO_COMPARE])),
+              String.format("%d", _longSVValues[INDEX_TO_COMPARE])),
           getExpectedStringResults(LONG_SV_COLUMN, functionType));
       testCaseQueryWithStringResults(String.format("%s(%s, %s)", 
functionType.getName(), FLOAT_SV_COLUMN,
-          String.format("%f", _floatSVValues[INDEX_TO_COMPARE])),
+              String.format("%f", _floatSVValues[INDEX_TO_COMPARE])),
           getExpectedStringResults(FLOAT_SV_COLUMN, functionType));
       testCaseQueryWithStringResults(String.format("%s(%s, %s)", 
functionType.getName(), DOUBLE_SV_COLUMN,
-          String.format("%.20f", _doubleSVValues[INDEX_TO_COMPARE])),
+              String.format("%.20f", _doubleSVValues[INDEX_TO_COMPARE])),
           getExpectedStringResults(DOUBLE_SV_COLUMN, functionType));
       testCaseQueryWithStringResults(String.format("%s(%s, %s)", 
functionType.getName(), STRING_SV_COLUMN,
-          String.format("'%s'", _stringSVValues[INDEX_TO_COMPARE])),
+              String.format("'%s'", _stringSVValues[INDEX_TO_COMPARE])),
           getExpectedStringResults(STRING_SV_COLUMN, functionType));
     }
   }
@@ -96,15 +114,17 @@ public class CaseTransformFunctionTest extends 
BaseTransformFunctionTest {
     TransformFunction transformFunction = 
TransformFunctionFactory.get(expression, _dataSourceMap);
     Assert.assertTrue(transformFunction instanceof CaseTransformFunction);
     Assert.assertEquals(transformFunction.getName(), 
CaseTransformFunction.FUNCTION_NAME);
+    Assert.assertEquals(transformFunction.getResultMetadata().getDataType(), 
DataType.INT);
     testTransformFunction(transformFunction, expectedValues);
   }
 
-  private void testCaseQueryWithDoubleResults(String predicate, double[] 
expectedValues) {
+  private void testCaseQueryWithFloatResults(String predicate, float[] 
expectedValues) {
     ExpressionContext expression =
         RequestContextUtils.getExpressionFromSQL(String.format("CASE WHEN %s 
THEN 100.0 ELSE 10.0 END", predicate));
     TransformFunction transformFunction = 
TransformFunctionFactory.get(expression, _dataSourceMap);
     Assert.assertTrue(transformFunction instanceof CaseTransformFunction);
     Assert.assertEquals(transformFunction.getName(), 
CaseTransformFunction.FUNCTION_NAME);
+    Assert.assertEquals(transformFunction.getResultMetadata().getDataType(), 
DataType.FLOAT);
     testTransformFunction(transformFunction, expectedValues);
   }
 
@@ -114,6 +134,7 @@ public class CaseTransformFunctionTest extends 
BaseTransformFunctionTest {
     TransformFunction transformFunction = 
TransformFunctionFactory.get(expression, _dataSourceMap);
     Assert.assertTrue(transformFunction instanceof CaseTransformFunction);
     Assert.assertEquals(transformFunction.getName(), 
CaseTransformFunction.FUNCTION_NAME);
+    Assert.assertEquals(transformFunction.getResultMetadata().getDataType(), 
DataType.STRING);
     testTransformFunction(transformFunction, expectedValues);
   }
 
@@ -246,8 +267,8 @@ public class CaseTransformFunctionTest extends 
BaseTransformFunctionTest {
     return result;
   }
 
-  private double[] getExpectedDoubleResults(String column, 
TransformFunctionType type) {
-    double[] result = new double[NUM_ROWS];
+  private float[] getExpectedFloatResults(String column, TransformFunctionType 
type) {
+    float[] result = new float[NUM_ROWS];
     for (int i = 0; i < NUM_ROWS; i++) {
       switch (column) {
         case INT_SV_COLUMN:
diff --git 
a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/LiteralTransformFunctionTest.java
 
b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/LiteralTransformFunctionTest.java
index 8a383c6..529245d 100644
--- 
a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/LiteralTransformFunctionTest.java
+++ 
b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/LiteralTransformFunctionTest.java
@@ -18,7 +18,7 @@
  */
 package org.apache.pinot.core.operator.transform.function;
 
-import org.apache.pinot.spi.data.FieldSpec;
+import org.apache.pinot.spi.data.FieldSpec.DataType;
 import org.testng.Assert;
 import org.testng.annotations.Test;
 
@@ -27,15 +27,13 @@ public class LiteralTransformFunctionTest {
 
   @Test
   public void testLiteralTransformFunction() {
-    Assert.assertEquals(LiteralTransformFunction.inferLiteralDataType(new 
LiteralTransformFunction("abc")),
-        FieldSpec.DataType.STRING);
-    Assert.assertEquals(LiteralTransformFunction.inferLiteralDataType(new 
LiteralTransformFunction("123")),
-        FieldSpec.DataType.INT);
-    Assert.assertEquals(LiteralTransformFunction.inferLiteralDataType(new 
LiteralTransformFunction("2147483649")),
-        FieldSpec.DataType.LONG);
-    Assert.assertEquals(LiteralTransformFunction.inferLiteralDataType(new 
LiteralTransformFunction("1.2")),
-        FieldSpec.DataType.FLOAT);
-    Assert.assertEquals(LiteralTransformFunction.inferLiteralDataType(new 
LiteralTransformFunction("41241241.2412")),
-        FieldSpec.DataType.DOUBLE);
+    Assert.assertEquals(LiteralTransformFunction.inferLiteralDataType("abc"), 
DataType.STRING);
+    Assert.assertEquals(LiteralTransformFunction.inferLiteralDataType("123"), 
DataType.INT);
+    
Assert.assertEquals(LiteralTransformFunction.inferLiteralDataType("2147483649"),
 DataType.LONG);
+    Assert.assertEquals(LiteralTransformFunction.inferLiteralDataType("1.2"), 
DataType.FLOAT);
+    
Assert.assertEquals(LiteralTransformFunction.inferLiteralDataType("41241241.2412"),
 DataType.DOUBLE);
+    Assert.assertEquals(LiteralTransformFunction.inferLiteralDataType("true"), 
DataType.BOOLEAN);
+    
Assert.assertEquals(LiteralTransformFunction.inferLiteralDataType("false"), 
DataType.BOOLEAN);
+    
Assert.assertEquals(LiteralTransformFunction.inferLiteralDataType("2020-02-02 
20:20:20.20"), DataType.TIMESTAMP);
   }
 }

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org
For additional commands, e-mail: commits-h...@pinot.apache.org

Reply via email to