This is an automated email from the ASF dual-hosted git repository.
xiangfu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push:
new 48490933de Support parsing ARRAY literal in multistage query engine
(#11268)
48490933de is described below
commit 48490933ded6f20d94139eedbb57971fe0cbcb61
Author: Xiang Fu <[email protected]>
AuthorDate: Mon Aug 7 13:10:35 2023 -0700
Support parsing ARRAY literal in multistage query engine (#11268)
---
.../common/function/TransformFunctionType.java | 2 +
.../org/apache/pinot/common/utils/DataSchema.java | 54 +++-
.../function/ArrayLiteralTransformFunction.java | 291 +++++++++++++++++++++
.../function/TransformFunctionFactory.java | 8 +
.../core/data/function/VectorFunctionsTest.java | 19 ++
.../ArrayLiteralTransformFunctionTest.java | 167 ++++++++++++
.../function/VectorTransformFunctionTest.java | 12 +-
.../integration/tests/VectorIntegrationTest.java | 86 +++++-
.../apache/calcite/sql/fun/PinotOperatorTable.java | 8 +-
.../planner/logical/RelToPlanNodeConverter.java | 21 +-
.../local/function/InbuiltFunctionEvaluator.java | 39 ++-
11 files changed, 677 insertions(+), 30 deletions(-)
diff --git
a/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java
b/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java
index 471f6b128a..f741ff223e 100644
---
a/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java
+++
b/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java
@@ -221,6 +221,8 @@ public enum TransformFunctionType {
VECTOR_NORM("vectorNorm", ReturnTypes.explicit(SqlTypeName.DOUBLE),
OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY)),
"vector_norm"),
+ ARRAY_VALUE_CONSTRUCTOR("arrayValueConstructor"),
+
// Trigonometry
SIN("sin"),
COS("cos"),
diff --git
a/pinot-common/src/main/java/org/apache/pinot/common/utils/DataSchema.java
b/pinot-common/src/main/java/org/apache/pinot/common/utils/DataSchema.java
index 354ba8cd3c..282a3d7416 100644
--- a/pinot-common/src/main/java/org/apache/pinot/common/utils/DataSchema.java
+++ b/pinot-common/src/main/java/org/apache/pinot/common/utils/DataSchema.java
@@ -24,6 +24,8 @@ import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonPropertyOrder;
import com.google.common.collect.Ordering;
import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
+import it.unimi.dsi.fastutil.floats.FloatArrayList;
+import it.unimi.dsi.fastutil.ints.IntArrayList;
import it.unimi.dsi.fastutil.longs.LongArrayList;
import java.io.ByteArrayOutputStream;
import java.io.DataOutputStream;
@@ -358,11 +360,11 @@ public class DataSchema {
case BYTES:
return ((ByteArray) value).getBytes();
case INT_ARRAY:
- return (int[]) value;
+ return toIntArray(value);
case LONG_ARRAY:
return toLongArray(value);
case FLOAT_ARRAY:
- return (float[]) value;
+ return toFloatArray(value);
case DOUBLE_ARRAY:
return toDoubleArray(value);
case STRING_ARRAY:
@@ -475,6 +477,38 @@ public class DataSchema {
}
}
+ private static float[] toFloatArray(Object value) {
+ if (value instanceof float[]) {
+ return (float[]) value;
+ } else if (value instanceof FloatArrayList) {
+ return ((FloatArrayList) value).elements();
+ } else if (value instanceof int[]) {
+ int[] intValues = (int[]) value;
+ int length = intValues.length;
+ float[] floatValues = new float[length];
+ for (int i = 0; i < length; i++) {
+ floatValues[i] = intValues[i];
+ }
+ return floatValues;
+ } else if (value instanceof long[]) {
+ long[] longValues = (long[]) value;
+ int length = longValues.length;
+ float[] floatValues = new float[length];
+ for (int i = 0; i < length; i++) {
+ floatValues[i] = longValues[i];
+ }
+ return floatValues;
+ } else {
+ double[] doubleValues = (double[]) value;
+ int length = doubleValues.length;
+ float[] floatValues = new float[length];
+ for (int i = 0; i < length; i++) {
+ floatValues[i] = (float) doubleValues[i];
+ }
+ return floatValues;
+ }
+ }
+
private static long[] toLongArray(Object value) {
if (value instanceof long[]) {
return (long[]) value;
@@ -491,6 +525,22 @@ public class DataSchema {
}
}
+ private static int[] toIntArray(Object value) {
+ if (value instanceof int[]) {
+ return (int[]) value;
+ } else if (value instanceof IntArrayList) {
+ return ((IntArrayList) value).elements();
+ } else {
+ long[] longValues = (long[]) value;
+ int length = longValues.length;
+ int[] intValues = new int[length];
+ for (int i = 0; i < length; i++) {
+ intValues[i] = (int) longValues[i];
+ }
+ return intValues;
+ }
+ }
+
private static boolean[] toBooleanArray(Object value) {
int[] ints = (int[]) value;
boolean[] booleans = new boolean[ints.length];
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ArrayLiteralTransformFunction.java
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ArrayLiteralTransformFunction.java
new file mode 100644
index 0000000000..6208ee1966
--- /dev/null
+++
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ArrayLiteralTransformFunction.java
@@ -0,0 +1,291 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.operator.transform.function;
+
+import com.google.common.base.Preconditions;
+import java.math.BigDecimal;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import org.apache.pinot.common.request.context.ExpressionContext;
+import org.apache.pinot.core.operator.ColumnContext;
+import org.apache.pinot.core.operator.blocks.ValueBlock;
+import org.apache.pinot.core.operator.transform.TransformResultMetadata;
+import org.apache.pinot.segment.spi.index.reader.Dictionary;
+import org.apache.pinot.spi.data.FieldSpec.DataType;
+import org.roaringbitmap.RoaringBitmap;
+
+
+/**
+ * The <code>LiteralTransformFunction</code> class is a special transform
function which is a wrapper on top of a
+ * LITERAL. The data type is inferred from the literal string.
+ */
+public class ArrayLiteralTransformFunction implements TransformFunction {
+ public static final String FUNCTION_NAME = "arrayValueConstructor";
+
+ private final DataType _dataType;
+
+ private final int[] _intArrayLiteral;
+ private final long[] _longArrayLiteral;
+ private final float[] _floatArrayLiteral;
+ private final double[] _doubleArrayLiteral;
+ private final String[] _stringArrayLiteral;
+
+ // literals may be shared but values are intentionally not volatile as
assignment races are benign
+ private int[][] _intArrayResult;
+ private long[][] _longArrayResult;
+ private float[][] _floatArrayResult;
+ private double[][] _doubleArrayResult;
+ private String[][] _stringArrayResult;
+
+ public ArrayLiteralTransformFunction(List<ExpressionContext>
literalContexts) {
+ Preconditions.checkNotNull(literalContexts);
+ if (literalContexts.isEmpty()) {
+ _dataType = DataType.UNKNOWN;
+ _intArrayLiteral = new int[0];
+ _longArrayLiteral = new long[0];
+ _floatArrayLiteral = new float[0];
+ _doubleArrayLiteral = new double[0];
+ _stringArrayLiteral = new String[0];
+ return;
+ }
+ for (ExpressionContext literalContext : literalContexts) {
+ Preconditions.checkState(literalContext.getType() ==
ExpressionContext.Type.LITERAL,
+ "ArrayLiteralTransformFunction only takes literals as arguments,
found: %s", literalContext);
+ }
+ _dataType = literalContexts.get(0).getLiteral().getType();
+ switch (_dataType) {
+ case INT:
+ _intArrayLiteral = new int[literalContexts.size()];
+ for (int i = 0; i < _intArrayLiteral.length; i++) {
+ _intArrayLiteral[i] =
literalContexts.get(i).getLiteral().getIntValue();
+ }
+ _longArrayLiteral = null;
+ _floatArrayLiteral = null;
+ _doubleArrayLiteral = null;
+ _stringArrayLiteral = null;
+ break;
+ case LONG:
+ _longArrayLiteral = new long[literalContexts.size()];
+ for (int i = 0; i < _longArrayLiteral.length; i++) {
+ _longArrayLiteral[i] =
Long.parseLong(literalContexts.get(i).getLiteral().getStringValue());
+ }
+ _intArrayLiteral = null;
+ _floatArrayLiteral = null;
+ _doubleArrayLiteral = null;
+ _stringArrayLiteral = null;
+ break;
+ case FLOAT:
+ _floatArrayLiteral = new float[literalContexts.size()];
+ for (int i = 0; i < _floatArrayLiteral.length; i++) {
+ _floatArrayLiteral[i] =
Float.parseFloat(literalContexts.get(i).getLiteral().getStringValue());
+ }
+ _intArrayLiteral = null;
+ _longArrayLiteral = null;
+ _doubleArrayLiteral = null;
+ _stringArrayLiteral = null;
+ break;
+ case DOUBLE:
+ _doubleArrayLiteral = new double[literalContexts.size()];
+ for (int i = 0; i < _doubleArrayLiteral.length; i++) {
+ _doubleArrayLiteral[i] =
Double.parseDouble(literalContexts.get(i).getLiteral().getStringValue());
+ }
+ _intArrayLiteral = null;
+ _longArrayLiteral = null;
+ _floatArrayLiteral = null;
+ _stringArrayLiteral = null;
+ break;
+ case STRING:
+ _stringArrayLiteral = new String[literalContexts.size()];
+ for (int i = 0; i < _stringArrayLiteral.length; i++) {
+ _stringArrayLiteral[i] =
literalContexts.get(i).getLiteral().getStringValue();
+ }
+ _intArrayLiteral = null;
+ _longArrayLiteral = null;
+ _floatArrayLiteral = null;
+ _doubleArrayLiteral = null;
+ break;
+ default:
+ throw new IllegalStateException(
+ "Illegal data type for ArrayLiteralTransformFunction: " +
_dataType + ", literal contexts: "
+ + Arrays.toString(literalContexts.toArray()));
+ }
+ }
+
+ public int[] getIntArrayLiteral() {
+ return _intArrayLiteral;
+ }
+
+ public long[] getLongArrayLiteral() {
+ return _longArrayLiteral;
+ }
+
+ public float[] getFloatArrayLiteral() {
+ return _floatArrayLiteral;
+ }
+
+ public double[] getDoubleArrayLiteral() {
+ return _doubleArrayLiteral;
+ }
+
+ public String[] getStringArrayLiteral() {
+ return _stringArrayLiteral;
+ }
+
+ @Override
+ public String getName() {
+ return FUNCTION_NAME;
+ }
+
+ @Override
+ public void init(List<TransformFunction> arguments, Map<String,
ColumnContext> columnContextMap) {
+ }
+
+ @Override
+ public TransformResultMetadata getResultMetadata() {
+ return new TransformResultMetadata(_dataType, false, false);
+ }
+
+ @Override
+ public Dictionary getDictionary() {
+ return null;
+ }
+
+ @Override
+ public int[] transformToDictIdsSV(ValueBlock valueBlock) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public int[][] transformToDictIdsMV(ValueBlock valueBlock) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public int[] transformToIntValuesSV(ValueBlock valueBlock) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public long[] transformToLongValuesSV(ValueBlock valueBlock) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public float[] transformToFloatValuesSV(ValueBlock valueBlock) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public double[] transformToDoubleValuesSV(ValueBlock valueBlock) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public BigDecimal[] transformToBigDecimalValuesSV(ValueBlock valueBlock) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public String[] transformToStringValuesSV(ValueBlock valueBlock) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public byte[][] transformToBytesValuesSV(ValueBlock valueBlock) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public int[][] transformToIntValuesMV(ValueBlock valueBlock) {
+ int numDocs = valueBlock.getNumDocs();
+ int[][] intArrayResult = _intArrayResult;
+ if (intArrayResult == null || intArrayResult.length < numDocs) {
+ intArrayResult = new int[numDocs][];
+ Arrays.fill(intArrayResult, _intArrayLiteral);
+ _intArrayResult = intArrayResult;
+ }
+ return intArrayResult;
+ }
+
+ @Override
+ public long[][] transformToLongValuesMV(ValueBlock valueBlock) {
+ int numDocs = valueBlock.getNumDocs();
+ long[][] longArrayResult = _longArrayResult;
+ if (longArrayResult == null || longArrayResult.length < numDocs) {
+ longArrayResult = new long[numDocs][];
+ Arrays.fill(longArrayResult, _longArrayLiteral);
+ _longArrayResult = longArrayResult;
+ }
+ return longArrayResult;
+ }
+
+ @Override
+ public float[][] transformToFloatValuesMV(ValueBlock valueBlock) {
+ int numDocs = valueBlock.getNumDocs();
+ float[][] floatArrayResult = _floatArrayResult;
+ if (floatArrayResult == null || floatArrayResult.length < numDocs) {
+ floatArrayResult = new float[numDocs][];
+ Arrays.fill(floatArrayResult, _floatArrayLiteral);
+ _floatArrayResult = floatArrayResult;
+ }
+ return floatArrayResult;
+ }
+
+ @Override
+ public double[][] transformToDoubleValuesMV(ValueBlock valueBlock) {
+ int numDocs = valueBlock.getNumDocs();
+ double[][] doubleArrayResult = _doubleArrayResult;
+ if (doubleArrayResult == null || doubleArrayResult.length < numDocs) {
+ doubleArrayResult = new double[numDocs][];
+ Arrays.fill(doubleArrayResult, _doubleArrayLiteral);
+ _doubleArrayResult = doubleArrayResult;
+ }
+ return doubleArrayResult;
+ }
+
+ @Override
+ public String[][] transformToStringValuesMV(ValueBlock valueBlock) {
+ int numDocs = valueBlock.getNumDocs();
+ String[][] stringArrayResult = _stringArrayResult;
+ if (stringArrayResult == null || stringArrayResult.length < numDocs) {
+ stringArrayResult = new String[numDocs][];
+ Arrays.fill(stringArrayResult, _stringArrayLiteral);
+ _stringArrayResult = stringArrayResult;
+ }
+ return stringArrayResult;
+ }
+
+ @Override
+ public byte[][][] transformToBytesValuesMV(ValueBlock valueBlock) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public RoaringBitmap getNullBitmap(ValueBlock valueBlock) {
+ // Treat all unknown type values as null regardless of the value.
+ if (_dataType != DataType.UNKNOWN) {
+ return null;
+ }
+ int length = valueBlock.getNumDocs();
+ RoaringBitmap bitmap = new RoaringBitmap();
+ bitmap.add(0L, length);
+ return bitmap;
+ }
+}
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
index 4e3ff24119..be2b54d128 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
@@ -153,6 +153,7 @@ public class TransformFunctionFactory {
typeToImplementation.put(TransformFunctionType.ARRAYMAX,
ArrayMaxTransformFunction.class);
typeToImplementation.put(TransformFunctionType.ARRAYMIN,
ArrayMinTransformFunction.class);
typeToImplementation.put(TransformFunctionType.ARRAYSUM,
ArraySumTransformFunction.class);
+ typeToImplementation.put(TransformFunctionType.ARRAY_VALUE_CONSTRUCTOR,
ArrayLiteralTransformFunction.class);
typeToImplementation.put(TransformFunctionType.GROOVY,
GroovyTransformFunction.class);
typeToImplementation.put(TransformFunctionType.CASE,
CaseTransformFunction.class);
@@ -281,6 +282,13 @@ public class TransformFunctionFactory {
List<ExpressionContext> arguments = function.getArguments();
int numArguments = arguments.size();
+ // Check if the function is ArrayLiteraltransform function
+ if
(functionName.equalsIgnoreCase(ArrayLiteralTransformFunction.FUNCTION_NAME)) {
+ return
queryContext.getOrComputeSharedValue(ArrayLiteralTransformFunction.class,
+ expression.getFunction().getArguments(),
+ ArrayLiteralTransformFunction::new);
+ }
+
TransformFunction transformFunction;
Class<? extends TransformFunction> transformFunctionClass =
TRANSFORM_FUNCTION_MAP.get(functionName);
if (transformFunctionClass != null) {
diff --git
a/pinot-core/src/test/java/org/apache/pinot/core/data/function/VectorFunctionsTest.java
b/pinot-core/src/test/java/org/apache/pinot/core/data/function/VectorFunctionsTest.java
index 972c33ee43..6600b5c10f 100644
---
a/pinot-core/src/test/java/org/apache/pinot/core/data/function/VectorFunctionsTest.java
+++
b/pinot-core/src/test/java/org/apache/pinot/core/data/function/VectorFunctionsTest.java
@@ -108,6 +108,25 @@ public class VectorFunctionsTest {
inputs.add(new Object[]{"vectorDims(vector2)",
Lists.newArrayList("vector2"), row, 5});
inputs.add(new Object[]{"vectorNorm(vector1)",
Lists.newArrayList("vector1"), row, 0.741619857751291});
inputs.add(new Object[]{"vectorNorm(vector2)",
Lists.newArrayList("vector2"), row, 0.0});
+
+ inputs.add(new Object[]{
+ "cosineDistance(vector1, ARRAY[0.0,0.0,0.0,0.0,0.0])",
Lists.newArrayList("vector1"), row, Double.NaN
+ });
+ inputs.add(new Object[]{
+ "cosineDistance(vector1, ARRAY[0.0,0.0,0.0,0.0,0.0], 0.0)",
Lists.newArrayList("vector1"), row, 0.0
+ });
+ inputs.add(new Object[]{
+ "cosineDistance(vector1, ARRAY[0.0,0.0,0.0,0.0,0.0], 1.0)",
Lists.newArrayList("vector1"), row, 1.0
+ });
+ inputs.add(new Object[]{
+ "innerProduct(vector1, ARRAY[0.0,0.0,0.0,0.0,0.0])",
Lists.newArrayList("vector1"), row, 0.0
+ });
+ inputs.add(new Object[]{
+ "l2Distance(vector1, ARRAY[0.0,0.0,0.0,0.0,0.0])",
Lists.newArrayList("vector1"), row, 0.741619857751291
+ });
+ inputs.add(new Object[]{
+ "l1Distance(vector1, ARRAY[0.0,0.0,0.0,0.0,0.0])",
Lists.newArrayList("vector1"), row, 1.5000000223517418
+ });
return inputs.toArray(new Object[0][]);
}
}
diff --git
a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ArrayLiteralTransformFunctionTest.java
b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ArrayLiteralTransformFunctionTest.java
new file mode 100644
index 0000000000..005b8c3eeb
--- /dev/null
+++
b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ArrayLiteralTransformFunctionTest.java
@@ -0,0 +1,167 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.operator.transform.function;
+
+import java.util.ArrayList;
+import java.util.List;
+import org.apache.pinot.common.request.Literal;
+import org.apache.pinot.common.request.context.ExpressionContext;
+import org.apache.pinot.core.operator.blocks.ProjectionBlock;
+import org.apache.pinot.spi.data.FieldSpec.DataType;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+import org.testng.Assert;
+import org.testng.annotations.AfterMethod;
+import org.testng.annotations.BeforeMethod;
+import org.testng.annotations.Test;
+
+import static org.mockito.Mockito.when;
+
+
+public class ArrayLiteralTransformFunctionTest {
+ private static final int NUM_DOCS = 100;
+ private AutoCloseable _mocks;
+
+ @Mock
+ private ProjectionBlock _projectionBlock;
+
+ @BeforeMethod
+ public void setUp() {
+ _mocks = MockitoAnnotations.openMocks(this);
+ when(_projectionBlock.getNumDocs()).thenReturn(NUM_DOCS);
+ }
+
+ @AfterMethod
+ public void tearDown()
+ throws Exception {
+ _mocks.close();
+ }
+
+ @Test
+ public void testIntArrayLiteralTransformFunction() {
+ List<ExpressionContext> arrayExpressions = new ArrayList<>();
+ for (int i = 0; i < 10; i++) {
+ arrayExpressions.add(ExpressionContext.forLiteralContext(DataType.INT,
i));
+ }
+
+ ArrayLiteralTransformFunction intArray = new
ArrayLiteralTransformFunction(arrayExpressions);
+ Assert.assertEquals(intArray.getResultMetadata().getDataType(),
DataType.INT);
+ Assert.assertEquals(intArray.getIntArrayLiteral(), new int[]{
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9
+ });
+ }
+
+ @Test
+ public void testLongArrayLiteralTransformFunction() {
+ List<ExpressionContext> arrayExpressions = new ArrayList<>();
+ for (int i = 0; i < 10; i++) {
+ arrayExpressions.add(ExpressionContext.forLiteralContext(DataType.LONG,
(long) i));
+ }
+
+ ArrayLiteralTransformFunction longArray = new
ArrayLiteralTransformFunction(arrayExpressions);
+ Assert.assertEquals(longArray.getResultMetadata().getDataType(),
DataType.LONG);
+ Assert.assertEquals(longArray.getLongArrayLiteral(), new long[]{
+ 0L, 1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L
+ });
+ }
+
+ @Test
+ public void testFloatArrayLiteralTransformFunction() {
+ List<ExpressionContext> arrayExpressions = new ArrayList<>();
+ for (int i = 0; i < 10; i++) {
+ arrayExpressions.add(ExpressionContext.forLiteralContext(DataType.FLOAT,
(double) i));
+ }
+
+ ArrayLiteralTransformFunction floatArray = new
ArrayLiteralTransformFunction(arrayExpressions);
+ Assert.assertEquals(floatArray.getResultMetadata().getDataType(),
DataType.FLOAT);
+ Assert.assertEquals(floatArray.getFloatArrayLiteral(), new float[]{
+ 0f, 1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f
+ });
+ }
+
+ @Test
+ public void testDoubleArrayLiteralTransformFunction() {
+ List<ExpressionContext> arrayExpressions = new ArrayList<>();
+ for (int i = 0; i < 10; i++) {
+
arrayExpressions.add(ExpressionContext.forLiteralContext(DataType.DOUBLE,
(double) i));
+ }
+
+ ArrayLiteralTransformFunction doubleArray = new
ArrayLiteralTransformFunction(arrayExpressions);
+ Assert.assertEquals(doubleArray.getResultMetadata().getDataType(),
DataType.DOUBLE);
+ Assert.assertEquals(doubleArray.getDoubleArrayLiteral(), new double[]{
+ 0d, 1d, 2d, 3d, 4d, 5d, 6d, 7d, 8d, 9d
+ });
+ }
+
+ @Test
+ public void testStringArrayLiteralTransformFunction() {
+ List<ExpressionContext> arrayExpressions = new ArrayList<>();
+ for (int i = 0; i < 10; i++) {
+ arrayExpressions.add(
+ ExpressionContext.forLiteralContext(new
Literal(Literal._Fields.STRING_VALUE, String.valueOf(i))));
+ }
+
+ ArrayLiteralTransformFunction stringArray = new
ArrayLiteralTransformFunction(arrayExpressions);
+ Assert.assertEquals(stringArray.getResultMetadata().getDataType(),
DataType.STRING);
+ Assert.assertEquals(stringArray.getStringArrayLiteral(), new String[]{
+ "0", "1", "2", "3", "4", "5", "6", "7", "8", "9"
+ });
+ }
+
+ @Test
+ public void testEmptyArrayTransform() {
+ List<ExpressionContext> arrayExpressions = new ArrayList<>();
+ ArrayLiteralTransformFunction emptyLiteral = new
ArrayLiteralTransformFunction(arrayExpressions);
+ Assert.assertEquals(emptyLiteral.getIntArrayLiteral(), new int[0]);
+ Assert.assertEquals(emptyLiteral.getLongArrayLiteral(), new long[0]);
+ Assert.assertEquals(emptyLiteral.getFloatArrayLiteral(), new float[0]);
+ Assert.assertEquals(emptyLiteral.getDoubleArrayLiteral(), new double[0]);
+ Assert.assertEquals(emptyLiteral.getStringArrayLiteral(), new String[0]);
+
+ int[][] ints = emptyLiteral.transformToIntValuesMV(_projectionBlock);
+ Assert.assertEquals(ints.length, NUM_DOCS);
+ for (int i = 0; i < NUM_DOCS; i++) {
+ Assert.assertEquals(ints[i].length, 0);
+ }
+
+ long[][] longs = emptyLiteral.transformToLongValuesMV(_projectionBlock);
+ Assert.assertEquals(longs.length, NUM_DOCS);
+ for (int i = 0; i < NUM_DOCS; i++) {
+ Assert.assertEquals(longs[i].length, 0);
+ }
+
+ float[][] floats = emptyLiteral.transformToFloatValuesMV(_projectionBlock);
+ Assert.assertEquals(floats.length, NUM_DOCS);
+ for (int i = 0; i < NUM_DOCS; i++) {
+ Assert.assertEquals(floats[i].length, 0);
+ }
+
+ double[][] doubles =
emptyLiteral.transformToDoubleValuesMV(_projectionBlock);
+ Assert.assertEquals(doubles.length, NUM_DOCS);
+ for (int i = 0; i < NUM_DOCS; i++) {
+ Assert.assertEquals(doubles[i].length, 0);
+ }
+
+ String[][] strings =
emptyLiteral.transformToStringValuesMV(_projectionBlock);
+ Assert.assertEquals(strings.length, NUM_DOCS);
+ for (int i = 0; i < NUM_DOCS; i++) {
+ Assert.assertEquals(strings[i].length, 0);
+ }
+ }
+}
diff --git
a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/VectorTransformFunctionTest.java
b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/VectorTransformFunctionTest.java
index 8aed6e4698..23b3213f3e 100644
---
a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/VectorTransformFunctionTest.java
+++
b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/VectorTransformFunctionTest.java
@@ -59,6 +59,9 @@ public class VectorTransformFunctionTest extends
BaseTransformFunctionTest {
@DataProvider(name = "testVectorTransformFunctionDataProvider")
public Object[][] testVectorTransformFunctionDataProvider() {
+ String zeroVectorLiteral = "ARRAY[0.0"
+ + ",0.0".repeat(VECTOR_DIM_SIZE - 1)
+ + "]";
return new Object[][]{
new Object[]{"cosineDistance(vector1, vector2)", 0.1, 0.4},
new Object[]{"cosineDistance(vector1, vector2, 0)", 0.1, 0.4},
@@ -67,7 +70,14 @@ public class VectorTransformFunctionTest extends
BaseTransformFunctionTest {
new Object[]{"l1Distance(vector1, vector2)", 140, 210},
new Object[]{"l2Distance(vector1, vector2)", 8, 11},
new Object[]{"vectorNorm(vector1)", 10, 16},
- new Object[]{"vectorNorm(vector2)", 10, 16}
+ new Object[]{"vectorNorm(vector2)", 10, 16},
+
+ new Object[]{String.format("cosineDistance(vector1, %s, 0)",
zeroVectorLiteral), 0.0, 0.0},
+ new Object[]{String.format("innerProduct(vector1, %s)",
zeroVectorLiteral), 0.0, 0.0},
+ new Object[]{String.format("l1Distance(vector1, %s)",
zeroVectorLiteral), 0, VECTOR_DIM_SIZE},
+ new Object[]{String.format("l2Distance(vector1, %s)",
zeroVectorLiteral), 0, VECTOR_DIM_SIZE},
+ new Object[]{String.format("vectorDims(%s)", zeroVectorLiteral),
VECTOR_DIM_SIZE, VECTOR_DIM_SIZE},
+ new Object[]{String.format("vectorNorm(%s)", zeroVectorLiteral), 0.0,
0.0},
};
}
}
diff --git
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/VectorIntegrationTest.java
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/VectorIntegrationTest.java
index 48efe20490..dbfcd5a347 100644
---
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/VectorIntegrationTest.java
+++
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/VectorIntegrationTest.java
@@ -100,32 +100,96 @@ public class VectorIntegrationTest extends
BaseClusterIntegrationTest {
+ "vectorNorm(vector1), vectorNorm(vector2), "
+ "cosineDistance(vector1, zeroVector), "
+ "cosineDistance(vector1, zeroVector, 0) "
- + "FROM %s", DEFAULT_TABLE_NAME);
+ + "FROM %s LIMIT %d", DEFAULT_TABLE_NAME, getCountStarResult());
JsonNode jsonNode = postQuery(query);
for (int i = 0; i < getCountStarResult(); i++) {
- double cosineDistance =
jsonNode.get("resultTable").get("rows").get(0).get(0).asDouble();
+ double cosineDistance =
jsonNode.get("resultTable").get("rows").get(i).get(0).asDouble();
assertTrue(cosineDistance > 0.1 && cosineDistance < 0.4);
- double innerProduce =
jsonNode.get("resultTable").get("rows").get(0).get(1).asDouble();
+ double innerProduce =
jsonNode.get("resultTable").get("rows").get(i).get(1).asDouble();
assertTrue(innerProduce > 100 && innerProduce < 160);
- double l1Distance =
jsonNode.get("resultTable").get("rows").get(0).get(2).asDouble();
+ double l1Distance =
jsonNode.get("resultTable").get("rows").get(i).get(2).asDouble();
assertTrue(l1Distance > 140 && l1Distance < 210);
- double l2Distance =
jsonNode.get("resultTable").get("rows").get(0).get(3).asDouble();
+ double l2Distance =
jsonNode.get("resultTable").get("rows").get(i).get(3).asDouble();
assertTrue(l2Distance > 8 && l2Distance < 11);
- int vectorDimsVector1 =
jsonNode.get("resultTable").get("rows").get(0).get(4).asInt();
+ int vectorDimsVector1 =
jsonNode.get("resultTable").get("rows").get(i).get(4).asInt();
assertEquals(vectorDimsVector1, VECTOR_DIM_SIZE);
- int vectorDimsVector2 =
jsonNode.get("resultTable").get("rows").get(0).get(5).asInt();
+ int vectorDimsVector2 =
jsonNode.get("resultTable").get("rows").get(i).get(5).asInt();
assertEquals(vectorDimsVector2, VECTOR_DIM_SIZE);
- double vectorNormVector1 =
jsonNode.get("resultTable").get("rows").get(0).get(6).asInt();
+ double vectorNormVector1 =
jsonNode.get("resultTable").get("rows").get(i).get(6).asInt();
assertTrue(vectorNormVector1 > 10 && vectorNormVector1 < 16);
- double vectorNormVector2 =
jsonNode.get("resultTable").get("rows").get(0).get(7).asInt();
+ double vectorNormVector2 =
jsonNode.get("resultTable").get("rows").get(i).get(7).asInt();
assertTrue(vectorNormVector2 > 10 && vectorNormVector2 < 16);
- cosineDistance =
jsonNode.get("resultTable").get("rows").get(0).get(8).asDouble();
+ cosineDistance =
jsonNode.get("resultTable").get("rows").get(i).get(8).asDouble();
assertEquals(cosineDistance, Double.NaN);
- cosineDistance =
jsonNode.get("resultTable").get("rows").get(0).get(9).asDouble();
+ cosineDistance =
jsonNode.get("resultTable").get("rows").get(i).get(9).asDouble();
assertEquals(cosineDistance, 0.0);
}
}
+ @Test(dataProvider = "useBothQueryEngines")
+ public void testQueriesWithLiterals(boolean useMultiStageQueryEngine)
+ throws Exception {
+ setUseMultiStageQueryEngine(useMultiStageQueryEngine);
+ String zeroVectorStringLiteral = "ARRAY[0.0"
+ + ", 0.0".repeat(VECTOR_DIM_SIZE - 1)
+ + "]";
+ String oneVectorStringLiteral = "ARRAY[1.0"
+ + ", 1.0".repeat(VECTOR_DIM_SIZE - 1)
+ + "]";
+ String query =
+ String.format("SELECT "
+ + "cosineDistance(vector1, %s), "
+ + "innerProduct(vector1, %s), "
+ + "l1Distance(vector1, %s), "
+ + "l2Distance(vector1, %s), "
+ + "vectorDims(%s), "
+ + "vectorNorm(%s) "
+ + "FROM %s LIMIT %d",
+ zeroVectorStringLiteral, zeroVectorStringLiteral,
zeroVectorStringLiteral, zeroVectorStringLiteral,
+ zeroVectorStringLiteral, zeroVectorStringLiteral,
DEFAULT_TABLE_NAME, getCountStarResult());
+ JsonNode jsonNode = postQuery(query);
+ for (int i = 0; i < getCountStarResult(); i++) {
+ double cosineDistance =
jsonNode.get("resultTable").get("rows").get(i).get(0).asDouble();
+ assertEquals(cosineDistance, Double.NaN);
+ double innerProduce =
jsonNode.get("resultTable").get("rows").get(i).get(1).asDouble();
+ assertEquals(innerProduce, 0.0);
+ double l1Distance =
jsonNode.get("resultTable").get("rows").get(i).get(2).asDouble();
+ assertTrue(l1Distance > 100 && l1Distance < 300);
+ double l2Distance =
jsonNode.get("resultTable").get("rows").get(i).get(3).asDouble();
+ assertTrue(l2Distance > 10 && l2Distance < 16);
+ int vectorDimsVector =
jsonNode.get("resultTable").get("rows").get(i).get(4).asInt();
+ assertEquals(vectorDimsVector, VECTOR_DIM_SIZE);
+ double vectorNormVector =
jsonNode.get("resultTable").get("rows").get(i).get(5).asInt();
+ assertEquals(vectorNormVector, 0.0);
+ }
+
+ query =
+ String.format("SELECT "
+ + "cosineDistance(%s, %s), "
+ + "cosineDistance(%s, %s, 0.0), "
+ + "innerProduct(%s, %s), "
+ + "l1Distance(%s, %s), "
+ + "l2Distance(%s, %s)"
+ + "FROM %s LIMIT 1",
+ zeroVectorStringLiteral, oneVectorStringLiteral,
+ zeroVectorStringLiteral, oneVectorStringLiteral,
+ zeroVectorStringLiteral, oneVectorStringLiteral,
+ zeroVectorStringLiteral, oneVectorStringLiteral,
+ zeroVectorStringLiteral, oneVectorStringLiteral,
+ DEFAULT_TABLE_NAME);
+ jsonNode = postQuery(query);
+ double cosineDistance =
jsonNode.get("resultTable").get("rows").get(0).get(0).asDouble();
+ assertEquals(cosineDistance, Double.NaN);
+ cosineDistance =
jsonNode.get("resultTable").get("rows").get(0).get(1).asDouble();
+ assertEquals(cosineDistance, 0.0);
+ double innerProduce =
jsonNode.get("resultTable").get("rows").get(0).get(2).asDouble();
+ assertEquals(innerProduce, 0.0);
+ double l1Distance =
jsonNode.get("resultTable").get("rows").get(0).get(3).asDouble();
+ assertEquals(l1Distance, 512.0);
+ double l2Distance =
jsonNode.get("resultTable").get("rows").get(0).get(4).asDouble();
+ assertEquals(l2Distance, 22.627416997969522);
+ }
+
private File createAvroFile(long totalNumRecords)
throws IOException {
diff --git
a/pinot-query-planner/src/main/java/org/apache/calcite/sql/fun/PinotOperatorTable.java
b/pinot-query-planner/src/main/java/org/apache/calcite/sql/fun/PinotOperatorTable.java
index 2ee178419f..1a63a6eb07 100644
---
a/pinot-query-planner/src/main/java/org/apache/calcite/sql/fun/PinotOperatorTable.java
+++
b/pinot-query-planner/src/main/java/org/apache/calcite/sql/fun/PinotOperatorTable.java
@@ -49,8 +49,6 @@ public class PinotOperatorTable extends SqlStdOperatorTable {
private static @MonotonicNonNull PinotOperatorTable _instance;
- public static final SqlFunction COALESCE = new PinotSqlCoalesceFunction();
-
// TODO: clean up lazy init by using
Suppliers.memorized(this::computeInstance) and make getter wrapped around
// supplier instance. this should replace all lazy init static objects in
the codebase
public static synchronized PinotOperatorTable instance() {
@@ -75,6 +73,12 @@ public class PinotOperatorTable extends SqlStdOperatorTable {
* which are multistage enabled.
*/
public final void initNoDuplicate() {
+ // Pinot supports native COALESCE function, thus no need to create CASE
WHEN conversion.
+ register(new PinotSqlCoalesceFunction());
+ // Ensure ArrayValueConstructor is registered before ArrayQueryConstructor
+ register(ARRAY_VALUE_CONSTRUCTOR);
+
+ // TODO: reflection based registration is not ideal, we should use a
static list of operators and register them
// Use reflection to register the expressions stored in public fields.
for (Field field : getClass().getFields()) {
try {
diff --git
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToPlanNodeConverter.java
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToPlanNodeConverter.java
index a4e6be355a..b0b7545677 100644
---
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToPlanNodeConverter.java
+++
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToPlanNodeConverter.java
@@ -219,7 +219,7 @@ public final class RelToPlanNodeConverter {
case BIGINT:
return isArray ? DataSchema.ColumnDataType.LONG_ARRAY :
DataSchema.ColumnDataType.LONG;
case DECIMAL:
- return resolveDecimal(relDataType);
+ return resolveDecimal(relDataType, isArray);
case FLOAT:
case REAL:
return isArray ? DataSchema.ColumnDataType.FLOAT_ARRAY :
DataSchema.ColumnDataType.FLOAT;
@@ -259,31 +259,32 @@ public final class RelToPlanNodeConverter {
}
/**
- * Calcite uses DEMICAL type to infer data type hoisting and infer
arithmetic result types. down casting this
- * back to the proper primitive type for Pinot.
+ * Calcite uses DEMICAL type to infer data type hoisting and infer
arithmetic result types. down casting this back to
+ * the proper primitive type for Pinot.
*
* @param relDataType the DECIMAL rel data type.
+ * @param isArray
* @return proper {@link DataSchema.ColumnDataType}.
* @see {@link org.apache.calcite.rel.type.RelDataTypeFactoryImpl#decimalOf}.
*/
- private static DataSchema.ColumnDataType resolveDecimal(RelDataType
relDataType) {
+ private static DataSchema.ColumnDataType resolveDecimal(RelDataType
relDataType, boolean isArray) {
int precision = relDataType.getPrecision();
int scale = relDataType.getScale();
if (scale == 0) {
if (precision <= 10) {
- return DataSchema.ColumnDataType.INT;
+ return isArray ? DataSchema.ColumnDataType.INT_ARRAY :
DataSchema.ColumnDataType.INT;
} else if (precision <= 38) {
- return DataSchema.ColumnDataType.LONG;
+ return isArray ? DataSchema.ColumnDataType.LONG_ARRAY :
DataSchema.ColumnDataType.LONG;
} else {
- return DataSchema.ColumnDataType.BIG_DECIMAL;
+ return isArray ? DataSchema.ColumnDataType.DOUBLE_ARRAY :
DataSchema.ColumnDataType.BIG_DECIMAL;
}
} else {
if (precision <= 14) {
- return DataSchema.ColumnDataType.FLOAT;
+ return isArray ? DataSchema.ColumnDataType.FLOAT_ARRAY :
DataSchema.ColumnDataType.FLOAT;
} else if (precision <= 30) {
- return DataSchema.ColumnDataType.DOUBLE;
+ return isArray ? DataSchema.ColumnDataType.DOUBLE_ARRAY :
DataSchema.ColumnDataType.DOUBLE;
} else {
- return DataSchema.ColumnDataType.BIG_DECIMAL;
+ return isArray ? DataSchema.ColumnDataType.DOUBLE_ARRAY :
DataSchema.ColumnDataType.BIG_DECIMAL;
}
}
}
diff --git
a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/function/InbuiltFunctionEvaluator.java
b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/function/InbuiltFunctionEvaluator.java
index df896d2c00..823dd23b88 100644
---
a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/function/InbuiltFunctionEvaluator.java
+++
b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/function/InbuiltFunctionEvaluator.java
@@ -20,6 +20,7 @@ package org.apache.pinot.segment.local.function;
import com.google.common.base.Preconditions;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.List;
import org.apache.commons.lang3.StringUtils;
import org.apache.pinot.common.function.FunctionInfo;
@@ -78,6 +79,13 @@ public class InbuiltFunctionEvaluator implements
FunctionEvaluator {
case "not":
Preconditions.checkState(numArguments == 1, "NOT function expects
1 argument, got: %s", numArguments);
return new NotExecutionNode(childNodes[0]);
+ case "arrayvalueconstructor":
+ Object[] values = new Object[numArguments];
+ int i = 0;
+ for (ExpressionContext literal : arguments) {
+ values[i++] = literal.getLiteral().getValue();
+ }
+ return new ArrayConstantExecutionNode(values);
default:
FunctionInfo functionInfo =
FunctionRegistry.getFunctionInfo(functionName, numArguments);
if (functionInfo == null) {
@@ -145,7 +153,7 @@ public class InbuiltFunctionEvaluator implements
FunctionEvaluator {
@Override
public Object execute(GenericRow row) {
- for (ExecutableNode executableNode :_argumentNodes) {
+ for (ExecutableNode executableNode : _argumentNodes) {
Boolean res = (Boolean) executableNode.execute(row);
if (res) {
return true;
@@ -156,7 +164,7 @@ public class InbuiltFunctionEvaluator implements
FunctionEvaluator {
@Override
public Object execute(Object[] values) {
- for (ExecutableNode executableNode :_argumentNodes) {
+ for (ExecutableNode executableNode : _argumentNodes) {
Boolean res = (Boolean) executableNode.execute(values);
if (res) {
return true;
@@ -175,7 +183,7 @@ public class InbuiltFunctionEvaluator implements
FunctionEvaluator {
@Override
public Object execute(GenericRow row) {
- for (ExecutableNode executableNode :_argumentNodes) {
+ for (ExecutableNode executableNode : _argumentNodes) {
Boolean res = (Boolean) executableNode.execute(row);
if (!res) {
return false;
@@ -186,7 +194,7 @@ public class InbuiltFunctionEvaluator implements
FunctionEvaluator {
@Override
public Object execute(Object[] values) {
- for (ExecutableNode executableNode :_argumentNodes) {
+ for (ExecutableNode executableNode : _argumentNodes) {
Boolean res = (Boolean) executableNode.execute(values);
if (!res) {
return false;
@@ -284,6 +292,29 @@ public class InbuiltFunctionEvaluator implements
FunctionEvaluator {
}
}
+ private static class ArrayConstantExecutionNode implements ExecutableNode {
+ final Object[] _value;
+
+ ArrayConstantExecutionNode(Object[] value) {
+ _value = value;
+ }
+
+ @Override
+ public Object[] execute(GenericRow row) {
+ return _value;
+ }
+
+ @Override
+ public Object[] execute(Object[] values) {
+ return _value;
+ }
+
+ @Override
+ public String toString() {
+ return String.format("'%s'", Arrays.toString(_value));
+ }
+ }
+
private static class ColumnExecutionNode implements ExecutableNode {
final String _column;
final int _id;
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]