This is an automated email from the ASF dual-hosted git repository.

rongr pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 6bb387a10a Support array gen in literal evaluation (#12278)
6bb387a10a is described below

commit 6bb387a10a311af8ae166e93b5ba629fdb8cc4b8
Author: Xiang Fu <[email protected]>
AuthorDate: Thu Jan 18 07:53:22 2024 -0800

    Support array gen in literal evaluation (#12278)
---
 .../pinot/common/function/FunctionRegistry.java    | 46 ++++++++----
 .../common/function/scalar/ArrayFunctions.java     | 59 ++++++++++++++++
 .../common/request/context/LiteralContext.java     | 20 ++++++
 .../rewriter/CompileTimeFunctionsInvoker.java      |  9 ++-
 .../function/ArrayLiteralTransformFunction.java    | 72 +++++++++++++++++++
 .../function/TransformFunctionFactory.java         |  8 ++-
 .../function/HistogramAggregationFunction.java     | 33 +++++++--
 .../function/InbuiltFunctionEvaluatorTest.java     |  2 +-
 .../apache/pinot/queries/HistogramQueriesTest.java |  4 +-
 .../pinot/integration/tests/custom/ArrayTest.java  | 82 +++++++++++++++++++++-
 .../rel/rules/PinotEvaluateLiteralRule.java        |  8 ++-
 .../runtime/operator/operands/FunctionOperand.java | 21 ++++--
 .../local/function/InbuiltFunctionEvaluator.java   |  6 ++
 .../pinot/spi/annotations/ScalarFunction.java      |  5 ++
 14 files changed, 341 insertions(+), 34 deletions(-)

diff --git 
a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java
 
b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java
index 97fa972bee..deb1673d8b 100644
--- 
a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java
+++ 
b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java
@@ -52,6 +52,8 @@ public class FunctionRegistry {
   // This FUNCTION_MAP is used by Calcite function catalog to look up function 
by function signature.
   private static final NameMultimap<Function> FUNCTION_MAP = new 
NameMultimap<>();
 
+  private static final int VAR_ARG_KEY = -1;
+
   /**
    * Registers the scalar functions via reflection.
    * NOTE: In order to plugin methods using reflection, the methods should be 
inside a class that includes ".function."
@@ -69,12 +71,14 @@ public class FunctionRegistry {
         // Annotated function names
         String[] scalarFunctionNames = scalarFunction.names();
         boolean nullableParameters = scalarFunction.nullableParameters();
+        boolean isPlaceholder = scalarFunction.isPlaceholder();
+        boolean isVarArg = scalarFunction.isVarArg();
         if (scalarFunctionNames.length > 0) {
           for (String name : scalarFunctionNames) {
-            FunctionRegistry.registerFunction(name, method, 
nullableParameters, scalarFunction.isPlaceholder());
+            FunctionRegistry.registerFunction(name, method, 
nullableParameters, isPlaceholder, isVarArg);
           }
         } else {
-          FunctionRegistry.registerFunction(method, nullableParameters, 
scalarFunction.isPlaceholder());
+          FunctionRegistry.registerFunction(method, nullableParameters, 
isPlaceholder, isVarArg);
         }
       }
     }
@@ -93,31 +97,40 @@ public class FunctionRegistry {
   /**
    * Registers a method with the name of the method.
    */
-  public static void registerFunction(Method method, boolean 
nullableParameters, boolean isPlaceholder) {
-    registerFunction(method.getName(), method, nullableParameters, 
isPlaceholder);
+  public static void registerFunction(Method method, boolean 
nullableParameters, boolean isPlaceholder,
+      boolean isVarArg) {
+    registerFunction(method.getName(), method, nullableParameters, 
isPlaceholder, isVarArg);
   }
 
   /**
    * Registers a method with the given function name.
    */
   public static void registerFunction(String functionName, Method method, 
boolean nullableParameters,
-      boolean isPlaceholder) {
+      boolean isPlaceholder, boolean isVarArg) {
     if (!isPlaceholder) {
-      registerFunctionInfoMap(functionName, method, nullableParameters);
+      registerFunctionInfoMap(functionName, method, nullableParameters, 
isVarArg);
     }
-    registerCalciteNamedFunctionMap(functionName, method, nullableParameters);
+    registerCalciteNamedFunctionMap(functionName, method, nullableParameters, 
isVarArg);
   }
 
-  private static void registerFunctionInfoMap(String functionName, Method 
method, boolean nullableParameters) {
+  private static void registerFunctionInfoMap(String functionName, Method 
method, boolean nullableParameters,
+      boolean isVarArg) {
     FunctionInfo functionInfo = new FunctionInfo(method, 
method.getDeclaringClass(), nullableParameters);
     String canonicalName = canonicalize(functionName);
     Map<Integer, FunctionInfo> functionInfoMap = 
FUNCTION_INFO_MAP.computeIfAbsent(canonicalName, k -> new HashMap<>());
-    FunctionInfo existFunctionInfo = 
functionInfoMap.put(method.getParameterCount(), functionInfo);
-    Preconditions.checkState(existFunctionInfo == null || 
existFunctionInfo.getMethod() == functionInfo.getMethod(),
-        "Function: %s with %s parameters is already registered", functionName, 
method.getParameterCount());
+    if (isVarArg) {
+      FunctionInfo existFunctionInfo = functionInfoMap.put(VAR_ARG_KEY, 
functionInfo);
+      Preconditions.checkState(existFunctionInfo == null || 
existFunctionInfo.getMethod() == functionInfo.getMethod(),
+          "Function: %s with variable number of parameters is already 
registered", functionName);
+    } else {
+      FunctionInfo existFunctionInfo = 
functionInfoMap.put(method.getParameterCount(), functionInfo);
+      Preconditions.checkState(existFunctionInfo == null || 
existFunctionInfo.getMethod() == functionInfo.getMethod(),
+          "Function: %s with %s parameters is already registered", 
functionName, method.getParameterCount());
+    }
   }
 
-  private static void registerCalciteNamedFunctionMap(String functionName, 
Method method, boolean nullableParameters) {
+  private static void registerCalciteNamedFunctionMap(String functionName, 
Method method, boolean nullableParameters,
+      boolean isVarArg) {
     if (method.getAnnotation(Deprecated.class) == null) {
       FUNCTION_MAP.put(functionName, ScalarFunctionImpl.create(method));
     }
@@ -146,7 +159,14 @@ public class FunctionRegistry {
   @Nullable
   public static FunctionInfo getFunctionInfo(String functionName, int 
numParameters) {
     Map<Integer, FunctionInfo> functionInfoMap = 
FUNCTION_INFO_MAP.get(canonicalize(functionName));
-    return functionInfoMap != null ? functionInfoMap.get(numParameters) : null;
+    if (functionInfoMap != null) {
+      FunctionInfo functionInfo = functionInfoMap.get(numParameters);
+      if (functionInfo != null) {
+        return functionInfo;
+      }
+      return functionInfoMap.get(VAR_ARG_KEY);
+    }
+    return null;
   }
 
   private static String canonicalize(String functionName) {
diff --git 
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArrayFunctions.java
 
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArrayFunctions.java
index d8529e9842..32f115b51a 100644
--- 
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArrayFunctions.java
+++ 
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArrayFunctions.java
@@ -23,6 +23,7 @@ import it.unimi.dsi.fastutil.ints.IntLinkedOpenHashSet;
 import it.unimi.dsi.fastutil.ints.IntSet;
 import it.unimi.dsi.fastutil.objects.ObjectLinkedOpenHashSet;
 import it.unimi.dsi.fastutil.objects.ObjectSet;
+import java.math.BigDecimal;
 import java.util.Arrays;
 import org.apache.commons.lang3.ArrayUtils;
 import org.apache.pinot.spi.annotations.ScalarFunction;
@@ -226,4 +227,62 @@ public class ArrayFunctions {
   public static String arrayElementAtString(String[] arr, int idx) {
     return idx > 0 && idx <= arr.length ? arr[idx - 1] : 
NullValuePlaceHolder.STRING;
   }
+
+  @ScalarFunction(names = {"array", "arrayValueConstructor"}, isVarArg = true)
+  public static Object arrayValueConstructor(Object... arr) {
+    if (arr.length == 0) {
+      return arr;
+    }
+    Class<?> clazz = arr[0].getClass();
+    if (clazz == Integer.class) {
+      int[] intArr = new int[arr.length];
+      for (int i = 0; i < arr.length; i++) {
+        intArr[i] = (Integer) arr[i];
+      }
+      return intArr;
+    }
+    if (clazz == Long.class) {
+      long[] longArr = new long[arr.length];
+      for (int i = 0; i < arr.length; i++) {
+        longArr[i] = (Long) arr[i];
+      }
+      return longArr;
+    }
+    if (clazz == Float.class) {
+      float[] floatArr = new float[arr.length];
+      for (int i = 0; i < arr.length; i++) {
+        floatArr[i] = (Float) arr[i];
+      }
+      return floatArr;
+    }
+    if (clazz == Double.class) {
+      double[] doubleArr = new double[arr.length];
+      for (int i = 0; i < arr.length; i++) {
+        doubleArr[i] = (Double) arr[i];
+      }
+      return doubleArr;
+    }
+    if (clazz == Boolean.class) {
+      boolean[] boolArr = new boolean[arr.length];
+      for (int i = 0; i < arr.length; i++) {
+        boolArr[i] = (Boolean) arr[i];
+      }
+      return boolArr;
+    }
+    if (clazz == BigDecimal.class) {
+      BigDecimal[] bigDecimalArr = new BigDecimal[arr.length];
+      for (int i = 0; i < arr.length; i++) {
+        bigDecimalArr[i] = (BigDecimal) arr[i];
+      }
+      return bigDecimalArr;
+    }
+    if (clazz == String.class) {
+      String[] strArr = new String[arr.length];
+      for (int i = 0; i < arr.length; i++) {
+        strArr[i] = (String) arr[i];
+      }
+      return strArr;
+    }
+    return arr;
+  }
 }
diff --git 
a/pinot-common/src/main/java/org/apache/pinot/common/request/context/LiteralContext.java
 
b/pinot-common/src/main/java/org/apache/pinot/common/request/context/LiteralContext.java
index 802ec48a17..5e80a4f6de 100644
--- 
a/pinot-common/src/main/java/org/apache/pinot/common/request/context/LiteralContext.java
+++ 
b/pinot-common/src/main/java/org/apache/pinot/common/request/context/LiteralContext.java
@@ -102,6 +102,26 @@ public class LiteralContext {
         _type = DataType.BYTES;
         _value = literal.getBinaryValue();
         break;
+      case INT_ARRAY_VALUE:
+        _type = DataType.INT;
+        _value = literal.getIntArrayValue();
+        break;
+      case LONG_ARRAY_VALUE:
+        _type = DataType.LONG;
+        _value = literal.getLongArrayValue();
+        break;
+      case FLOAT_ARRAY_VALUE:
+        _type = DataType.FLOAT;
+        _value = literal.getFloatArrayValue();
+        break;
+      case DOUBLE_ARRAY_VALUE:
+        _type = DataType.DOUBLE;
+        _value = literal.getDoubleArrayValue();
+        break;
+      case STRING_ARRAY_VALUE:
+        _type = DataType.STRING;
+        _value = literal.getStringArrayValue();
+        break;
       case NULL_VALUE:
         _type = DataType.UNKNOWN;
         _value = null;
diff --git 
a/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/CompileTimeFunctionsInvoker.java
 
b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/CompileTimeFunctionsInvoker.java
index 2cb89d5fe8..92bc6bb1b0 100644
--- 
a/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/CompileTimeFunctionsInvoker.java
+++ 
b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/CompileTimeFunctionsInvoker.java
@@ -84,8 +84,13 @@ public class CompileTimeFunctionsInvoker implements 
QueryRewriter {
         }
         try {
           FunctionInvoker invoker = new FunctionInvoker(functionInfo);
-          invoker.convertTypes(arguments);
-          Object result = invoker.invoke(arguments);
+          Object result;
+          if (invoker.getMethod().isVarArgs()) {
+            result = invoker.invoke(new Object[] {arguments});
+          } else {
+            invoker.convertTypes(arguments);
+            result = invoker.invoke(arguments);
+          }
           return RequestUtils.getLiteralExpression(result);
         } catch (Exception e) {
           throw new SqlCompilationException(
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ArrayLiteralTransformFunction.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ArrayLiteralTransformFunction.java
index 7619b53a3e..b2065e20d3 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ArrayLiteralTransformFunction.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ArrayLiteralTransformFunction.java
@@ -24,6 +24,7 @@ import java.util.Arrays;
 import java.util.List;
 import java.util.Map;
 import org.apache.pinot.common.request.context.ExpressionContext;
+import org.apache.pinot.common.request.context.LiteralContext;
 import org.apache.pinot.core.operator.ColumnContext;
 import org.apache.pinot.core.operator.blocks.ValueBlock;
 import org.apache.pinot.core.operator.transform.TransformResultMetadata;
@@ -54,6 +55,77 @@ public class ArrayLiteralTransformFunction implements 
TransformFunction {
   private double[][] _doubleArrayResult;
   private String[][] _stringArrayResult;
 
+  public ArrayLiteralTransformFunction(LiteralContext literalContext) {
+    List literalArray = (List) literalContext.getValue();
+    Preconditions.checkNotNull(literalArray);
+    if (literalArray.isEmpty()) {
+      _dataType = DataType.UNKNOWN;
+      _intArrayLiteral = new int[0];
+      _longArrayLiteral = new long[0];
+      _floatArrayLiteral = new float[0];
+      _doubleArrayLiteral = new double[0];
+      _stringArrayLiteral = new String[0];
+      return;
+    }
+    _dataType = literalContext.getType();
+    switch (_dataType) {
+      case INT:
+        _intArrayLiteral = new int[literalArray.size()];
+        for (int i = 0; i < _intArrayLiteral.length; i++) {
+          _intArrayLiteral[i] = (int) literalArray.get(i);
+        }
+        _longArrayLiteral = null;
+        _floatArrayLiteral = null;
+        _doubleArrayLiteral = null;
+        _stringArrayLiteral = null;
+        break;
+      case LONG:
+        _longArrayLiteral = new long[literalArray.size()];
+        for (int i = 0; i < _longArrayLiteral.length; i++) {
+          _longArrayLiteral[i] = (long) literalArray.get(i);
+        }
+        _intArrayLiteral = null;
+        _floatArrayLiteral = null;
+        _doubleArrayLiteral = null;
+        _stringArrayLiteral = null;
+        break;
+      case FLOAT:
+        _floatArrayLiteral = new float[literalArray.size()];
+        for (int i = 0; i < _floatArrayLiteral.length; i++) {
+          _floatArrayLiteral[i] = (float) literalArray.get(i);
+        }
+        _intArrayLiteral = null;
+        _longArrayLiteral = null;
+        _doubleArrayLiteral = null;
+        _stringArrayLiteral = null;
+        break;
+      case DOUBLE:
+        _doubleArrayLiteral = new double[literalArray.size()];
+        for (int i = 0; i < _doubleArrayLiteral.length; i++) {
+          _doubleArrayLiteral[i] = (double) literalArray.get(i);
+        }
+        _intArrayLiteral = null;
+        _longArrayLiteral = null;
+        _floatArrayLiteral = null;
+        _stringArrayLiteral = null;
+        break;
+      case STRING:
+        _stringArrayLiteral = new String[literalArray.size()];
+        for (int i = 0; i < _stringArrayLiteral.length; i++) {
+          _stringArrayLiteral[i] = (String) literalArray.get(i);
+        }
+        _intArrayLiteral = null;
+        _longArrayLiteral = null;
+        _floatArrayLiteral = null;
+        _doubleArrayLiteral = null;
+        break;
+      default:
+        throw new IllegalStateException(
+            "Illegal data type for ArrayLiteralTransformFunction: " + 
_dataType + ", literal contexts: "
+                + Arrays.toString(literalArray.toArray()));
+    }
+  }
+
   public ArrayLiteralTransformFunction(List<ExpressionContext> 
literalContexts) {
     Preconditions.checkNotNull(literalContexts);
     if (literalContexts.isEmpty()) {
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
index 49541841ca..82afb6dbeb 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
@@ -31,6 +31,7 @@ import org.apache.pinot.common.function.FunctionRegistry;
 import org.apache.pinot.common.function.TransformFunctionType;
 import org.apache.pinot.common.request.context.ExpressionContext;
 import org.apache.pinot.common.request.context.FunctionContext;
+import org.apache.pinot.common.request.context.LiteralContext;
 import org.apache.pinot.common.utils.HashUtil;
 import org.apache.pinot.core.geospatial.transform.function.GeoToH3Function;
 import org.apache.pinot.core.geospatial.transform.function.StAreaFunction;
@@ -335,7 +336,12 @@ public class TransformFunctionFactory {
         String columnName = expression.getIdentifier();
         return new IdentifierTransformFunction(columnName, 
columnContextMap.get(columnName));
       case LITERAL:
-        return 
queryContext.getOrComputeSharedValue(LiteralTransformFunction.class, 
expression.getLiteral(),
+        LiteralContext literal = expression.getLiteral();
+        if (literal.getValue() != null && literal.getValue() instanceof 
ArrayList) {
+          return 
queryContext.getOrComputeSharedValue(ArrayLiteralTransformFunction.class, 
literal,
+              ArrayLiteralTransformFunction::new);
+        }
+        return 
queryContext.getOrComputeSharedValue(LiteralTransformFunction.class, literal,
             LiteralTransformFunction::new);
       default:
         throw new IllegalStateException();
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/HistogramAggregationFunction.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/HistogramAggregationFunction.java
index 78bf60eca9..078420bd60 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/HistogramAggregationFunction.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/HistogramAggregationFunction.java
@@ -59,17 +59,25 @@ public class HistogramAggregationFunction extends 
BaseSingleInputAggregationFunc
     if (numArguments == 2) {
       ExpressionContext arrayExpression = arguments.get(1);
       Preconditions.checkArgument(
-          (arrayExpression.getType() == ExpressionContext.Type.FUNCTION) && 
(arrayExpression.getFunction()
-              .getFunctionName().equals(ARRAY_CONSTRUCTOR)),
+          // ARRAY function
+          ((arrayExpression.getType() == ExpressionContext.Type.FUNCTION)
+              && 
(arrayExpression.getFunction().getFunctionName().equals(ARRAY_CONSTRUCTOR)))
+              || ((arrayExpression.getType() == ExpressionContext.Type.LITERAL)
+              && (arrayExpression.getLiteral().getValue() instanceof List)),
           "Please use the format of `Histogram(columnName, ARRAY[1,10,100])` 
to specify the bin edges");
-      _bucketEdges = parseVector(arrayExpression.getFunction().getArguments());
+      if (arrayExpression.getType() == ExpressionContext.Type.FUNCTION) {
+        _bucketEdges = 
parseVector(arrayExpression.getFunction().getArguments());
+      } else {
+        _bucketEdges = parseVectorLiteral((List) 
arrayExpression.getLiteral().getValue());
+      }
       _lower = _bucketEdges[0];
       _upper = _bucketEdges[_bucketEdges.length - 1];
     } else {
       _isEqualLength = true;
       _lower = arguments.get(1).getLiteral().getDoubleValue();
       _upper = arguments.get(2).getLiteral().getDoubleValue();
-      int numBins = arguments.get(3).getLiteral().getIntValue();;
+      int numBins = arguments.get(3).getLiteral().getIntValue();
+      ;
       Preconditions.checkArgument(_upper > _lower,
           "The right most edge must be greater than left most edge, given %s 
and %s", _lower, _upper);
       Preconditions.checkArgument(numBins > 0, "The number of bins must be 
greater than zero, given %s", numBins);
@@ -109,8 +117,23 @@ public class HistogramAggregationFunction extends 
BaseSingleInputAggregationFunc
     return ret;
   }
 
+  private double[] parseVectorLiteral(List arrayStr) {
+    int len = arrayStr.size();
+    Preconditions.checkArgument(len > 1, "The number of bin edges must be 
greater than 1");
+    double[] ret = new double[len];
+    for (int i = 0; i < len; i++) {
+      // TODO: Represent infinity as literal instead of identifier
+      ret[i] = Double.parseDouble(arrayStr.get(i).toString());
+      if (i > 0) {
+        Preconditions.checkState(ret[i] > ret[i - 1], "The bin edges must be 
strictly increasing");
+      }
+    }
+    return ret;
+  }
+
   /**
    * Find the bin id for the input value. Use division for equal-length bins, 
and binary search otherwise.
+   *
    * @param val input value
    * @return bin id
    */
@@ -135,7 +158,7 @@ public class HistogramAggregationFunction extends 
BaseSingleInputAggregationFunc
           i = mid;
         }
       }
-     id = i;
+      id = i;
     }
     return id;
   }
diff --git 
a/pinot-core/src/test/java/org/apache/pinot/core/data/function/InbuiltFunctionEvaluatorTest.java
 
b/pinot-core/src/test/java/org/apache/pinot/core/data/function/InbuiltFunctionEvaluatorTest.java
index 82be9bcf52..07ba2e00c7 100644
--- 
a/pinot-core/src/test/java/org/apache/pinot/core/data/function/InbuiltFunctionEvaluatorTest.java
+++ 
b/pinot-core/src/test/java/org/apache/pinot/core/data/function/InbuiltFunctionEvaluatorTest.java
@@ -131,7 +131,7 @@ public class InbuiltFunctionEvaluatorTest {
       throws Exception {
     MyFunc myFunc = new MyFunc();
     Method method = 
myFunc.getClass().getDeclaredMethod("appendToStringAndReturn", String.class);
-    FunctionRegistry.registerFunction(method, false, false);
+    FunctionRegistry.registerFunction(method, false, false, false);
     String expression = "appendToStringAndReturn('test ')";
     InbuiltFunctionEvaluator evaluator = new 
InbuiltFunctionEvaluator(expression);
     assertTrue(evaluator.getArguments().isEmpty());
diff --git 
a/pinot-core/src/test/java/org/apache/pinot/queries/HistogramQueriesTest.java 
b/pinot-core/src/test/java/org/apache/pinot/queries/HistogramQueriesTest.java
index 9fe29ec5d0..c400546c4f 100644
--- 
a/pinot-core/src/test/java/org/apache/pinot/queries/HistogramQueriesTest.java
+++ 
b/pinot-core/src/test/java/org/apache/pinot/queries/HistogramQueriesTest.java
@@ -312,7 +312,7 @@ public class HistogramQueriesTest extends BaseQueriesTest {
       operator.nextBlock();
     } catch (Exception e) {
       assertEquals(e.getMessage(),
-          "Invalid aggregation function: 
histogram(intColumn,arrayvalueconstructor('0')); Reason: The number of "
+          "Invalid aggregation function: histogram(intColumn,'[0]'); Reason: 
The number of "
               + "bin edges must be greater than 1");
     }
 
@@ -333,7 +333,7 @@ public class HistogramQueriesTest extends BaseQueriesTest {
       operator.nextBlock();
     } catch (Exception e) {
       assertEquals(e.getMessage(),
-          "Invalid aggregation function: 
histogram(intColumn,arrayvalueconstructor('0','0','1','2')); Reason: The "
+          "Invalid aggregation function: histogram(intColumn,'[0, 0, 1, 2]'); 
Reason: The "
               + "bin edges must be strictly increasing");
     }
 
diff --git 
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/ArrayTest.java
 
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/ArrayTest.java
index 9275d3ce90..19bf45b373 100644
--- 
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/ArrayTest.java
+++ 
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/ArrayTest.java
@@ -217,6 +217,22 @@ public class ArrayTest extends 
CustomDataQueryClusterIntegrationTest {
     Assert.assertEquals(row.get(0).get(2).asInt(), 3);
   }
 
+  @Test(dataProvider = "useBothQueryEngines")
+  public void testIntArrayLiteralWithoutFrom(boolean useMultiStageQueryEngine)
+      throws Exception {
+    setUseMultiStageQueryEngine(useMultiStageQueryEngine);
+    String query = "SELECT ARRAY[1,2,3] ";
+    JsonNode jsonNode = postQuery(query);
+    JsonNode rows = jsonNode.get("resultTable").get("rows");
+    Assert.assertEquals(rows.size(), 1);
+    JsonNode row = rows.get(0);
+    Assert.assertEquals(row.size(), 1);
+    Assert.assertEquals(row.get(0).size(), 3);
+    Assert.assertEquals(row.get(0).get(0).asInt(), 1);
+    Assert.assertEquals(row.get(0).get(1).asInt(), 2);
+    Assert.assertEquals(row.get(0).get(2).asInt(), 3);
+  }
+
   @Test(dataProvider = "useBothQueryEngines")
   public void testLongArrayLiteral(boolean useMultiStageQueryEngine)
       throws Exception {
@@ -236,6 +252,22 @@ public class ArrayTest extends 
CustomDataQueryClusterIntegrationTest {
     Assert.assertEquals(row.get(0).get(2).asLong(), 2147483650L);
   }
 
+  @Test(dataProvider = "useBothQueryEngines")
+  public void testLongArrayLiteralWithoutFrom(boolean useMultiStageQueryEngine)
+      throws Exception {
+    setUseMultiStageQueryEngine(useMultiStageQueryEngine);
+    String query = "SELECT ARRAY[2147483648,2147483649,2147483650]";
+    JsonNode jsonNode = postQuery(query);
+    JsonNode rows = jsonNode.get("resultTable").get("rows");
+    Assert.assertEquals(rows.size(), 1);
+    JsonNode row = rows.get(0);
+    Assert.assertEquals(row.size(), 1);
+    Assert.assertEquals(row.get(0).size(), 3);
+    Assert.assertEquals(row.get(0).get(0).asLong(), 2147483648L);
+    Assert.assertEquals(row.get(0).get(1).asLong(), 2147483649L);
+    Assert.assertEquals(row.get(0).get(2).asLong(), 2147483650L);
+  }
+
   @Test(dataProvider = "useBothQueryEngines")
   public void testFloatArrayLiteral(boolean useMultiStageQueryEngine)
       throws Exception {
@@ -255,6 +287,22 @@ public class ArrayTest extends 
CustomDataQueryClusterIntegrationTest {
     Assert.assertEquals(row.get(0).get(2).asDouble(), 0.3);
   }
 
+  @Test(dataProvider = "useBothQueryEngines")
+  public void testFloatArrayLiteralWithoutFrom(boolean 
useMultiStageQueryEngine)
+      throws Exception {
+    setUseMultiStageQueryEngine(useMultiStageQueryEngine);
+    String query = "SELECT ARRAY[0.1, 0.2, 0.3]";
+    JsonNode jsonNode = postQuery(query);
+    JsonNode rows = jsonNode.get("resultTable").get("rows");
+    Assert.assertEquals(rows.size(), 1);
+    JsonNode row = rows.get(0);
+    Assert.assertEquals(row.size(), 1);
+    Assert.assertEquals(row.get(0).size(), 3);
+    Assert.assertEquals(row.get(0).get(0).asDouble(), 0.1);
+    Assert.assertEquals(row.get(0).get(1).asDouble(), 0.2);
+    Assert.assertEquals(row.get(0).get(2).asDouble(), 0.3);
+  }
+
   @Test(dataProvider = "useBothQueryEngines")
   public void testDoubleArrayLiteral(boolean useMultiStageQueryEngine)
       throws Exception {
@@ -274,6 +322,22 @@ public class ArrayTest extends 
CustomDataQueryClusterIntegrationTest {
     Assert.assertEquals(row.get(0).get(2).asDouble(), 0.3);
   }
 
+  @Test(dataProvider = "useBothQueryEngines")
+  public void testDoubleArrayLiteralWithoutFrom(boolean 
useMultiStageQueryEngine)
+      throws Exception {
+    setUseMultiStageQueryEngine(useMultiStageQueryEngine);
+    String query = "SELECT ARRAY[CAST(0.1 AS DOUBLE), CAST(0.2 AS DOUBLE), 
CAST(0.3 AS DOUBLE)]";
+    JsonNode jsonNode = postQuery(query);
+    JsonNode rows = jsonNode.get("resultTable").get("rows");
+    Assert.assertEquals(rows.size(), 1);
+    JsonNode row = rows.get(0);
+    Assert.assertEquals(row.size(), 1);
+    Assert.assertEquals(row.get(0).size(), 3);
+    Assert.assertEquals(row.get(0).get(0).asDouble(), 0.1);
+    Assert.assertEquals(row.get(0).get(1).asDouble(), 0.2);
+    Assert.assertEquals(row.get(0).get(2).asDouble(), 0.3);
+  }
+
   @Test(dataProvider = "useBothQueryEngines")
   public void testStringArrayLiteral(boolean useMultiStageQueryEngine)
       throws Exception {
@@ -293,6 +357,22 @@ public class ArrayTest extends 
CustomDataQueryClusterIntegrationTest {
     Assert.assertEquals(row.get(0).get(2).asText(), "ccc");
   }
 
+  @Test(dataProvider = "useBothQueryEngines")
+  public void testStringArrayLiteralWithoutFrom(boolean 
useMultiStageQueryEngine)
+      throws Exception {
+    setUseMultiStageQueryEngine(useMultiStageQueryEngine);
+    String query = "SELECT ARRAY['a', 'bb', 'ccc']";
+    JsonNode jsonNode = postQuery(query);
+    JsonNode rows = jsonNode.get("resultTable").get("rows");
+    Assert.assertEquals(rows.size(), 1);
+    JsonNode row = rows.get(0);
+    Assert.assertEquals(row.size(), 1);
+    Assert.assertEquals(row.get(0).size(), 3);
+    Assert.assertEquals(row.get(0).get(0).asText(), "a");
+    Assert.assertEquals(row.get(0).get(1).asText(), "bb");
+    Assert.assertEquals(row.get(0).get(2).asText(), "ccc");
+  }
+
   @Override
   public String getTableName() {
     return DEFAULT_TABLE_NAME;
@@ -352,7 +432,7 @@ public class ArrayTest extends 
CustomDataQueryClusterIntegrationTest {
         fileWriter.append(recordCache.get((int) (i % (getCountStarResult() / 
10)), () -> {
               // create avro record
               GenericData.Record record = new GenericData.Record(avroSchema);
-              record.put(BOOLEAN_COLUMN, RANDOM.nextBoolean());
+              record.put(BOOLEAN_COLUMN, finalI % 4 == 0 || finalI % 4 == 1);
               record.put(INT_COLUMN, finalI);
               record.put(LONG_COLUMN, finalI);
               record.put(FLOAT_COLUMN, finalI + RANDOM.nextFloat());
diff --git 
a/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotEvaluateLiteralRule.java
 
b/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotEvaluateLiteralRule.java
index 02d44ade9b..ea0d531faa 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotEvaluateLiteralRule.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotEvaluateLiteralRule.java
@@ -159,8 +159,12 @@ public class PinotEvaluateLiteralRule {
     Object resultValue;
     try {
       FunctionInvoker invoker = new FunctionInvoker(functionInfo);
-      invoker.convertTypes(arguments);
-      resultValue = invoker.invoke(arguments);
+      if (functionInfo.getMethod().isVarArgs()) {
+        resultValue = invoker.invoke(new Object[] {arguments});
+      } else {
+        invoker.convertTypes(arguments);
+        resultValue = invoker.invoke(arguments);
+      }
       if (rexNodeType.getSqlTypeName() == SqlTypeName.ARRAY) {
         RelDataType componentType = rexNodeType.getComponentType();
         if (componentType != null) {
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FunctionOperand.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FunctionOperand.java
index 9e777d7867..cccc065be3 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FunctionOperand.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FunctionOperand.java
@@ -50,11 +50,13 @@ public class FunctionOperand implements TransformOperand {
     FunctionInfo functionInfo = 
FunctionRegistry.getFunctionInfo(canonicalName, numOperands);
     Preconditions.checkState(functionInfo != null, "Cannot find function with 
name: %s", canonicalName);
     _functionInvoker = new FunctionInvoker(functionInfo);
-    Class<?>[] parameterClasses = _functionInvoker.getParameterClasses();
-    PinotDataType[] parameterTypes = _functionInvoker.getParameterTypes();
-    for (int i = 0; i < numOperands; i++) {
-      Preconditions.checkState(parameterTypes[i] != null, "Unsupported 
parameter class: %s for method: %s",
-          parameterClasses[i], functionInfo.getMethod());
+    if (!_functionInvoker.getMethod().isVarArgs()) {
+      Class<?>[] parameterClasses = _functionInvoker.getParameterClasses();
+      PinotDataType[] parameterTypes = _functionInvoker.getParameterTypes();
+      for (int i = 0; i < numOperands; i++) {
+        Preconditions.checkState(parameterTypes[i] != null, "Unsupported 
parameter class: %s for method: %s",
+            parameterClasses[i], functionInfo.getMethod());
+      }
     }
     ColumnDataType functionInvokerResultType = 
FunctionUtils.getColumnDataType(_functionInvoker.getResultClass());
     // Handle unrecognized result class with STRING
@@ -80,8 +82,13 @@ public class FunctionOperand implements TransformOperand {
       _reusableOperandHolder[i] = value != null ? 
operand.getResultType().toExternal(value) : null;
     }
     // TODO: Optimize per record conversion
-    _functionInvoker.convertTypes(_reusableOperandHolder);
-    Object result = _functionInvoker.invoke(_reusableOperandHolder);
+    Object result;
+    if (_functionInvoker.getMethod().isVarArgs()) {
+      result = _functionInvoker.invoke(new Object[]{_reusableOperandHolder});
+    } else {
+      _functionInvoker.convertTypes(_reusableOperandHolder);
+      result = _functionInvoker.invoke(_reusableOperandHolder);
+    }
     return result != null ? 
TypeUtils.convert(_functionInvokerResultType.toInternal(result),
         _resultType.getStoredType()) : null;
   }
diff --git 
a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/function/InbuiltFunctionEvaluator.java
 
b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/function/InbuiltFunctionEvaluator.java
index 823dd23b88..8c3909e784 100644
--- 
a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/function/InbuiltFunctionEvaluator.java
+++ 
b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/function/InbuiltFunctionEvaluator.java
@@ -233,6 +233,9 @@ public class InbuiltFunctionEvaluator implements 
FunctionEvaluator {
             }
           }
         }
+        if (_functionInvoker.getMethod().isVarArgs()) {
+          return _functionInvoker.invoke(new Object[]{_arguments});
+        }
         _functionInvoker.convertTypes(_arguments);
         return _functionInvoker.invoke(_arguments);
       } catch (Exception e) {
@@ -256,6 +259,9 @@ public class InbuiltFunctionEvaluator implements 
FunctionEvaluator {
             }
           }
         }
+        if (_functionInvoker.getMethod().isVarArgs()) {
+          return _functionInvoker.invoke(new Object[]{_arguments});
+        }
         _functionInvoker.convertTypes(_arguments);
         return _functionInvoker.invoke(_arguments);
       } catch (Exception e) {
diff --git 
a/pinot-spi/src/main/java/org/apache/pinot/spi/annotations/ScalarFunction.java 
b/pinot-spi/src/main/java/org/apache/pinot/spi/annotations/ScalarFunction.java
index 46a743d52c..0a647a8792 100644
--- 
a/pinot-spi/src/main/java/org/apache/pinot/spi/annotations/ScalarFunction.java
+++ 
b/pinot-spi/src/main/java/org/apache/pinot/spi/annotations/ScalarFunction.java
@@ -59,4 +59,9 @@ public @interface ScalarFunction {
   boolean nullableParameters() default false;
 
   boolean isPlaceholder() default false;
+
+  /**
+   * Whether the scalar function takes various number of arguments.
+   */
+  boolean isVarArg() default false;
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to