This is an automated email from the ASF dual-hosted git repository.
xiangfu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push:
new 47eff886cd Add more array scalar functions (#11555)
47eff886cd is described below
commit 47eff886cd0a130cbff9566175f7ec36ea7792e0
Author: Xuanyi Li <[email protected]>
AuthorDate: Tue Sep 19 17:04:28 2023 -0700
Add more array scalar functions (#11555)
* scalar func
* fix unit test
* fix silly bug in intersectIndices
* add indexOfAll for long, float and double, including unit test
---
.../common/function/scalar/ArrayFunctions.java | 81 ++++++++++++++
.../function/BaseTransformFunctionTest.java | 37 +++++++
.../ScalarTransformFunctionWrapperTest.java | 116 +++++++++++++++++++++
3 files changed, 234 insertions(+)
diff --git
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArrayFunctions.java
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArrayFunctions.java
index a15cc931b4..a9a6d39e72 100644
---
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArrayFunctions.java
+++
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArrayFunctions.java
@@ -23,7 +23,9 @@ import it.unimi.dsi.fastutil.ints.IntLinkedOpenHashSet;
import it.unimi.dsi.fastutil.ints.IntSet;
import it.unimi.dsi.fastutil.objects.ObjectLinkedOpenHashSet;
import it.unimi.dsi.fastutil.objects.ObjectSet;
+import java.util.ArrayList;
import java.util.Arrays;
+import java.util.List;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.pinot.spi.annotations.ScalarFunction;
import org.apache.pinot.spi.utils.CommonConstants.NullValuePlaceHolder;
@@ -74,6 +76,85 @@ public class ArrayFunctions {
return ArrayUtils.indexOf(values, valueToFind);
}
+ @ScalarFunction
+ public static int[] arrayIndexOfAllInt(int[] value, int valueToFind) {
+ List<Integer> indices = new ArrayList<>();
+ for (int i = 0; i < value.length; i++) {
+ if (value[i] == valueToFind) {
+ indices.add(i);
+ }
+ }
+ return indices.stream().mapToInt(Integer::intValue).toArray();
+ }
+
+ @ScalarFunction
+ public static int[] arrayIndexOfAllLong(long[] value, long valueToFind) {
+ List<Integer> indices = new ArrayList<>();
+ for (int i = 0; i < value.length; i++) {
+ if (value[i] == valueToFind) {
+ indices.add(i);
+ }
+ }
+ return indices.stream().mapToInt(Integer::intValue).toArray();
+ }
+
+ @ScalarFunction
+ public static int[] arrayIndexOfAllFloat(float[] value, float valueToFind) {
+ List<Integer> indices = new ArrayList<>();
+ for (int i = 0; i < value.length; i++) {
+ if (value[i] == valueToFind) {
+ indices.add(i);
+ }
+ }
+ return indices.stream().mapToInt(Integer::intValue).toArray();
+ }
+
+ @ScalarFunction
+ public static int[] arrayIndexOfAllDouble(double[] value, double
valueToFind) {
+ List<Integer> indices = new ArrayList<>();
+ for (int i = 0; i < value.length; i++) {
+ if (value[i] == valueToFind) {
+ indices.add(i);
+ }
+ }
+ return indices.stream().mapToInt(Integer::intValue).toArray();
+ }
+
+ @ScalarFunction
+ public static int[] arrayIndexOfAllString(String[] value, String
valueToFind) {
+ List<Integer> indices = new ArrayList<>();
+ for (int i = 0; i < value.length; i++) {
+ if (valueToFind.equals(value[i])) {
+ indices.add(i);
+ }
+ }
+ return indices.stream().mapToInt(Integer::intValue).toArray();
+ }
+
+ /**
+ * Assume values1, and values2 are monotonous increasing indices of MV cols.
+ * Here is the common usage:
+ * col1: ["a", "b", "a", "b"]
+ * col2: ["c", "d", "d", "c"]
+ * The user want to get the first index called idx, s.t. col1[idx] == "b" &&
col2[idx] == "d"
+ * arrayElementAtInt(0, intersectIndices(arrayIndexOfAllString(col1, "b"),
arrayIndexOfAllString(col2, "d")))
+ */
+ @ScalarFunction
+ public static int[] intersectIndices(int[] values1, int[] values2) {
+ // TODO: if values1.length << values2.length. Use binary search can speed
up the query
+ int i = 0;
+ int j = 0;
+ List<Integer> indices = new ArrayList<>();
+ while (i < values1.length && j < values2.length) {
+ if (values1[i] == values2[j]) {
+ indices.add(values1[i]);
+ j++;
+ }
+ i++;
+ }
+ return indices.stream().mapToInt(Integer::intValue).toArray();
+ }
+
@ScalarFunction
public static boolean arrayContainsInt(int[] values, int valueToFind) {
return ArrayUtils.contains(values, valueToFind);
diff --git
a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/BaseTransformFunctionTest.java
b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/BaseTransformFunctionTest.java
index ed2a5b4b7b..129c67ad73 100644
---
a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/BaseTransformFunctionTest.java
+++
b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/BaseTransformFunctionTest.java
@@ -102,9 +102,16 @@ public abstract class BaseTransformFunctionTest {
protected static final String STRING_MV_COLUMN = "stringMV";
protected static final String STRING_ALPHANUM_MV_COLUMN = "stringAlphaNumMV";
protected static final String STRING_LONG_MV_COLUMN = "stringLongMV";
+ // deterministic MV is useful for testing IndexOf and IndexOfAll
+ protected static final String STRING_ALPHANUM_MV_COLUMN_2 =
"stringAlphaNumMV2";
protected static final String TIME_COLUMN = "timeColumn";
protected static final String TIMESTAMP_COLUMN = "timestampColumn";
protected static final String TIMESTAMP_COLUMN_NULL = "timestampColumnNull";
+ protected static final String INT_MONO_INCREASING_MV_1 =
"intMonoIncreasingMV1";
+ protected static final String INT_MONO_INCREASING_MV_2 =
"intMonoIncreasingMV2";
+ protected static final String LONG_MV_COLUMN_2 = "longMV2";
+ protected static final String FLOAT_MV_COLUMN_2 = "floatMV2";
+ protected static final String DOUBLE_MV_COLUMN_2 = "doubleMV2";
protected static final String JSON_COLUMN = "json";
protected static final String DEFAULT_JSON_COLUMN = "defaultJson";
@@ -122,11 +129,17 @@ public abstract class BaseTransformFunctionTest {
protected final double[][] _doubleMVValues = new double[NUM_ROWS][];
protected final String[][] _stringMVValues = new String[NUM_ROWS][];
protected final String[][] _stringAlphaNumericMVValues = new
String[NUM_ROWS][];
+ protected final String[][] _stringAlphaNumericMV2Values = new
String[NUM_ROWS][];
protected final String[][] _stringLongFormatMVValues = new
String[NUM_ROWS][];
protected final long[] _timeValues = new long[NUM_ROWS];
protected final String[] _jsonValues = new String[NUM_ROWS];
protected final float[][] _vector1Values = new float[NUM_ROWS][];
protected final float[][] _vector2Values = new float[NUM_ROWS][];
+ protected final int[][] _intMonoIncreasingMV1Values = new int[NUM_ROWS][];
+ protected final int[][] _intMonoIncreasingMV2Values = new int[NUM_ROWS][];
+ protected final long[][] _longMV2Values = new long[NUM_ROWS][];
+ protected final float[][] _floatMV2Values = new float[NUM_ROWS][];
+ protected final double[][] _doubleMV2Values = new double[NUM_ROWS][];
protected Map<String, DataSource> _dataSourceMap;
protected ProjectionBlock _projectionBlock;
@@ -155,9 +168,15 @@ public abstract class BaseTransformFunctionTest {
_doubleMVValues[i] = new double[numValues];
_stringMVValues[i] = new String[numValues];
_stringAlphaNumericMVValues[i] = new String[numValues];
+ _stringAlphaNumericMV2Values[i] = new String[numValues];
_stringLongFormatMVValues[i] = new String[numValues];
_vector1Values[i] = new float[VECTOR_DIM_SIZE];
_vector2Values[i] = new float[VECTOR_DIM_SIZE];
+ _intMonoIncreasingMV1Values[i] = new int[numValues];
+ _intMonoIncreasingMV2Values[i] = new int[numValues];
+ _longMV2Values[i] = new long[numValues];
+ _floatMV2Values[i] = new float[numValues];
+ _doubleMV2Values[i] = new double[numValues];
for (int j = 0; j < numValues; j++) {
_intMVValues[i][j] = 1 + RANDOM.nextInt(MAX_MULTI_VALUE);
@@ -166,7 +185,13 @@ public abstract class BaseTransformFunctionTest {
_doubleMVValues[i][j] = 1 + RANDOM.nextDouble();
_stringMVValues[i][j] = df.format(_intSVValues[i] *
RANDOM.nextDouble());
_stringAlphaNumericMVValues[i][j] =
RandomStringUtils.randomAlphanumeric(26);
+ _stringAlphaNumericMV2Values[i][j] = "a";
_stringLongFormatMVValues[i][j] = df.format(_intSVValues[i] *
RANDOM.nextLong());
+ _intMonoIncreasingMV1Values[i][j] = j;
+ _intMonoIncreasingMV2Values[i][j] = j + 1;
+ _longMV2Values[i][j] = 1L;
+ _floatMV2Values[i][j] = 1.0f;
+ _doubleMV2Values[i][j] = 1.0;
}
for (int j = 0; j < VECTOR_DIM_SIZE; j++) {
@@ -219,6 +244,7 @@ public abstract class BaseTransformFunctionTest {
map.put(DOUBLE_MV_COLUMN, ArrayUtils.toObject(_doubleMVValues[i]));
map.put(STRING_MV_COLUMN, _stringMVValues[i]);
map.put(STRING_ALPHANUM_MV_COLUMN, _stringAlphaNumericMVValues[i]);
+ map.put(STRING_ALPHANUM_MV_COLUMN_2, _stringAlphaNumericMV2Values[i]);
map.put(STRING_LONG_MV_COLUMN, _stringLongFormatMVValues[i]);
map.put(TIMESTAMP_COLUMN, _timeValues[i]);
if (isNullRow(i)) {
@@ -229,6 +255,11 @@ public abstract class BaseTransformFunctionTest {
map.put(TIME_COLUMN, _timeValues[i]);
_jsonValues[i] = JsonUtils.objectToJsonNode(map).toString();
map.put(JSON_COLUMN, _jsonValues[i]);
+ map.put(INT_MONO_INCREASING_MV_1,
ArrayUtils.toObject(_intMonoIncreasingMV1Values[i]));
+ map.put(INT_MONO_INCREASING_MV_2,
ArrayUtils.toObject(_intMonoIncreasingMV2Values[i]));
+ map.put(LONG_MV_COLUMN_2, ArrayUtils.toObject(_longMV2Values[i]));
+ map.put(FLOAT_MV_COLUMN_2, ArrayUtils.toObject(_floatMV2Values[i]));
+ map.put(DOUBLE_MV_COLUMN_2, ArrayUtils.toObject(_doubleMV2Values[i]));
GenericRow row = new GenericRow();
row.init(map);
rows.add(row);
@@ -254,10 +285,16 @@ public abstract class BaseTransformFunctionTest {
.addMultiValueDimension(DOUBLE_MV_COLUMN, FieldSpec.DataType.DOUBLE)
.addMultiValueDimension(STRING_MV_COLUMN, FieldSpec.DataType.STRING)
.addMultiValueDimension(STRING_ALPHANUM_MV_COLUMN,
FieldSpec.DataType.STRING)
+ .addMultiValueDimension(STRING_ALPHANUM_MV_COLUMN_2,
FieldSpec.DataType.STRING)
.addMultiValueDimension(STRING_LONG_MV_COLUMN,
FieldSpec.DataType.STRING)
.addMultiValueDimension(VECTOR_1_COLUMN, FieldSpec.DataType.FLOAT)
.addMultiValueDimension(VECTOR_2_COLUMN, FieldSpec.DataType.FLOAT)
.addMultiValueDimension(ZERO_VECTOR_COLUMN, FieldSpec.DataType.FLOAT)
+ .addMultiValueDimension(INT_MONO_INCREASING_MV_1,
FieldSpec.DataType.INT)
+ .addMultiValueDimension(INT_MONO_INCREASING_MV_2,
FieldSpec.DataType.INT)
+ .addMultiValueDimension(LONG_MV_COLUMN_2, FieldSpec.DataType.LONG)
+ .addMultiValueDimension(FLOAT_MV_COLUMN_2, FieldSpec.DataType.FLOAT)
+ .addMultiValueDimension(DOUBLE_MV_COLUMN_2, FieldSpec.DataType.DOUBLE)
.addDateTime(TIMESTAMP_COLUMN, FieldSpec.DataType.TIMESTAMP,
"1:MILLISECONDS:EPOCH", "1:MILLISECONDS")
.addDateTime(TIMESTAMP_COLUMN_NULL, FieldSpec.DataType.TIMESTAMP,
"1:MILLISECONDS:EPOCH", "1:MILLISECONDS")
.addTime(new TimeGranularitySpec(FieldSpec.DataType.LONG,
TimeUnit.MILLISECONDS, TIME_COLUMN), null).build();
diff --git
a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ScalarTransformFunctionWrapperTest.java
b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ScalarTransformFunctionWrapperTest.java
index c16f0e9c23..5befeccc07 100644
---
a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ScalarTransformFunctionWrapperTest.java
+++
b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ScalarTransformFunctionWrapperTest.java
@@ -953,6 +953,122 @@ public class ScalarTransformFunctionWrapperTest extends
BaseTransformFunctionTes
testTransformFunction(transformFunction, expectedValues);
}
+ @Test
+ public void testArrayIndexOfAllInt() {
+ ExpressionContext expression = RequestContextUtils.getExpression(
+ String.format("array_index_of_all_int(%s, 0)",
INT_MONO_INCREASING_MV_1));
+ TransformFunction transformFunction =
TransformFunctionFactory.get(expression, _dataSourceMap);
+ assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper);
+ assertEquals(transformFunction.getResultMetadata().getDataType(),
DataType.INT);
+ assertFalse(transformFunction.getResultMetadata().isSingleValue());
+ int[][] expectedValues = new int[NUM_ROWS][];
+ for (int i = 0; i < NUM_ROWS; i++) {
+ int[] expectedValue = {0};
+ expectedValues[i] = expectedValue;
+ }
+ testTransformFunctionMV(transformFunction, expectedValues);
+ }
+
+ @Test
+ public void testArrayIndexOfAllLong() {
+ ExpressionContext expression = RequestContextUtils.getExpression(
+ String.format("array_index_of_all_long(%s, 1)", LONG_MV_COLUMN_2));
+ TransformFunction transformFunction =
TransformFunctionFactory.get(expression, _dataSourceMap);
+ assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper);
+ assertEquals(transformFunction.getResultMetadata().getDataType(),
DataType.INT);
+ assertFalse(transformFunction.getResultMetadata().isSingleValue());
+ int[][] expectedValues = new int[NUM_ROWS][];
+ for (int i = 0; i < NUM_ROWS; i++) {
+ int len = _longMV2Values[i].length;
+ int[] expectedValue = new int[len];
+ for (int j = 0; j < len; j++) {
+ expectedValue[j] = j;
+ }
+ expectedValues[i] = expectedValue;
+ }
+ testTransformFunctionMV(transformFunction, expectedValues);
+ }
+
+ @Test
+ public void testArrayIndexOfAllFloat() {
+ ExpressionContext expression = RequestContextUtils.getExpression(
+ String.format("array_index_of_all_float(%s, 1.0)", FLOAT_MV_COLUMN_2));
+ TransformFunction transformFunction =
TransformFunctionFactory.get(expression, _dataSourceMap);
+ assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper);
+ assertEquals(transformFunction.getResultMetadata().getDataType(),
DataType.INT);
+ assertFalse(transformFunction.getResultMetadata().isSingleValue());
+ int[][] expectedValues = new int[NUM_ROWS][];
+ for (int i = 0; i < NUM_ROWS; i++) {
+ int len = _floatMV2Values[i].length;
+ int[] expectedValue = new int[len];
+ for (int j = 0; j < len; j++) {
+ expectedValue[j] = j;
+ }
+ expectedValues[i] = expectedValue;
+ }
+ testTransformFunctionMV(transformFunction, expectedValues);
+ }
+
+ @Test
+ public void testArrayIndexOfAllDouble() {
+ ExpressionContext expression = RequestContextUtils.getExpression(
+ String.format("array_index_of_all_double(%s, 1.0)",
DOUBLE_MV_COLUMN_2));
+ TransformFunction transformFunction =
TransformFunctionFactory.get(expression, _dataSourceMap);
+ assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper);
+ assertEquals(transformFunction.getResultMetadata().getDataType(),
DataType.INT);
+ assertFalse(transformFunction.getResultMetadata().isSingleValue());
+ int[][] expectedValues = new int[NUM_ROWS][];
+ for (int i = 0; i < NUM_ROWS; i++) {
+ int len = _doubleMV2Values[i].length;
+ int[] expectedValue = new int[len];
+ for (int j = 0; j < len; j++) {
+ expectedValue[j] = j;
+ }
+ expectedValues[i] = expectedValue;
+ }
+ testTransformFunctionMV(transformFunction, expectedValues);
+ }
+
+ @Test
+ public void testArrayIndexOfAllString() {
+ ExpressionContext expression = RequestContextUtils.getExpression(
+ String.format("array_index_of_all_string(%s, 'a')",
STRING_ALPHANUM_MV_COLUMN_2));
+ TransformFunction transformFunction =
TransformFunctionFactory.get(expression, _dataSourceMap);
+ assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper);
+ assertEquals(transformFunction.getResultMetadata().getDataType(),
DataType.INT);
+ assertFalse(transformFunction.getResultMetadata().isSingleValue());
+ int[][] expectedValues = new int[NUM_ROWS][];
+ for (int i = 0; i < NUM_ROWS; i++) {
+ int len = _stringAlphaNumericMV2Values[i].length;
+ int[] expectedValue = new int[len];
+ for (int j = 0; j < len; j++) {
+ expectedValue[j] = j;
+ }
+ expectedValues[i] = expectedValue;
+ }
+ testTransformFunctionMV(transformFunction, expectedValues);
+ }
+
+ @Test
+ public void testIntersectIndices() {
+ ExpressionContext expression = RequestContextUtils.getExpression(
+ String.format("intersect_indices(%s, %s)", INT_MONO_INCREASING_MV_1,
INT_MONO_INCREASING_MV_2));
+ TransformFunction transformFunction =
TransformFunctionFactory.get(expression, _dataSourceMap);
+ assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper);
+ assertEquals(transformFunction.getResultMetadata().getDataType(),
DataType.INT);
+ assertFalse(transformFunction.getResultMetadata().isSingleValue());
+ int[][] expectedValues = new int[NUM_ROWS][];
+ for (int i = 0; i < NUM_ROWS; i++) {
+ int len = _intMonoIncreasingMV1Values[i].length;
+ int[] expectedValue = new int[len - 1];
+ for (int j = 0; j < expectedValue.length; j++) {
+ expectedValue[j] = j + 1;
+ }
+ expectedValues[i] = expectedValue;
+ }
+ testTransformFunctionMV(transformFunction, expectedValues);
+ }
+
@Test
public void testBase64TransformFunction() {
ExpressionContext expression =
RequestContextUtils.getExpression(String.format("toBase64(%s)",
BYTES_SV_COLUMN));
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]