This is an automated email from the ASF dual-hosted git repository.
rongr pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push:
new 6bb387a10a Support array gen in literal evaluation (#12278)
6bb387a10a is described below
commit 6bb387a10a311af8ae166e93b5ba629fdb8cc4b8
Author: Xiang Fu <[email protected]>
AuthorDate: Thu Jan 18 07:53:22 2024 -0800
Support array gen in literal evaluation (#12278)
---
.../pinot/common/function/FunctionRegistry.java | 46 ++++++++----
.../common/function/scalar/ArrayFunctions.java | 59 ++++++++++++++++
.../common/request/context/LiteralContext.java | 20 ++++++
.../rewriter/CompileTimeFunctionsInvoker.java | 9 ++-
.../function/ArrayLiteralTransformFunction.java | 72 +++++++++++++++++++
.../function/TransformFunctionFactory.java | 8 ++-
.../function/HistogramAggregationFunction.java | 33 +++++++--
.../function/InbuiltFunctionEvaluatorTest.java | 2 +-
.../apache/pinot/queries/HistogramQueriesTest.java | 4 +-
.../pinot/integration/tests/custom/ArrayTest.java | 82 +++++++++++++++++++++-
.../rel/rules/PinotEvaluateLiteralRule.java | 8 ++-
.../runtime/operator/operands/FunctionOperand.java | 21 ++++--
.../local/function/InbuiltFunctionEvaluator.java | 6 ++
.../pinot/spi/annotations/ScalarFunction.java | 5 ++
14 files changed, 341 insertions(+), 34 deletions(-)
diff --git
a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java
b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java
index 97fa972bee..deb1673d8b 100644
---
a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java
+++
b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java
@@ -52,6 +52,8 @@ public class FunctionRegistry {
// This FUNCTION_MAP is used by Calcite function catalog to look up function
by function signature.
private static final NameMultimap<Function> FUNCTION_MAP = new
NameMultimap<>();
+ private static final int VAR_ARG_KEY = -1;
+
/**
* Registers the scalar functions via reflection.
* NOTE: In order to plugin methods using reflection, the methods should be
inside a class that includes ".function."
@@ -69,12 +71,14 @@ public class FunctionRegistry {
// Annotated function names
String[] scalarFunctionNames = scalarFunction.names();
boolean nullableParameters = scalarFunction.nullableParameters();
+ boolean isPlaceholder = scalarFunction.isPlaceholder();
+ boolean isVarArg = scalarFunction.isVarArg();
if (scalarFunctionNames.length > 0) {
for (String name : scalarFunctionNames) {
- FunctionRegistry.registerFunction(name, method,
nullableParameters, scalarFunction.isPlaceholder());
+ FunctionRegistry.registerFunction(name, method,
nullableParameters, isPlaceholder, isVarArg);
}
} else {
- FunctionRegistry.registerFunction(method, nullableParameters,
scalarFunction.isPlaceholder());
+ FunctionRegistry.registerFunction(method, nullableParameters,
isPlaceholder, isVarArg);
}
}
}
@@ -93,31 +97,40 @@ public class FunctionRegistry {
/**
* Registers a method with the name of the method.
*/
- public static void registerFunction(Method method, boolean
nullableParameters, boolean isPlaceholder) {
- registerFunction(method.getName(), method, nullableParameters,
isPlaceholder);
+ public static void registerFunction(Method method, boolean
nullableParameters, boolean isPlaceholder,
+ boolean isVarArg) {
+ registerFunction(method.getName(), method, nullableParameters,
isPlaceholder, isVarArg);
}
/**
* Registers a method with the given function name.
*/
public static void registerFunction(String functionName, Method method,
boolean nullableParameters,
- boolean isPlaceholder) {
+ boolean isPlaceholder, boolean isVarArg) {
if (!isPlaceholder) {
- registerFunctionInfoMap(functionName, method, nullableParameters);
+ registerFunctionInfoMap(functionName, method, nullableParameters,
isVarArg);
}
- registerCalciteNamedFunctionMap(functionName, method, nullableParameters);
+ registerCalciteNamedFunctionMap(functionName, method, nullableParameters,
isVarArg);
}
- private static void registerFunctionInfoMap(String functionName, Method
method, boolean nullableParameters) {
+ private static void registerFunctionInfoMap(String functionName, Method
method, boolean nullableParameters,
+ boolean isVarArg) {
FunctionInfo functionInfo = new FunctionInfo(method,
method.getDeclaringClass(), nullableParameters);
String canonicalName = canonicalize(functionName);
Map<Integer, FunctionInfo> functionInfoMap =
FUNCTION_INFO_MAP.computeIfAbsent(canonicalName, k -> new HashMap<>());
- FunctionInfo existFunctionInfo =
functionInfoMap.put(method.getParameterCount(), functionInfo);
- Preconditions.checkState(existFunctionInfo == null ||
existFunctionInfo.getMethod() == functionInfo.getMethod(),
- "Function: %s with %s parameters is already registered", functionName,
method.getParameterCount());
+ if (isVarArg) {
+ FunctionInfo existFunctionInfo = functionInfoMap.put(VAR_ARG_KEY,
functionInfo);
+ Preconditions.checkState(existFunctionInfo == null ||
existFunctionInfo.getMethod() == functionInfo.getMethod(),
+ "Function: %s with variable number of parameters is already
registered", functionName);
+ } else {
+ FunctionInfo existFunctionInfo =
functionInfoMap.put(method.getParameterCount(), functionInfo);
+ Preconditions.checkState(existFunctionInfo == null ||
existFunctionInfo.getMethod() == functionInfo.getMethod(),
+ "Function: %s with %s parameters is already registered",
functionName, method.getParameterCount());
+ }
}
- private static void registerCalciteNamedFunctionMap(String functionName,
Method method, boolean nullableParameters) {
+ private static void registerCalciteNamedFunctionMap(String functionName,
Method method, boolean nullableParameters,
+ boolean isVarArg) {
if (method.getAnnotation(Deprecated.class) == null) {
FUNCTION_MAP.put(functionName, ScalarFunctionImpl.create(method));
}
@@ -146,7 +159,14 @@ public class FunctionRegistry {
@Nullable
public static FunctionInfo getFunctionInfo(String functionName, int
numParameters) {
Map<Integer, FunctionInfo> functionInfoMap =
FUNCTION_INFO_MAP.get(canonicalize(functionName));
- return functionInfoMap != null ? functionInfoMap.get(numParameters) : null;
+ if (functionInfoMap != null) {
+ FunctionInfo functionInfo = functionInfoMap.get(numParameters);
+ if (functionInfo != null) {
+ return functionInfo;
+ }
+ return functionInfoMap.get(VAR_ARG_KEY);
+ }
+ return null;
}
private static String canonicalize(String functionName) {
diff --git
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArrayFunctions.java
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArrayFunctions.java
index d8529e9842..32f115b51a 100644
---
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArrayFunctions.java
+++
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArrayFunctions.java
@@ -23,6 +23,7 @@ import it.unimi.dsi.fastutil.ints.IntLinkedOpenHashSet;
import it.unimi.dsi.fastutil.ints.IntSet;
import it.unimi.dsi.fastutil.objects.ObjectLinkedOpenHashSet;
import it.unimi.dsi.fastutil.objects.ObjectSet;
+import java.math.BigDecimal;
import java.util.Arrays;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.pinot.spi.annotations.ScalarFunction;
@@ -226,4 +227,62 @@ public class ArrayFunctions {
public static String arrayElementAtString(String[] arr, int idx) {
return idx > 0 && idx <= arr.length ? arr[idx - 1] :
NullValuePlaceHolder.STRING;
}
+
+ @ScalarFunction(names = {"array", "arrayValueConstructor"}, isVarArg = true)
+ public static Object arrayValueConstructor(Object... arr) {
+ if (arr.length == 0) {
+ return arr;
+ }
+ Class<?> clazz = arr[0].getClass();
+ if (clazz == Integer.class) {
+ int[] intArr = new int[arr.length];
+ for (int i = 0; i < arr.length; i++) {
+ intArr[i] = (Integer) arr[i];
+ }
+ return intArr;
+ }
+ if (clazz == Long.class) {
+ long[] longArr = new long[arr.length];
+ for (int i = 0; i < arr.length; i++) {
+ longArr[i] = (Long) arr[i];
+ }
+ return longArr;
+ }
+ if (clazz == Float.class) {
+ float[] floatArr = new float[arr.length];
+ for (int i = 0; i < arr.length; i++) {
+ floatArr[i] = (Float) arr[i];
+ }
+ return floatArr;
+ }
+ if (clazz == Double.class) {
+ double[] doubleArr = new double[arr.length];
+ for (int i = 0; i < arr.length; i++) {
+ doubleArr[i] = (Double) arr[i];
+ }
+ return doubleArr;
+ }
+ if (clazz == Boolean.class) {
+ boolean[] boolArr = new boolean[arr.length];
+ for (int i = 0; i < arr.length; i++) {
+ boolArr[i] = (Boolean) arr[i];
+ }
+ return boolArr;
+ }
+ if (clazz == BigDecimal.class) {
+ BigDecimal[] bigDecimalArr = new BigDecimal[arr.length];
+ for (int i = 0; i < arr.length; i++) {
+ bigDecimalArr[i] = (BigDecimal) arr[i];
+ }
+ return bigDecimalArr;
+ }
+ if (clazz == String.class) {
+ String[] strArr = new String[arr.length];
+ for (int i = 0; i < arr.length; i++) {
+ strArr[i] = (String) arr[i];
+ }
+ return strArr;
+ }
+ return arr;
+ }
}
diff --git
a/pinot-common/src/main/java/org/apache/pinot/common/request/context/LiteralContext.java
b/pinot-common/src/main/java/org/apache/pinot/common/request/context/LiteralContext.java
index 802ec48a17..5e80a4f6de 100644
---
a/pinot-common/src/main/java/org/apache/pinot/common/request/context/LiteralContext.java
+++
b/pinot-common/src/main/java/org/apache/pinot/common/request/context/LiteralContext.java
@@ -102,6 +102,26 @@ public class LiteralContext {
_type = DataType.BYTES;
_value = literal.getBinaryValue();
break;
+ case INT_ARRAY_VALUE:
+ _type = DataType.INT;
+ _value = literal.getIntArrayValue();
+ break;
+ case LONG_ARRAY_VALUE:
+ _type = DataType.LONG;
+ _value = literal.getLongArrayValue();
+ break;
+ case FLOAT_ARRAY_VALUE:
+ _type = DataType.FLOAT;
+ _value = literal.getFloatArrayValue();
+ break;
+ case DOUBLE_ARRAY_VALUE:
+ _type = DataType.DOUBLE;
+ _value = literal.getDoubleArrayValue();
+ break;
+ case STRING_ARRAY_VALUE:
+ _type = DataType.STRING;
+ _value = literal.getStringArrayValue();
+ break;
case NULL_VALUE:
_type = DataType.UNKNOWN;
_value = null;
diff --git
a/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/CompileTimeFunctionsInvoker.java
b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/CompileTimeFunctionsInvoker.java
index 2cb89d5fe8..92bc6bb1b0 100644
---
a/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/CompileTimeFunctionsInvoker.java
+++
b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/CompileTimeFunctionsInvoker.java
@@ -84,8 +84,13 @@ public class CompileTimeFunctionsInvoker implements
QueryRewriter {
}
try {
FunctionInvoker invoker = new FunctionInvoker(functionInfo);
- invoker.convertTypes(arguments);
- Object result = invoker.invoke(arguments);
+ Object result;
+ if (invoker.getMethod().isVarArgs()) {
+ result = invoker.invoke(new Object[] {arguments});
+ } else {
+ invoker.convertTypes(arguments);
+ result = invoker.invoke(arguments);
+ }
return RequestUtils.getLiteralExpression(result);
} catch (Exception e) {
throw new SqlCompilationException(
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ArrayLiteralTransformFunction.java
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ArrayLiteralTransformFunction.java
index 7619b53a3e..b2065e20d3 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ArrayLiteralTransformFunction.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ArrayLiteralTransformFunction.java
@@ -24,6 +24,7 @@ import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.apache.pinot.common.request.context.ExpressionContext;
+import org.apache.pinot.common.request.context.LiteralContext;
import org.apache.pinot.core.operator.ColumnContext;
import org.apache.pinot.core.operator.blocks.ValueBlock;
import org.apache.pinot.core.operator.transform.TransformResultMetadata;
@@ -54,6 +55,77 @@ public class ArrayLiteralTransformFunction implements
TransformFunction {
private double[][] _doubleArrayResult;
private String[][] _stringArrayResult;
+ public ArrayLiteralTransformFunction(LiteralContext literalContext) {
+ List literalArray = (List) literalContext.getValue();
+ Preconditions.checkNotNull(literalArray);
+ if (literalArray.isEmpty()) {
+ _dataType = DataType.UNKNOWN;
+ _intArrayLiteral = new int[0];
+ _longArrayLiteral = new long[0];
+ _floatArrayLiteral = new float[0];
+ _doubleArrayLiteral = new double[0];
+ _stringArrayLiteral = new String[0];
+ return;
+ }
+ _dataType = literalContext.getType();
+ switch (_dataType) {
+ case INT:
+ _intArrayLiteral = new int[literalArray.size()];
+ for (int i = 0; i < _intArrayLiteral.length; i++) {
+ _intArrayLiteral[i] = (int) literalArray.get(i);
+ }
+ _longArrayLiteral = null;
+ _floatArrayLiteral = null;
+ _doubleArrayLiteral = null;
+ _stringArrayLiteral = null;
+ break;
+ case LONG:
+ _longArrayLiteral = new long[literalArray.size()];
+ for (int i = 0; i < _longArrayLiteral.length; i++) {
+ _longArrayLiteral[i] = (long) literalArray.get(i);
+ }
+ _intArrayLiteral = null;
+ _floatArrayLiteral = null;
+ _doubleArrayLiteral = null;
+ _stringArrayLiteral = null;
+ break;
+ case FLOAT:
+ _floatArrayLiteral = new float[literalArray.size()];
+ for (int i = 0; i < _floatArrayLiteral.length; i++) {
+ _floatArrayLiteral[i] = (float) literalArray.get(i);
+ }
+ _intArrayLiteral = null;
+ _longArrayLiteral = null;
+ _doubleArrayLiteral = null;
+ _stringArrayLiteral = null;
+ break;
+ case DOUBLE:
+ _doubleArrayLiteral = new double[literalArray.size()];
+ for (int i = 0; i < _doubleArrayLiteral.length; i++) {
+ _doubleArrayLiteral[i] = (double) literalArray.get(i);
+ }
+ _intArrayLiteral = null;
+ _longArrayLiteral = null;
+ _floatArrayLiteral = null;
+ _stringArrayLiteral = null;
+ break;
+ case STRING:
+ _stringArrayLiteral = new String[literalArray.size()];
+ for (int i = 0; i < _stringArrayLiteral.length; i++) {
+ _stringArrayLiteral[i] = (String) literalArray.get(i);
+ }
+ _intArrayLiteral = null;
+ _longArrayLiteral = null;
+ _floatArrayLiteral = null;
+ _doubleArrayLiteral = null;
+ break;
+ default:
+ throw new IllegalStateException(
+ "Illegal data type for ArrayLiteralTransformFunction: " +
_dataType + ", literal contexts: "
+ + Arrays.toString(literalArray.toArray()));
+ }
+ }
+
public ArrayLiteralTransformFunction(List<ExpressionContext>
literalContexts) {
Preconditions.checkNotNull(literalContexts);
if (literalContexts.isEmpty()) {
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
index 49541841ca..82afb6dbeb 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
@@ -31,6 +31,7 @@ import org.apache.pinot.common.function.FunctionRegistry;
import org.apache.pinot.common.function.TransformFunctionType;
import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.common.request.context.FunctionContext;
+import org.apache.pinot.common.request.context.LiteralContext;
import org.apache.pinot.common.utils.HashUtil;
import org.apache.pinot.core.geospatial.transform.function.GeoToH3Function;
import org.apache.pinot.core.geospatial.transform.function.StAreaFunction;
@@ -335,7 +336,12 @@ public class TransformFunctionFactory {
String columnName = expression.getIdentifier();
return new IdentifierTransformFunction(columnName,
columnContextMap.get(columnName));
case LITERAL:
- return
queryContext.getOrComputeSharedValue(LiteralTransformFunction.class,
expression.getLiteral(),
+ LiteralContext literal = expression.getLiteral();
+ if (literal.getValue() != null && literal.getValue() instanceof
ArrayList) {
+ return
queryContext.getOrComputeSharedValue(ArrayLiteralTransformFunction.class,
literal,
+ ArrayLiteralTransformFunction::new);
+ }
+ return
queryContext.getOrComputeSharedValue(LiteralTransformFunction.class, literal,
LiteralTransformFunction::new);
default:
throw new IllegalStateException();
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/HistogramAggregationFunction.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/HistogramAggregationFunction.java
index 78bf60eca9..078420bd60 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/HistogramAggregationFunction.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/HistogramAggregationFunction.java
@@ -59,17 +59,25 @@ public class HistogramAggregationFunction extends
BaseSingleInputAggregationFunc
if (numArguments == 2) {
ExpressionContext arrayExpression = arguments.get(1);
Preconditions.checkArgument(
- (arrayExpression.getType() == ExpressionContext.Type.FUNCTION) &&
(arrayExpression.getFunction()
- .getFunctionName().equals(ARRAY_CONSTRUCTOR)),
+ // ARRAY function
+ ((arrayExpression.getType() == ExpressionContext.Type.FUNCTION)
+ &&
(arrayExpression.getFunction().getFunctionName().equals(ARRAY_CONSTRUCTOR)))
+ || ((arrayExpression.getType() == ExpressionContext.Type.LITERAL)
+ && (arrayExpression.getLiteral().getValue() instanceof List)),
"Please use the format of `Histogram(columnName, ARRAY[1,10,100])`
to specify the bin edges");
- _bucketEdges = parseVector(arrayExpression.getFunction().getArguments());
+ if (arrayExpression.getType() == ExpressionContext.Type.FUNCTION) {
+ _bucketEdges =
parseVector(arrayExpression.getFunction().getArguments());
+ } else {
+ _bucketEdges = parseVectorLiteral((List)
arrayExpression.getLiteral().getValue());
+ }
_lower = _bucketEdges[0];
_upper = _bucketEdges[_bucketEdges.length - 1];
} else {
_isEqualLength = true;
_lower = arguments.get(1).getLiteral().getDoubleValue();
_upper = arguments.get(2).getLiteral().getDoubleValue();
- int numBins = arguments.get(3).getLiteral().getIntValue();;
+ int numBins = arguments.get(3).getLiteral().getIntValue();
+ ;
Preconditions.checkArgument(_upper > _lower,
"The right most edge must be greater than left most edge, given %s
and %s", _lower, _upper);
Preconditions.checkArgument(numBins > 0, "The number of bins must be
greater than zero, given %s", numBins);
@@ -109,8 +117,23 @@ public class HistogramAggregationFunction extends
BaseSingleInputAggregationFunc
return ret;
}
+ private double[] parseVectorLiteral(List arrayStr) {
+ int len = arrayStr.size();
+ Preconditions.checkArgument(len > 1, "The number of bin edges must be
greater than 1");
+ double[] ret = new double[len];
+ for (int i = 0; i < len; i++) {
+ // TODO: Represent infinity as literal instead of identifier
+ ret[i] = Double.parseDouble(arrayStr.get(i).toString());
+ if (i > 0) {
+ Preconditions.checkState(ret[i] > ret[i - 1], "The bin edges must be
strictly increasing");
+ }
+ }
+ return ret;
+ }
+
/**
* Find the bin id for the input value. Use division for equal-length bins,
and binary search otherwise.
+ *
* @param val input value
* @return bin id
*/
@@ -135,7 +158,7 @@ public class HistogramAggregationFunction extends
BaseSingleInputAggregationFunc
i = mid;
}
}
- id = i;
+ id = i;
}
return id;
}
diff --git
a/pinot-core/src/test/java/org/apache/pinot/core/data/function/InbuiltFunctionEvaluatorTest.java
b/pinot-core/src/test/java/org/apache/pinot/core/data/function/InbuiltFunctionEvaluatorTest.java
index 82be9bcf52..07ba2e00c7 100644
---
a/pinot-core/src/test/java/org/apache/pinot/core/data/function/InbuiltFunctionEvaluatorTest.java
+++
b/pinot-core/src/test/java/org/apache/pinot/core/data/function/InbuiltFunctionEvaluatorTest.java
@@ -131,7 +131,7 @@ public class InbuiltFunctionEvaluatorTest {
throws Exception {
MyFunc myFunc = new MyFunc();
Method method =
myFunc.getClass().getDeclaredMethod("appendToStringAndReturn", String.class);
- FunctionRegistry.registerFunction(method, false, false);
+ FunctionRegistry.registerFunction(method, false, false, false);
String expression = "appendToStringAndReturn('test ')";
InbuiltFunctionEvaluator evaluator = new
InbuiltFunctionEvaluator(expression);
assertTrue(evaluator.getArguments().isEmpty());
diff --git
a/pinot-core/src/test/java/org/apache/pinot/queries/HistogramQueriesTest.java
b/pinot-core/src/test/java/org/apache/pinot/queries/HistogramQueriesTest.java
index 9fe29ec5d0..c400546c4f 100644
---
a/pinot-core/src/test/java/org/apache/pinot/queries/HistogramQueriesTest.java
+++
b/pinot-core/src/test/java/org/apache/pinot/queries/HistogramQueriesTest.java
@@ -312,7 +312,7 @@ public class HistogramQueriesTest extends BaseQueriesTest {
operator.nextBlock();
} catch (Exception e) {
assertEquals(e.getMessage(),
- "Invalid aggregation function:
histogram(intColumn,arrayvalueconstructor('0')); Reason: The number of "
+ "Invalid aggregation function: histogram(intColumn,'[0]'); Reason:
The number of "
+ "bin edges must be greater than 1");
}
@@ -333,7 +333,7 @@ public class HistogramQueriesTest extends BaseQueriesTest {
operator.nextBlock();
} catch (Exception e) {
assertEquals(e.getMessage(),
- "Invalid aggregation function:
histogram(intColumn,arrayvalueconstructor('0','0','1','2')); Reason: The "
+ "Invalid aggregation function: histogram(intColumn,'[0, 0, 1, 2]');
Reason: The "
+ "bin edges must be strictly increasing");
}
diff --git
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/ArrayTest.java
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/ArrayTest.java
index 9275d3ce90..19bf45b373 100644
---
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/ArrayTest.java
+++
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/ArrayTest.java
@@ -217,6 +217,22 @@ public class ArrayTest extends
CustomDataQueryClusterIntegrationTest {
Assert.assertEquals(row.get(0).get(2).asInt(), 3);
}
+ @Test(dataProvider = "useBothQueryEngines")
+ public void testIntArrayLiteralWithoutFrom(boolean useMultiStageQueryEngine)
+ throws Exception {
+ setUseMultiStageQueryEngine(useMultiStageQueryEngine);
+ String query = "SELECT ARRAY[1,2,3] ";
+ JsonNode jsonNode = postQuery(query);
+ JsonNode rows = jsonNode.get("resultTable").get("rows");
+ Assert.assertEquals(rows.size(), 1);
+ JsonNode row = rows.get(0);
+ Assert.assertEquals(row.size(), 1);
+ Assert.assertEquals(row.get(0).size(), 3);
+ Assert.assertEquals(row.get(0).get(0).asInt(), 1);
+ Assert.assertEquals(row.get(0).get(1).asInt(), 2);
+ Assert.assertEquals(row.get(0).get(2).asInt(), 3);
+ }
+
@Test(dataProvider = "useBothQueryEngines")
public void testLongArrayLiteral(boolean useMultiStageQueryEngine)
throws Exception {
@@ -236,6 +252,22 @@ public class ArrayTest extends
CustomDataQueryClusterIntegrationTest {
Assert.assertEquals(row.get(0).get(2).asLong(), 2147483650L);
}
+ @Test(dataProvider = "useBothQueryEngines")
+ public void testLongArrayLiteralWithoutFrom(boolean useMultiStageQueryEngine)
+ throws Exception {
+ setUseMultiStageQueryEngine(useMultiStageQueryEngine);
+ String query = "SELECT ARRAY[2147483648,2147483649,2147483650]";
+ JsonNode jsonNode = postQuery(query);
+ JsonNode rows = jsonNode.get("resultTable").get("rows");
+ Assert.assertEquals(rows.size(), 1);
+ JsonNode row = rows.get(0);
+ Assert.assertEquals(row.size(), 1);
+ Assert.assertEquals(row.get(0).size(), 3);
+ Assert.assertEquals(row.get(0).get(0).asLong(), 2147483648L);
+ Assert.assertEquals(row.get(0).get(1).asLong(), 2147483649L);
+ Assert.assertEquals(row.get(0).get(2).asLong(), 2147483650L);
+ }
+
@Test(dataProvider = "useBothQueryEngines")
public void testFloatArrayLiteral(boolean useMultiStageQueryEngine)
throws Exception {
@@ -255,6 +287,22 @@ public class ArrayTest extends
CustomDataQueryClusterIntegrationTest {
Assert.assertEquals(row.get(0).get(2).asDouble(), 0.3);
}
+ @Test(dataProvider = "useBothQueryEngines")
+ public void testFloatArrayLiteralWithoutFrom(boolean
useMultiStageQueryEngine)
+ throws Exception {
+ setUseMultiStageQueryEngine(useMultiStageQueryEngine);
+ String query = "SELECT ARRAY[0.1, 0.2, 0.3]";
+ JsonNode jsonNode = postQuery(query);
+ JsonNode rows = jsonNode.get("resultTable").get("rows");
+ Assert.assertEquals(rows.size(), 1);
+ JsonNode row = rows.get(0);
+ Assert.assertEquals(row.size(), 1);
+ Assert.assertEquals(row.get(0).size(), 3);
+ Assert.assertEquals(row.get(0).get(0).asDouble(), 0.1);
+ Assert.assertEquals(row.get(0).get(1).asDouble(), 0.2);
+ Assert.assertEquals(row.get(0).get(2).asDouble(), 0.3);
+ }
+
@Test(dataProvider = "useBothQueryEngines")
public void testDoubleArrayLiteral(boolean useMultiStageQueryEngine)
throws Exception {
@@ -274,6 +322,22 @@ public class ArrayTest extends
CustomDataQueryClusterIntegrationTest {
Assert.assertEquals(row.get(0).get(2).asDouble(), 0.3);
}
+ @Test(dataProvider = "useBothQueryEngines")
+ public void testDoubleArrayLiteralWithoutFrom(boolean
useMultiStageQueryEngine)
+ throws Exception {
+ setUseMultiStageQueryEngine(useMultiStageQueryEngine);
+ String query = "SELECT ARRAY[CAST(0.1 AS DOUBLE), CAST(0.2 AS DOUBLE),
CAST(0.3 AS DOUBLE)]";
+ JsonNode jsonNode = postQuery(query);
+ JsonNode rows = jsonNode.get("resultTable").get("rows");
+ Assert.assertEquals(rows.size(), 1);
+ JsonNode row = rows.get(0);
+ Assert.assertEquals(row.size(), 1);
+ Assert.assertEquals(row.get(0).size(), 3);
+ Assert.assertEquals(row.get(0).get(0).asDouble(), 0.1);
+ Assert.assertEquals(row.get(0).get(1).asDouble(), 0.2);
+ Assert.assertEquals(row.get(0).get(2).asDouble(), 0.3);
+ }
+
@Test(dataProvider = "useBothQueryEngines")
public void testStringArrayLiteral(boolean useMultiStageQueryEngine)
throws Exception {
@@ -293,6 +357,22 @@ public class ArrayTest extends
CustomDataQueryClusterIntegrationTest {
Assert.assertEquals(row.get(0).get(2).asText(), "ccc");
}
+ @Test(dataProvider = "useBothQueryEngines")
+ public void testStringArrayLiteralWithoutFrom(boolean
useMultiStageQueryEngine)
+ throws Exception {
+ setUseMultiStageQueryEngine(useMultiStageQueryEngine);
+ String query = "SELECT ARRAY['a', 'bb', 'ccc']";
+ JsonNode jsonNode = postQuery(query);
+ JsonNode rows = jsonNode.get("resultTable").get("rows");
+ Assert.assertEquals(rows.size(), 1);
+ JsonNode row = rows.get(0);
+ Assert.assertEquals(row.size(), 1);
+ Assert.assertEquals(row.get(0).size(), 3);
+ Assert.assertEquals(row.get(0).get(0).asText(), "a");
+ Assert.assertEquals(row.get(0).get(1).asText(), "bb");
+ Assert.assertEquals(row.get(0).get(2).asText(), "ccc");
+ }
+
@Override
public String getTableName() {
return DEFAULT_TABLE_NAME;
@@ -352,7 +432,7 @@ public class ArrayTest extends
CustomDataQueryClusterIntegrationTest {
fileWriter.append(recordCache.get((int) (i % (getCountStarResult() /
10)), () -> {
// create avro record
GenericData.Record record = new GenericData.Record(avroSchema);
- record.put(BOOLEAN_COLUMN, RANDOM.nextBoolean());
+ record.put(BOOLEAN_COLUMN, finalI % 4 == 0 || finalI % 4 == 1);
record.put(INT_COLUMN, finalI);
record.put(LONG_COLUMN, finalI);
record.put(FLOAT_COLUMN, finalI + RANDOM.nextFloat());
diff --git
a/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotEvaluateLiteralRule.java
b/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotEvaluateLiteralRule.java
index 02d44ade9b..ea0d531faa 100644
---
a/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotEvaluateLiteralRule.java
+++
b/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotEvaluateLiteralRule.java
@@ -159,8 +159,12 @@ public class PinotEvaluateLiteralRule {
Object resultValue;
try {
FunctionInvoker invoker = new FunctionInvoker(functionInfo);
- invoker.convertTypes(arguments);
- resultValue = invoker.invoke(arguments);
+ if (functionInfo.getMethod().isVarArgs()) {
+ resultValue = invoker.invoke(new Object[] {arguments});
+ } else {
+ invoker.convertTypes(arguments);
+ resultValue = invoker.invoke(arguments);
+ }
if (rexNodeType.getSqlTypeName() == SqlTypeName.ARRAY) {
RelDataType componentType = rexNodeType.getComponentType();
if (componentType != null) {
diff --git
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FunctionOperand.java
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FunctionOperand.java
index 9e777d7867..cccc065be3 100644
---
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FunctionOperand.java
+++
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FunctionOperand.java
@@ -50,11 +50,13 @@ public class FunctionOperand implements TransformOperand {
FunctionInfo functionInfo =
FunctionRegistry.getFunctionInfo(canonicalName, numOperands);
Preconditions.checkState(functionInfo != null, "Cannot find function with
name: %s", canonicalName);
_functionInvoker = new FunctionInvoker(functionInfo);
- Class<?>[] parameterClasses = _functionInvoker.getParameterClasses();
- PinotDataType[] parameterTypes = _functionInvoker.getParameterTypes();
- for (int i = 0; i < numOperands; i++) {
- Preconditions.checkState(parameterTypes[i] != null, "Unsupported
parameter class: %s for method: %s",
- parameterClasses[i], functionInfo.getMethod());
+ if (!_functionInvoker.getMethod().isVarArgs()) {
+ Class<?>[] parameterClasses = _functionInvoker.getParameterClasses();
+ PinotDataType[] parameterTypes = _functionInvoker.getParameterTypes();
+ for (int i = 0; i < numOperands; i++) {
+ Preconditions.checkState(parameterTypes[i] != null, "Unsupported
parameter class: %s for method: %s",
+ parameterClasses[i], functionInfo.getMethod());
+ }
}
ColumnDataType functionInvokerResultType =
FunctionUtils.getColumnDataType(_functionInvoker.getResultClass());
// Handle unrecognized result class with STRING
@@ -80,8 +82,13 @@ public class FunctionOperand implements TransformOperand {
_reusableOperandHolder[i] = value != null ?
operand.getResultType().toExternal(value) : null;
}
// TODO: Optimize per record conversion
- _functionInvoker.convertTypes(_reusableOperandHolder);
- Object result = _functionInvoker.invoke(_reusableOperandHolder);
+ Object result;
+ if (_functionInvoker.getMethod().isVarArgs()) {
+ result = _functionInvoker.invoke(new Object[]{_reusableOperandHolder});
+ } else {
+ _functionInvoker.convertTypes(_reusableOperandHolder);
+ result = _functionInvoker.invoke(_reusableOperandHolder);
+ }
return result != null ?
TypeUtils.convert(_functionInvokerResultType.toInternal(result),
_resultType.getStoredType()) : null;
}
diff --git
a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/function/InbuiltFunctionEvaluator.java
b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/function/InbuiltFunctionEvaluator.java
index 823dd23b88..8c3909e784 100644
---
a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/function/InbuiltFunctionEvaluator.java
+++
b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/function/InbuiltFunctionEvaluator.java
@@ -233,6 +233,9 @@ public class InbuiltFunctionEvaluator implements
FunctionEvaluator {
}
}
}
+ if (_functionInvoker.getMethod().isVarArgs()) {
+ return _functionInvoker.invoke(new Object[]{_arguments});
+ }
_functionInvoker.convertTypes(_arguments);
return _functionInvoker.invoke(_arguments);
} catch (Exception e) {
@@ -256,6 +259,9 @@ public class InbuiltFunctionEvaluator implements
FunctionEvaluator {
}
}
}
+ if (_functionInvoker.getMethod().isVarArgs()) {
+ return _functionInvoker.invoke(new Object[]{_arguments});
+ }
_functionInvoker.convertTypes(_arguments);
return _functionInvoker.invoke(_arguments);
} catch (Exception e) {
diff --git
a/pinot-spi/src/main/java/org/apache/pinot/spi/annotations/ScalarFunction.java
b/pinot-spi/src/main/java/org/apache/pinot/spi/annotations/ScalarFunction.java
index 46a743d52c..0a647a8792 100644
---
a/pinot-spi/src/main/java/org/apache/pinot/spi/annotations/ScalarFunction.java
+++
b/pinot-spi/src/main/java/org/apache/pinot/spi/annotations/ScalarFunction.java
@@ -59,4 +59,9 @@ public @interface ScalarFunction {
boolean nullableParameters() default false;
boolean isPlaceholder() default false;
+
+ /**
+ * Whether the scalar function takes various number of arguments.
+ */
+ boolean isVarArg() default false;
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]