This is an automated email from the ASF dual-hosted git repository.
jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push:
new 6c01ce2 Add IN function (#7542)
6c01ce2 is described below
commit 6c01ce2bbf551768ae27b394c44b91fe667eb75c
Author: Yupeng Fu <[email protected]>
AuthorDate: Mon Oct 11 10:51:28 2021 -0700
Add IN function (#7542)
Similar to IN function in WHERE clause, the IN transforms function checks
if the value is in the given set of values, and returns a boolean value
---
.../common/function/TransformFunctionType.java | 1 +
.../transform/function/InTransformFunction.java | 381 +++++++++++++++++++++
.../function/TransformFunctionFactory.java | 2 +
.../function/InTransformFunctionTest.java | 194 +++++++++++
4 files changed, 578 insertions(+)
diff --git
a/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java
b/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java
index 8b2a8af..5eda773 100644
---
a/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java
+++
b/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java
@@ -44,6 +44,7 @@ public enum TransformFunctionType {
GREATER_THAN_OR_EQUAL("greater_than_or_equal"),
LESS_THAN("less_than"),
LESS_THAN_OR_EQUAL("less_than_or_equal"),
+ IN("in"),
AND("and"),
OR("or"),
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/InTransformFunction.java
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/InTransformFunction.java
new file mode 100644
index 0000000..6245108
--- /dev/null
+++
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/InTransformFunction.java
@@ -0,0 +1,381 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.operator.transform.function;
+
+import com.google.common.base.Preconditions;
+import it.unimi.dsi.fastutil.doubles.DoubleOpenHashSet;
+import it.unimi.dsi.fastutil.floats.FloatOpenHashSet;
+import it.unimi.dsi.fastutil.ints.IntOpenHashSet;
+import it.unimi.dsi.fastutil.longs.LongOpenHashSet;
+import it.unimi.dsi.fastutil.objects.ObjectOpenHashSet;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import org.apache.pinot.common.function.TransformFunctionType;
+import org.apache.pinot.core.operator.blocks.ProjectionBlock;
+import org.apache.pinot.core.operator.transform.TransformResultMetadata;
+import org.apache.pinot.core.plan.DocIdSetPlanNode;
+import org.apache.pinot.segment.spi.datasource.DataSource;
+import org.apache.pinot.spi.data.FieldSpec.DataType;
+import org.apache.pinot.spi.utils.ByteArray;
+import org.apache.pinot.spi.utils.BytesUtils;
+
+
+/**
+ * The IN transform function takes one main expression (lhs) and multiple
value expressions.
+ * <p>For each docId, the function returns {@code true} if the set of values
contains the value of the expression,
+ * {@code false} otherwise.
+ * <p>E.g. {@code SELECT col IN ('a','b','c') FROM myTable)}
+ */
+@SuppressWarnings({"rawtypes", "unchecked"})
+public class InTransformFunction extends BaseTransformFunction {
+ private TransformFunction _mainFunction;
+ private TransformFunction[] _valueFunctions;
+ private Set _valueSet;
+
+ @Override
+ public String getName() {
+ return TransformFunctionType.IN.getName();
+ }
+
+ @Override
+ public void init(List<TransformFunction> arguments, Map<String, DataSource>
dataSourceMap) {
+ int numArguments = arguments.size();
+ Preconditions.checkArgument(numArguments >= 2,
+ "At least 2 arguments are required for IN transform function:
expression, values");
+ _mainFunction = arguments.get(0);
+
+ boolean allLiteralValues = true;
+ ObjectOpenHashSet<String> stringValues = new
ObjectOpenHashSet<>(numArguments - 1);
+ for (int i = 1; i < numArguments; i++) {
+ TransformFunction valueFunction = arguments.get(i);
+ if (valueFunction instanceof LiteralTransformFunction) {
+ stringValues.add(((LiteralTransformFunction)
valueFunction).getLiteral());
+ } else {
+ allLiteralValues = false;
+ break;
+ }
+ }
+
+ if (allLiteralValues) {
+ int numValues = stringValues.size();
+ DataType storedType =
_mainFunction.getResultMetadata().getDataType().getStoredType();
+ switch (storedType) {
+ case INT:
+ IntOpenHashSet intValues = new IntOpenHashSet(numValues);
+ for (String stringValue : stringValues) {
+ intValues.add(Integer.parseInt(stringValue));
+ }
+ _valueSet = intValues;
+ break;
+ case LONG:
+ LongOpenHashSet longValues = new LongOpenHashSet(numValues);
+ for (String stringValue : stringValues) {
+ longValues.add(Long.parseLong(stringValue));
+ }
+ _valueSet = longValues;
+ break;
+ case FLOAT:
+ FloatOpenHashSet floatValues = new FloatOpenHashSet(numValues);
+ for (String stringValue : stringValues) {
+ floatValues.add(Float.parseFloat(stringValue));
+ }
+ _valueSet = floatValues;
+ break;
+ case DOUBLE:
+ DoubleOpenHashSet doubleValues = new DoubleOpenHashSet(numValues);
+ for (String stringValue : stringValues) {
+ doubleValues.add(Double.parseDouble(stringValue));
+ }
+ _valueSet = doubleValues;
+ break;
+ case STRING:
+ _valueSet = stringValues;
+ break;
+ case BYTES:
+ ObjectOpenHashSet<ByteArray> bytesValues = new
ObjectOpenHashSet<>(numValues);
+ for (String stringValue : stringValues) {
+ bytesValues.add(BytesUtils.toByteArray(stringValue));
+ }
+ _valueSet = bytesValues;
+ break;
+ default:
+ throw new IllegalStateException();
+ }
+ } else {
+
Preconditions.checkArgument(_mainFunction.getResultMetadata().isSingleValue(),
+ "The first argument for IN transform function must be single-valued
when there are non-literal values");
+ _valueFunctions = new TransformFunction[numArguments - 1];
+ for (int i = 1; i < numArguments; i++) {
+ TransformFunction valueFunction = arguments.get(i);
+
Preconditions.checkArgument(valueFunction.getResultMetadata().isSingleValue(),
+ "The values for IN transform function must be single-valued");
+ _valueFunctions[i - 1] = valueFunction;
+ }
+ }
+ }
+
+ @Override
+ public TransformResultMetadata getResultMetadata() {
+ return BOOLEAN_SV_NO_DICTIONARY_METADATA;
+ }
+
+ @Override
+ public int[] transformToIntValuesSV(ProjectionBlock projectionBlock) {
+ if (_intValuesSV == null) {
+ _intValuesSV = new int[DocIdSetPlanNode.MAX_DOC_PER_CALL];
+ } else {
+ Arrays.fill(_intValuesSV, 0);
+ }
+
+ int length = projectionBlock.getNumDocs();
+ TransformResultMetadata mainFunctionMetadata =
_mainFunction.getResultMetadata();
+ DataType storedType = mainFunctionMetadata.getDataType().getStoredType();
+ if (_valueSet != null) {
+ if (_mainFunction.getResultMetadata().isSingleValue()) {
+ switch (storedType) {
+ case INT:
+ IntOpenHashSet inIntValues = (IntOpenHashSet) _valueSet;
+ int[] intValues =
_mainFunction.transformToIntValuesSV(projectionBlock);
+ for (int i = 0; i < length; i++) {
+ if (inIntValues.contains(intValues[i])) {
+ _intValuesSV[i] = 1;
+ }
+ }
+ break;
+ case LONG:
+ LongOpenHashSet inLongValues = (LongOpenHashSet) _valueSet;
+ long[] longValues =
_mainFunction.transformToLongValuesSV(projectionBlock);
+ for (int i = 0; i < length; i++) {
+ if (inLongValues.contains(longValues[i])) {
+ _intValuesSV[i] = 1;
+ }
+ }
+ break;
+ case FLOAT:
+ FloatOpenHashSet inFloatValues = (FloatOpenHashSet) _valueSet;
+ float[] floatValues =
_mainFunction.transformToFloatValuesSV(projectionBlock);
+ for (int i = 0; i < length; i++) {
+ if (inFloatValues.contains(floatValues[i])) {
+ _intValuesSV[i] = 1;
+ }
+ }
+ break;
+ case DOUBLE:
+ DoubleOpenHashSet inDoubleValues = (DoubleOpenHashSet) _valueSet;
+ double[] doubleValues =
_mainFunction.transformToDoubleValuesSV(projectionBlock);
+ for (int i = 0; i < length; i++) {
+ if (inDoubleValues.contains(doubleValues[i])) {
+ _intValuesSV[i] = 1;
+ }
+ }
+ break;
+ case STRING:
+ ObjectOpenHashSet<String> inStringValues =
(ObjectOpenHashSet<String>) _valueSet;
+ String[] stringValues =
_mainFunction.transformToStringValuesSV(projectionBlock);
+ for (int i = 0; i < length; i++) {
+ if (inStringValues.contains(stringValues[i])) {
+ _intValuesSV[i] = 1;
+ }
+ }
+ break;
+ case BYTES:
+ ObjectOpenHashSet<ByteArray> inBytesValues =
(ObjectOpenHashSet<ByteArray>) _valueSet;
+ byte[][] bytesValues =
_mainFunction.transformToBytesValuesSV(projectionBlock);
+ for (int i = 0; i < length; i++) {
+ if (inBytesValues.contains(new ByteArray(bytesValues[i]))) {
+ _intValuesSV[i] = 1;
+ }
+ }
+ break;
+ default:
+ throw new IllegalStateException();
+ }
+ } else {
+ switch (storedType) {
+ case INT:
+ IntOpenHashSet inIntValues = (IntOpenHashSet) _valueSet;
+ int[][] intValues =
_mainFunction.transformToIntValuesMV(projectionBlock);
+ for (int i = 0; i < length; i++) {
+ for (int intValue : intValues[i]) {
+ if (inIntValues.contains(intValue)) {
+ _intValuesSV[i] = 1;
+ break;
+ }
+ }
+ }
+ break;
+ case LONG:
+ LongOpenHashSet inLongValues = (LongOpenHashSet) _valueSet;
+ long[][] longValues =
_mainFunction.transformToLongValuesMV(projectionBlock);
+ for (int i = 0; i < length; i++) {
+ for (long longValue : longValues[i]) {
+ if (inLongValues.contains(longValue)) {
+ _intValuesSV[i] = 1;
+ break;
+ }
+ }
+ }
+ break;
+ case FLOAT:
+ FloatOpenHashSet inFloatValues = (FloatOpenHashSet) _valueSet;
+ float[][] floatValues =
_mainFunction.transformToFloatValuesMV(projectionBlock);
+ for (int i = 0; i < length; i++) {
+ for (float floatValue : floatValues[i]) {
+ if (inFloatValues.contains(floatValue)) {
+ _intValuesSV[i] = 1;
+ break;
+ }
+ }
+ }
+ break;
+ case DOUBLE:
+ DoubleOpenHashSet inDoubleValues = (DoubleOpenHashSet) _valueSet;
+ double[][] doubleValues =
_mainFunction.transformToDoubleValuesMV(projectionBlock);
+ for (int i = 0; i < length; i++) {
+ for (double doubleValue : doubleValues[i]) {
+ if (inDoubleValues.contains(doubleValue)) {
+ _intValuesSV[i] = 1;
+ break;
+ }
+ }
+ }
+ break;
+ case STRING:
+ ObjectOpenHashSet<String> inStringValues =
(ObjectOpenHashSet<String>) _valueSet;
+ String[][] stringValues =
_mainFunction.transformToStringValuesMV(projectionBlock);
+ for (int i = 0; i < length; i++) {
+ for (String stringValue : stringValues[i]) {
+ if (inStringValues.contains(stringValue)) {
+ _intValuesSV[i] = 1;
+ break;
+ }
+ }
+ }
+ break;
+ default:
+ throw new IllegalStateException();
+ }
+ }
+ } else {
+ int numValues = _valueFunctions.length;
+ switch (storedType) {
+ case INT:
+ int[] intValues =
_mainFunction.transformToIntValuesSV(projectionBlock);
+ int[][] inIntValues = new int[numValues][];
+ for (int i = 0; i < numValues; i++) {
+ inIntValues[i] =
_valueFunctions[i].transformToIntValuesSV(projectionBlock);
+ }
+ for (int i = 0; i < length; i++) {
+ for (int[] inIntValue : inIntValues) {
+ if (intValues[i] == inIntValue[i]) {
+ _intValuesSV[i] = 1;
+ break;
+ }
+ }
+ }
+ break;
+ case LONG:
+ long[] longValues =
_mainFunction.transformToLongValuesSV(projectionBlock);
+ long[][] inLongValues = new long[numValues][];
+ for (int i = 0; i < numValues; i++) {
+ inLongValues[i] =
_valueFunctions[i].transformToLongValuesSV(projectionBlock);
+ }
+ for (int i = 0; i < length; i++) {
+ for (long[] inLongValue : inLongValues) {
+ if (longValues[i] == inLongValue[i]) {
+ _intValuesSV[i] = 1;
+ break;
+ }
+ }
+ }
+ break;
+ case FLOAT:
+ float[] floatValues =
_mainFunction.transformToFloatValuesSV(projectionBlock);
+ float[][] inFloatValues = new float[numValues][];
+ for (int i = 0; i < numValues; i++) {
+ inFloatValues[i] =
_valueFunctions[i].transformToFloatValuesSV(projectionBlock);
+ }
+ for (int i = 0; i < length; i++) {
+ // Check int bits to be aligned with the Set (Float.equals())
behavior
+ int intBits = Float.floatToIntBits(floatValues[i]);
+ for (float[] inFloatValue : inFloatValues) {
+ if (intBits == Float.floatToIntBits(inFloatValue[i])) {
+ _intValuesSV[i] = 1;
+ break;
+ }
+ }
+ }
+ break;
+ case DOUBLE:
+ double[] doubleValues =
_mainFunction.transformToDoubleValuesSV(projectionBlock);
+ double[][] inDoubleValues = new double[numValues][];
+ for (int i = 0; i < numValues; i++) {
+ inDoubleValues[i] =
_valueFunctions[i].transformToDoubleValuesSV(projectionBlock);
+ }
+ for (int i = 0; i < length; i++) {
+ // Check long bits to be aligned with the Set (Double.equals())
behavior
+ long longBits = Double.doubleToLongBits(doubleValues[i]);
+ for (double[] inDoubleValue : inDoubleValues) {
+ if (longBits == Double.doubleToLongBits(inDoubleValue[i])) {
+ _intValuesSV[i] = 1;
+ break;
+ }
+ }
+ }
+ break;
+ case STRING:
+ String[] stringValues =
_mainFunction.transformToStringValuesSV(projectionBlock);
+ String[][] inStringValues = new String[numValues][];
+ for (int i = 0; i < numValues; i++) {
+ inStringValues[i] =
_valueFunctions[i].transformToStringValuesSV(projectionBlock);
+ }
+ for (int i = 0; i < length; i++) {
+ for (String[] inStringValue : inStringValues) {
+ if (stringValues[i].equals(inStringValue[i])) {
+ _intValuesSV[i] = 1;
+ break;
+ }
+ }
+ }
+ break;
+ case BYTES:
+ byte[][] bytesValues =
_mainFunction.transformToBytesValuesSV(projectionBlock);
+ byte[][][] inBytesValues = new byte[numValues][][];
+ for (int i = 0; i < numValues; i++) {
+ inBytesValues[i] =
_valueFunctions[i].transformToBytesValuesSV(projectionBlock);
+ }
+ for (int i = 0; i < length; i++) {
+ for (byte[][] inBytesValue : inBytesValues) {
+ if (Arrays.equals(bytesValues[i], inBytesValue[i])) {
+ _intValuesSV[i] = 1;
+ break;
+ }
+ }
+ }
+ break;
+ default:
+ throw new IllegalStateException();
+ }
+ }
+
+ return _intValuesSV;
+ }
+}
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
index 27b6bb1..b3f6bfa 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
@@ -122,11 +122,13 @@ public class TransformFunctionFactory {
put(canonicalize(TransformFunctionType.LESS_THAN.getName().toLowerCase()),
LessThanTransformFunction.class);
put(canonicalize(TransformFunctionType.LESS_THAN_OR_EQUAL.getName().toLowerCase()),
LessThanOrEqualTransformFunction.class);
+ put(canonicalize(TransformFunctionType.IN.getName().toLowerCase()),
InTransformFunction.class);
// logical functions
put(canonicalize(TransformFunctionType.AND.getName().toLowerCase()),
AndOperatorTransformFunction.class);
put(canonicalize(TransformFunctionType.OR.getName().toLowerCase()),
OrOperatorTransformFunction.class);
+
// geo functions
// geo constructors
put(canonicalize(TransformFunctionType.ST_GEOG_FROM_TEXT.getName().toLowerCase()),
diff --git
a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/InTransformFunctionTest.java
b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/InTransformFunctionTest.java
new file mode 100644
index 0000000..fed5c3e
--- /dev/null
+++
b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/InTransformFunctionTest.java
@@ -0,0 +1,194 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.operator.transform.function;
+
+import com.google.common.collect.Sets;
+import java.util.Set;
+import org.apache.pinot.common.function.TransformFunctionType;
+import org.apache.pinot.common.request.context.ExpressionContext;
+import org.apache.pinot.common.request.context.RequestContextUtils;
+import org.apache.pinot.spi.utils.ByteArray;
+import org.apache.pinot.spi.utils.BytesUtils;
+import org.testng.annotations.Test;
+
+import static org.testng.Assert.assertEquals;
+import static org.testng.Assert.assertTrue;
+
+
+public class InTransformFunctionTest extends BaseTransformFunctionTest {
+
+ @Test
+ public void testIntInTransformFunction() {
+ String expressionStr =
+ String.format("%s IN (%d, %d, %d)", INT_SV_COLUMN, _intSVValues[2],
_intSVValues[5], _intSVValues[9]);
+ ExpressionContext expression =
RequestContextUtils.getExpressionFromSQL(expressionStr);
+ TransformFunction transformFunction =
TransformFunctionFactory.get(expression, _dataSourceMap);
+ assertTrue(transformFunction instanceof InTransformFunction);
+ assertEquals(transformFunction.getName(),
TransformFunctionType.IN.getName());
+
+ Set<Integer> inValues = Sets.newHashSet(_intSVValues[2], _intSVValues[5],
_intSVValues[9]);
+ int[] intValues =
transformFunction.transformToIntValuesSV(_projectionBlock);
+ for (int i = 0; i < NUM_ROWS; i++) {
+ if (i == 2 || i == 5 || i == 9) {
+ assertEquals(intValues[i], 1);
+ }
+ assertEquals(intValues[i], inValues.contains(_intSVValues[i]) ? 1 : 0);
+ }
+ }
+
+ @Test
+ public void testIntMVInTransformFunction() {
+ String expressionStr =
+ String.format("%s IN (%d, %d, %d)", INT_MV_COLUMN, _intMVValues[2][0],
_intMVValues[5][0], _intMVValues[9][0]);
+ ExpressionContext expression =
RequestContextUtils.getExpressionFromSQL(expressionStr);
+ TransformFunction transformFunction =
TransformFunctionFactory.get(expression, _dataSourceMap);
+ assertTrue(transformFunction instanceof InTransformFunction);
+ assertEquals(transformFunction.getName(),
TransformFunctionType.IN.getName());
+
+ Set<Integer> inValues = Sets.newHashSet(_intMVValues[2][0],
_intMVValues[5][0], _intMVValues[9][0]);
+ int[] intValues =
transformFunction.transformToIntValuesSV(_projectionBlock);
+ for (int i = 0; i < NUM_ROWS; i++) {
+ if (i == 2 || i == 5 || i == 9) {
+ assertEquals(intValues[i], 1);
+ }
+ int expected = 0;
+ for (int intValue : _intMVValues[i]) {
+ if (inValues.contains(intValue)) {
+ expected = 1;
+ break;
+ }
+ }
+ assertEquals(intValues[i], expected);
+ }
+ }
+
+ @Test
+ public void testIntInTransformFunctionWithTransformedValues() {
+ String expressionStr = String.format("%s IN (%d, 1+1, 4+5)",
INT_SV_COLUMN, _intSVValues[2]);
+ ExpressionContext expression =
RequestContextUtils.getExpressionFromSQL(expressionStr);
+ TransformFunction transformFunction =
TransformFunctionFactory.get(expression, _dataSourceMap);
+ assertTrue(transformFunction instanceof InTransformFunction);
+ assertEquals(transformFunction.getName(),
TransformFunctionType.IN.getName());
+
+ Set<Integer> inValues = Sets.newHashSet(_intSVValues[2], 2, 9);
+ int[] intValues =
transformFunction.transformToIntValuesSV(_projectionBlock);
+ for (int i = 0; i < NUM_ROWS; i++) {
+ if (i == 2) {
+ assertEquals(intValues[i], 1);
+ }
+ assertEquals(intValues[i], inValues.contains(_intSVValues[i]) ? 1 : 0);
+ }
+ }
+
+ @Test
+ public void testLongInTransformFunction() {
+ String expressionStr =
+ String.format("%s IN (%d, %d, %d)", LONG_SV_COLUMN, _longSVValues[2],
_longSVValues[7], _longSVValues[11]);
+ ExpressionContext expression =
RequestContextUtils.getExpressionFromSQL(expressionStr);
+ TransformFunction transformFunction =
TransformFunctionFactory.get(expression, _dataSourceMap);
+ assertTrue(transformFunction instanceof InTransformFunction);
+ assertEquals(transformFunction.getName(),
TransformFunctionType.IN.getName());
+
+ Set<Long> inValues = Sets.newHashSet(_longSVValues[2], _longSVValues[7],
_longSVValues[11]);
+ int[] intValues =
transformFunction.transformToIntValuesSV(_projectionBlock);
+ for (int i = 0; i < NUM_ROWS; i++) {
+ if (i == 2 || i == 7 || i == 11) {
+ assertEquals(intValues[i], 1);
+ }
+ assertEquals(intValues[i], inValues.contains(_longSVValues[i]) ? 1 : 0);
+ }
+ }
+
+ @Test
+ public void testFloatInTransformFunction() {
+ String expressionStr =
+ String.format("%s IN (%s, %s, %s)", FLOAT_SV_COLUMN,
_floatSVValues[3], _floatSVValues[7], _floatSVValues[9]);
+ ExpressionContext expression =
RequestContextUtils.getExpressionFromSQL(expressionStr);
+ TransformFunction transformFunction =
TransformFunctionFactory.get(expression, _dataSourceMap);
+ assertTrue(transformFunction instanceof InTransformFunction);
+ assertEquals(transformFunction.getName(),
TransformFunctionType.IN.getName());
+
+ Set<Float> inValues = Sets.newHashSet(_floatSVValues[3],
_floatSVValues[7], _floatSVValues[9]);
+ int[] intValues =
transformFunction.transformToIntValuesSV(_projectionBlock);
+ for (int i = 0; i < NUM_ROWS; i++) {
+ if (i == 3 || i == 7 || i == 9) {
+ assertEquals(intValues[i], 1);
+ }
+ assertEquals(intValues[i], inValues.contains(_floatSVValues[i]) ? 1 : 0);
+ }
+ }
+
+ @Test
+ public void testDoubleInTransformFunction() {
+ String expressionStr = String.format("%s IN (%s, %s, %s)",
DOUBLE_SV_COLUMN, _doubleSVValues[3], _doubleSVValues[7],
+ _doubleSVValues[9]);
+ ExpressionContext expression =
RequestContextUtils.getExpressionFromSQL(expressionStr);
+ TransformFunction transformFunction =
TransformFunctionFactory.get(expression, _dataSourceMap);
+ assertTrue(transformFunction instanceof InTransformFunction);
+ assertEquals(transformFunction.getName(),
TransformFunctionType.IN.getName());
+
+ Set<Double> inValues = Sets.newHashSet(_doubleSVValues[3],
_doubleSVValues[7], _doubleSVValues[9]);
+ int[] intValues =
transformFunction.transformToIntValuesSV(_projectionBlock);
+ for (int i = 0; i < NUM_ROWS; i++) {
+ if (i == 3 || i == 7 || i == 9) {
+ assertEquals(intValues[i], 1);
+ }
+ assertEquals(intValues[i], inValues.contains(_doubleSVValues[i]) ? 1 :
0);
+ }
+ }
+
+ @Test
+ public void testStringInTransformFunction() {
+ String expressionStr =
+ String.format("%s IN ('a','b','%s','%s')", STRING_SV_COLUMN,
_stringSVValues[2], _stringSVValues[5]);
+ ExpressionContext expression =
RequestContextUtils.getExpressionFromSQL(expressionStr);
+ TransformFunction transformFunction =
TransformFunctionFactory.get(expression, _dataSourceMap);
+ assertTrue(transformFunction instanceof InTransformFunction);
+ assertEquals(transformFunction.getName(),
TransformFunctionType.IN.getName());
+
+ Set<String> inValues = Sets.newHashSet("a", "b", _stringSVValues[2],
_stringSVValues[5]);
+ int[] intValues =
transformFunction.transformToIntValuesSV(_projectionBlock);
+ for (int i = 0; i < NUM_ROWS; i++) {
+ if (i == 2 || i == 5) {
+ assertEquals(intValues[i], 1);
+ }
+ assertEquals(intValues[i], inValues.contains(_stringSVValues[i]) ? 1 :
0);
+ }
+ }
+
+ @Test
+ public void testBytesInTransformFunction() {
+ String expressionStr =
+ String.format("%s IN ('%s','%s')", BYTES_SV_COLUMN,
BytesUtils.toHexString(_bytesSVValues[2]),
+ BytesUtils.toHexString(_bytesSVValues[5]));
+ ExpressionContext expression =
RequestContextUtils.getExpressionFromSQL(expressionStr);
+ TransformFunction transformFunction =
TransformFunctionFactory.get(expression, _dataSourceMap);
+ assertTrue(transformFunction instanceof InTransformFunction);
+ assertEquals(transformFunction.getName(),
TransformFunctionType.IN.getName());
+
+ Set<ByteArray> inValues = Sets.newHashSet(new
ByteArray(_bytesSVValues[2]), new ByteArray(_bytesSVValues[5]));
+ int[] intValues =
transformFunction.transformToIntValuesSV(_projectionBlock);
+ for (int i = 0; i < NUM_ROWS; i++) {
+ if (i == 2 || i == 5) {
+ assertEquals(intValues[i], 1);
+ }
+ assertEquals(intValues[i], inValues.contains(new
ByteArray(_bytesSVValues[i])) ? 1 : 0);
+ }
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]