This is an automated email from the ASF dual-hosted git repository.
jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push:
new 5d9a794073 Support arbitrary number of WHEN THEN clauses in the scalar
CASE function (#14125)
5d9a794073 is described below
commit 5d9a79407345b48cbad96d4d1f65cd0d099a9e8a
Author: Yash Mayya <[email protected]>
AuthorDate: Mon Oct 7 23:25:06 2024 +0530
Support arbitrary number of WHEN THEN clauses in the scalar CASE function
(#14125)
---
.../common/function/scalar/ObjectFunctions.java | 143 +--------------------
.../pinot/sql/parsers/CalciteSqlCompilerTest.java | 13 +-
.../postaggregation/PostAggregationFunction.java | 59 +++++----
.../PostAggregationFunctionTest.java | 28 +++-
.../tests/MultiStageEngineIntegrationTest.java | 47 +++++++
5 files changed, 119 insertions(+), 171 deletions(-)
diff --git
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ObjectFunctions.java
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ObjectFunctions.java
index 7b8fa82d11..6c3df77f96 100644
---
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ObjectFunctions.java
+++
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ObjectFunctions.java
@@ -93,147 +93,8 @@ public class ObjectFunctions {
return null;
}
- @Nullable
- @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen"})
- public static Object caseWhen(@Nullable Boolean c1, @Nullable Object o1,
@Nullable Object oe) {
- return caseWhenVar(c1, o1, oe);
- }
-
- @Nullable
- @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen"})
- public static Object caseWhen(@Nullable Boolean c1, @Nullable Object o1,
@Nullable Boolean c2, @Nullable Object o2,
- @Nullable Object oe) {
- return caseWhenVar(c1, o1, c2, o2, oe);
- }
-
- @Nullable
- @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen"})
- public static Object caseWhen(@Nullable Boolean c1, @Nullable Object o1,
@Nullable Boolean c2, @Nullable Object o2,
- @Nullable Boolean c3, @Nullable Object o3, @Nullable Object oe) {
- return caseWhenVar(c1, o1, c2, o2, c3, o3, oe);
- }
-
- @Nullable
- @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen"})
- public static Object caseWhen(@Nullable Boolean c1, @Nullable Object o1,
@Nullable Boolean c2, @Nullable Object o2,
- @Nullable Boolean c3, @Nullable Object o3, @Nullable Boolean c4,
@Nullable Object o4, @Nullable Object oe) {
- return caseWhenVar(c1, o1, c2, o2, c3, o3, c4, o4, oe);
- }
-
- @Nullable
- @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen"})
- public static Object caseWhen(@Nullable Boolean c1, @Nullable Object o1,
@Nullable Boolean c2, @Nullable Object o2,
- @Nullable Boolean c3, @Nullable Object o3, @Nullable Boolean c4,
@Nullable Object o4, @Nullable Boolean c5,
- @Nullable Object o5, @Nullable Object oe) {
- return caseWhenVar(c1, o1, c2, o2, c3, o3, c4, o4, c5, o5, oe);
- }
-
- @Nullable
- @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen"})
- public static Object caseWhen(@Nullable Boolean c1, @Nullable Object o1,
@Nullable Boolean c2, @Nullable Object o2,
- @Nullable Boolean c3, @Nullable Object o3, @Nullable Boolean c4,
@Nullable Object o4, @Nullable Boolean c5,
- @Nullable Object o5, @Nullable Boolean c6, @Nullable Object o6,
@Nullable Object oe) {
- return caseWhenVar(c1, o1, c2, o2, c3, o3, c4, o4, c5, o5, c6, o6, oe);
- }
-
- @Nullable
- @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen"})
- public static Object caseWhen(@Nullable Boolean c1, @Nullable Object o1,
@Nullable Boolean c2, @Nullable Object o2,
- @Nullable Boolean c3, @Nullable Object o3, @Nullable Boolean c4,
@Nullable Object o4, @Nullable Boolean c5,
- @Nullable Object o5, @Nullable Boolean c6, @Nullable Object o6,
@Nullable Boolean c7, @Nullable Object o7,
- @Nullable Object oe) {
- return caseWhenVar(c1, o1, c2, o2, c3, o3, c4, o4, c5, o5, c6, o6, c7, o7,
oe);
- }
-
- @Nullable
- @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen"})
- public static Object caseWhen(@Nullable Boolean c1, @Nullable Object o1,
@Nullable Boolean c2, @Nullable Object o2,
- @Nullable Boolean c3, @Nullable Object o3, @Nullable Boolean c4,
@Nullable Object o4, @Nullable Boolean c5,
- @Nullable Object o5, @Nullable Boolean c6, @Nullable Object o6,
@Nullable Boolean c7, @Nullable Object o7,
- @Nullable Boolean c8, @Nullable Object o8, @Nullable Object oe) {
- return caseWhenVar(c1, o1, c2, o2, c3, o3, c4, o4, c5, o5, c6, o6, c7, o7,
c8, o8, oe);
- }
-
- @Nullable
- @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen"})
- public static Object caseWhen(@Nullable Boolean c1, @Nullable Object o1,
@Nullable Boolean c2, @Nullable Object o2,
- @Nullable Boolean c3, @Nullable Object o3, @Nullable Boolean c4,
@Nullable Object o4, @Nullable Boolean c5,
- @Nullable Object o5, @Nullable Boolean c6, @Nullable Object o6,
@Nullable Boolean c7, @Nullable Object o7,
- @Nullable Boolean c8, @Nullable Object o8, @Nullable Boolean c9,
@Nullable Object o9, @Nullable Object oe) {
- return caseWhenVar(c1, o1, c2, o2, c3, o3, c4, o4, c5, o5, c6, o6, c7, o7,
c8, o8, c9, o9, oe);
- }
-
- @Nullable
- @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen"})
- public static Object caseWhen(@Nullable Boolean c1, @Nullable Object o1,
@Nullable Boolean c2, @Nullable Object o2,
- @Nullable Boolean c3, @Nullable Object o3, @Nullable Boolean c4,
@Nullable Object o4, @Nullable Boolean c5,
- @Nullable Object o5, @Nullable Boolean c6, @Nullable Object o6,
@Nullable Boolean c7, @Nullable Object o7,
- @Nullable Boolean c8, @Nullable Object o8, @Nullable Boolean c9,
@Nullable Object o9, @Nullable Boolean c10,
- @Nullable Object o10, @Nullable Object oe) {
- return caseWhenVar(c1, o1, c2, o2, c3, o3, c4, o4, c5, o5, c6, o6, c7, o7,
c8, o8, c9, o9, c10, o10, oe);
- }
-
- @Nullable
- @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen"})
- public static Object caseWhen(@Nullable Boolean c1, @Nullable Object o1,
@Nullable Boolean c2, @Nullable Object o2,
- @Nullable Boolean c3, @Nullable Object o3, @Nullable Boolean c4,
@Nullable Object o4, @Nullable Boolean c5,
- @Nullable Object o5, @Nullable Boolean c6, @Nullable Object o6,
@Nullable Boolean c7, @Nullable Object o7,
- @Nullable Boolean c8, @Nullable Object o8, @Nullable Boolean c9,
@Nullable Object o9, @Nullable Boolean c10,
- @Nullable Object o10, @Nullable Boolean c11, @Nullable Object o11,
@Nullable Object oe) {
- return caseWhenVar(c1, o1, c2, o2, c3, o3, c4, o4, c5, o5, c6, o6, c7, o7,
c8, o8, c9, o9, c10, o10, c11, o11, oe);
- }
-
- @Nullable
- @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen"})
- public static Object caseWhen(@Nullable Boolean c1, @Nullable Object o1,
@Nullable Boolean c2, @Nullable Object o2,
- @Nullable Boolean c3, @Nullable Object o3, @Nullable Boolean c4,
@Nullable Object o4, @Nullable Boolean c5,
- @Nullable Object o5, @Nullable Boolean c6, @Nullable Object o6,
@Nullable Boolean c7, @Nullable Object o7,
- @Nullable Boolean c8, @Nullable Object o8, @Nullable Boolean c9,
@Nullable Object o9, @Nullable Boolean c10,
- @Nullable Object o10, @Nullable Boolean c11, @Nullable Object o11,
@Nullable Boolean c12, @Nullable Object o12,
- @Nullable Object oe) {
- return caseWhenVar(c1, o1, c2, o2, c3, o3, c4, o4, c5, o5, c6, o6, c7, o7,
c8, o8, c9, o9, c10, o10, c11, o11, c12,
- o12, oe);
- }
-
- @Nullable
- @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen"})
- public static Object caseWhen(@Nullable Boolean c1, @Nullable Object o1,
@Nullable Boolean c2, @Nullable Object o2,
- @Nullable Boolean c3, @Nullable Object o3, @Nullable Boolean c4,
@Nullable Object o4, @Nullable Boolean c5,
- @Nullable Object o5, @Nullable Boolean c6, @Nullable Object o6,
@Nullable Boolean c7, @Nullable Object o7,
- @Nullable Boolean c8, @Nullable Object o8, @Nullable Boolean c9,
@Nullable Object o9, @Nullable Boolean c10,
- @Nullable Object o10, @Nullable Boolean c11, @Nullable Object o11,
@Nullable Boolean c12, @Nullable Object o12,
- @Nullable Boolean c13, @Nullable Object o13, @Nullable Object oe) {
- return caseWhenVar(c1, o1, c2, o2, c3, o3, c4, o4, c5, o5, c6, o6, c7, o7,
c8, o8, c9, o9, c10, o10, c11, o11, c12,
- o12, c13, o13, oe);
- }
-
- @Nullable
- @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen"})
- public static Object caseWhen(@Nullable Boolean c1, @Nullable Object o1,
@Nullable Boolean c2, @Nullable Object o2,
- @Nullable Boolean c3, @Nullable Object o3, @Nullable Boolean c4,
@Nullable Object o4, @Nullable Boolean c5,
- @Nullable Object o5, @Nullable Boolean c6, @Nullable Object o6,
@Nullable Boolean c7, @Nullable Object o7,
- @Nullable Boolean c8, @Nullable Object o8, @Nullable Boolean c9,
@Nullable Object o9, @Nullable Boolean c10,
- @Nullable Object o10, @Nullable Boolean c11, @Nullable Object o11,
@Nullable Boolean c12, @Nullable Object o12,
- @Nullable Boolean c13, @Nullable Object o13, @Nullable Boolean c14,
@Nullable Object o14, @Nullable Object oe) {
- return caseWhenVar(c1, o1, c2, o2, c3, o3, c4, o4, c5, o5, c6, o6, c7, o7,
c8, o8, c9, o9, c10, o10, c11, o11, c12,
- o12, c13, o13, c14, o14, oe);
- }
-
- @Nullable
- @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen"})
- public static Object caseWhen(@Nullable Boolean c1, @Nullable Object o1,
@Nullable Boolean c2, @Nullable Object o2,
- @Nullable Boolean c3, @Nullable Object o3, @Nullable Boolean c4,
@Nullable Object o4, @Nullable Boolean c5,
- @Nullable Object o5, @Nullable Boolean c6, @Nullable Object o6,
@Nullable Boolean c7, @Nullable Object o7,
- @Nullable Boolean c8, @Nullable Object o8, @Nullable Boolean c9,
@Nullable Object o9, @Nullable Boolean c10,
- @Nullable Object o10, @Nullable Boolean c11, @Nullable Object o11,
@Nullable Boolean c12, @Nullable Object o12,
- @Nullable Boolean c13, @Nullable Object o13, @Nullable Boolean c14,
@Nullable Object o14, @Nullable Boolean c15,
- @Nullable Object o15, @Nullable Object oe) {
- return caseWhenVar(c1, o1, c2, o2, c3, o3, c4, o4, c5, o5, c6, o6, c7, o7,
c8, o8, c9, o9, c10, o10, c11, o11, c12,
- o12, c13, o13, c14, o14, c15, o15, oe);
- }
-
- @Nullable
- private static Object caseWhenVar(Object... objs) {
+ @ScalarFunction(names = {"case", "caseWhen"}, nullableParameters = true,
isVarArg = true)
+ public static Object caseWhen(Object... objs) {
for (int i = 0; i < objs.length - 1; i += 2) {
if (Boolean.TRUE.equals(objs[i])) {
return objs[i + 1];
diff --git
a/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java
b/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java
index 67940a64c0..561fc755f1 100644
---
a/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java
+++
b/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java
@@ -117,7 +117,7 @@ public class CalciteSqlCompilerTest {
}
@Test
- public void testCaseWhenStatements() {
+ public void testCaseWhenTransformStatements() {
//@formatter:off
PinotQuery pinotQuery = compileToPinotQuery(
"SELECT OrderID, Quantity,\n"
@@ -237,6 +237,17 @@ public class CalciteSqlCompilerTest {
.getName(), "Quantity");
}
+ @Test
+ public void testCaseWhenScalar() {
+ PinotQuery pinotQuery = compileToPinotQuery("SELECT CASE WHEN NOW() > 0
THEN 1 ELSE -1 END FROM myTable");
+ Assert.assertEquals(pinotQuery.getSelectList().size(), 1);
+ Assert.assertTrue(pinotQuery.getSelectList().get(0).isSetLiteral());
+
Assert.assertEquals(pinotQuery.getSelectList().get(0).getLiteral().getIntValue(),
1);
+
+ Assert.assertThrows(SqlCompilationException.class,
+ () -> compileToPinotQuery("SELECT CASE WHEN 1 > 0 END FROM myTable"));
+ }
+
@Test
public void testQuotedStrings() {
PinotQuery pinotQuery = compileToPinotQuery("select * from vegetables
where origin = 'Martha''s Vineyard'");
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/postaggregation/PostAggregationFunction.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/postaggregation/PostAggregationFunction.java
index 370865227c..1ad0fdfaf4 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/postaggregation/PostAggregationFunction.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/postaggregation/PostAggregationFunction.java
@@ -20,6 +20,7 @@ package org.apache.pinot.core.query.postaggregation;
import com.google.common.base.Preconditions;
import java.util.Arrays;
+import javax.annotation.Nullable;
import org.apache.pinot.common.function.FunctionInfo;
import org.apache.pinot.common.function.FunctionInvoker;
import org.apache.pinot.common.function.FunctionRegistry;
@@ -33,8 +34,9 @@ import org.apache.pinot.common.utils.PinotDataType;
*/
public class PostAggregationFunction {
private final FunctionInvoker _functionInvoker;
- private final PinotDataType[] _argumentTypes;
private final ColumnDataType _resultType;
+ @Nullable
+ private PinotDataType[] _argumentTypes;
public PostAggregationFunction(String functionName, ColumnDataType[]
argumentTypes) {
String canonicalName = FunctionRegistry.canonicalize(functionName);
@@ -49,24 +51,27 @@ public class PostAggregationFunction {
}
}
_functionInvoker = new FunctionInvoker(functionInfo);
- Class<?>[] parameterClasses = _functionInvoker.getParameterClasses();
- PinotDataType[] parameterTypes = _functionInvoker.getParameterTypes();
- int numArguments = argumentTypes.length;
- int numParameters = parameterClasses.length;
- Preconditions.checkArgument(numArguments == numParameters,
- "Wrong number of arguments for method: %s, expected: %s, actual: %s",
functionInfo.getMethod(), numParameters,
- numArguments);
- for (int i = 0; i < numParameters; i++) {
- Preconditions.checkArgument(parameterTypes[i] != null, "Unsupported
parameter class: %s for method: %s",
- parameterClasses[i], functionInfo.getMethod());
- }
- _argumentTypes = new PinotDataType[numArguments];
- for (int i = 0; i < numArguments; i++) {
- _argumentTypes[i] =
PinotDataType.getPinotDataTypeForExecution(argumentTypes[i]);
- }
ColumnDataType resultType =
FunctionUtils.getColumnDataType(_functionInvoker.getResultClass());
// Handle unrecognized result class with STRING
_resultType = resultType != null ? resultType : ColumnDataType.STRING;
+
+ if (!_functionInvoker.getMethod().isVarArgs()) {
+ Class<?>[] parameterClasses = _functionInvoker.getParameterClasses();
+ PinotDataType[] parameterTypes = _functionInvoker.getParameterTypes();
+ int numArguments = argumentTypes.length;
+ int numParameters = parameterClasses.length;
+ Preconditions.checkArgument(numArguments == numParameters,
+ "Wrong number of arguments for method: %s, expected: %s, actual:
%s", functionInfo.getMethod(), numParameters,
+ numArguments);
+ for (int i = 0; i < numParameters; i++) {
+ Preconditions.checkArgument(parameterTypes[i] != null, "Unsupported
parameter class: %s for method: %s",
+ parameterClasses[i], functionInfo.getMethod());
+ }
+ _argumentTypes = new PinotDataType[numArguments];
+ for (int i = 0; i < numArguments; i++) {
+ _argumentTypes[i] =
PinotDataType.getPinotDataTypeForExecution(argumentTypes[i]);
+ }
+ }
}
/**
@@ -81,16 +86,22 @@ public class PostAggregationFunction {
* NOTE: The passed in arguments could be modified during the type
conversion.
*/
public Object invoke(Object[] arguments) {
- int numArguments = arguments.length;
- PinotDataType[] parameterTypes = _functionInvoker.getParameterTypes();
- for (int i = 0; i < numArguments; i++) {
- PinotDataType parameterType = parameterTypes[i];
- PinotDataType argumentType = _argumentTypes[i];
- if (parameterType != argumentType) {
- arguments[i] = parameterType.convert(arguments[i], argumentType);
+ Object result;
+ if (_functionInvoker.getMethod().isVarArgs()) {
+ result = _functionInvoker.invoke(new Object[]{arguments});
+ } else {
+ int numArguments = arguments.length;
+ PinotDataType[] parameterTypes = _functionInvoker.getParameterTypes();
+ for (int i = 0; i < numArguments; i++) {
+ PinotDataType parameterType = parameterTypes[i];
+ PinotDataType argumentType = _argumentTypes[i];
+ if (parameterType != argumentType) {
+ arguments[i] = parameterType.convert(arguments[i], argumentType);
+ }
}
+ result = _functionInvoker.invoke(arguments);
}
- Object result = _functionInvoker.invoke(arguments);
+
return _resultType == ColumnDataType.STRING ? result.toString() : result;
}
}
diff --git
a/pinot-core/src/test/java/org/apache/pinot/core/query/postaggregation/PostAggregationFunctionTest.java
b/pinot-core/src/test/java/org/apache/pinot/core/query/postaggregation/PostAggregationFunctionTest.java
index 6f4cd02a29..01f88b1d91 100644
---
a/pinot-core/src/test/java/org/apache/pinot/core/query/postaggregation/PostAggregationFunctionTest.java
+++
b/pinot-core/src/test/java/org/apache/pinot/core/query/postaggregation/PostAggregationFunctionTest.java
@@ -58,9 +58,9 @@ public class PostAggregationFunctionTest {
// ST_AsText
function = new PostAggregationFunction("ST_AsText", new
ColumnDataType[]{ColumnDataType.BYTES});
assertEquals(function.getResultType(), ColumnDataType.STRING);
- assertEquals(function.invoke(
- new
Object[]{GeometrySerializer.serialize(GeometryUtils.GEOMETRY_FACTORY.createPoint(new
Coordinate(10, 20)))}),
- "POINT (10 20)");
+ assertEquals(function.invoke(new Object[]{
+
GeometrySerializer.serialize(GeometryUtils.GEOMETRY_FACTORY.createPoint(new
Coordinate(10, 20)))
+ }), "POINT (10 20)");
// Cast
function = new PostAggregationFunction("cast", new
ColumnDataType[]{ColumnDataType.INT, ColumnDataType.STRING});
@@ -94,12 +94,30 @@ public class PostAggregationFunctionTest {
assertFalse((Boolean) function.invoke(new Object[]{"a", "b"}));
// Coalesce
- function = new PostAggregationFunction("coalesce", new
ColumnDataType[]{ColumnDataType.INT, ColumnDataType.STRING,
- ColumnDataType.BOOLEAN});
+ function = new PostAggregationFunction("coalesce", new ColumnDataType[]{
+ ColumnDataType.INT, ColumnDataType.STRING, ColumnDataType.BOOLEAN
+ });
assertEquals(function.getResultType(), ColumnDataType.OBJECT);
assertNull(function.invoke(new Object[]{null, null, null}));
assertEquals(function.invoke(new Object[]{null, "1", null}), "1");
assertEquals(function.invoke(new Object[]{1, "2", false}), 1);
assertEquals(function.invoke(new Object[]{null, null, true}), true);
+
+ // Case with a large number of when then clauses
+ function = new PostAggregationFunction("case", new ColumnDataType[]{
+ ColumnDataType.BOOLEAN, ColumnDataType.INT, ColumnDataType.BOOLEAN,
ColumnDataType.INT,
+ ColumnDataType.BOOLEAN, ColumnDataType.INT, ColumnDataType.BOOLEAN,
ColumnDataType.INT,
+ ColumnDataType.BOOLEAN, ColumnDataType.INT, ColumnDataType.BOOLEAN,
ColumnDataType.INT,
+ ColumnDataType.BOOLEAN, ColumnDataType.INT, ColumnDataType.BOOLEAN,
ColumnDataType.INT,
+ ColumnDataType.BOOLEAN, ColumnDataType.INT, ColumnDataType.BOOLEAN,
ColumnDataType.INT,
+ ColumnDataType.BOOLEAN, ColumnDataType.INT, ColumnDataType.BOOLEAN,
ColumnDataType.INT,
+ ColumnDataType.BOOLEAN, ColumnDataType.INT, ColumnDataType.BOOLEAN,
ColumnDataType.INT,
+ ColumnDataType.BOOLEAN, ColumnDataType.INT, ColumnDataType.BOOLEAN,
ColumnDataType.INT, ColumnDataType.INT
+ });
+ assertEquals(function.getResultType(), ColumnDataType.OBJECT);
+ assertEquals(function.invoke(new Object[]{
+ false, 1, false, 2, false, 3, false, 4, false, 5, false, 6, false, 7,
false, 8, false, 9, false, 10, false,
+ 11, false, 12, false, 13, false, 14, false, 15, false, 16, false, 17,
false, 18, false, 19, 20
+ }), 20);
}
}
diff --git
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/MultiStageEngineIntegrationTest.java
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/MultiStageEngineIntegrationTest.java
index 7721198ecb..b58ae2dee2 100644
---
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/MultiStageEngineIntegrationTest.java
+++
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/MultiStageEngineIntegrationTest.java
@@ -963,6 +963,53 @@ public class MultiStageEngineIntegrationTest extends
BaseClusterIntegrationTestS
Assert.assertFalse(plan.contains("<="));
}
+ @Test
+ public void testCaseWhenWithLargeNumberOfWhenThenClauses()
+ throws Exception {
+ // This test is to verify that the case when function with a large number
of when then clauses works correctly.
+ // The test verifies both the scalar and transform function variants.
+
+ // Write the query in a way that the case when will be executed in the
intermediate stage and hence will have
+ // to use the scalar function variant instead of the transform function
variant.
+ String sqlQuery =
+ "SELECT CASE WHEN CRSArrTime > 2000 THEN 20 WHEN CRSArrTime > 1900
THEN 19 WHEN CRSArrTime > 1800 THEN 18 "
+ + "WHEN CRSArrTime > 1700 THEN 17 WHEN CRSArrTime > 1600 THEN 16
WHEN CRSArrTime > 1500 THEN 15 WHEN "
+ + "CRSArrTime > 1400 THEN 14 WHEN CRSArrTime > 1300 THEN 13 WHEN
CRSArrTime > 1200 THEN 12 WHEN "
+ + "CRSArrTime > 1100 THEN 11 WHEN CRSArrTime > 1000 THEN 10 WHEN
CRSArrTime > 900 THEN 9 WHEN "
+ + "CRSArrTime > 800 THEN 8 WHEN CRSArrTime > 700 THEN 7 WHEN
CRSArrTime > 600 THEN 6 WHEN "
+ + "CRSArrTime > 500 THEN 50 WHEN CRSArrTime > 400 THEN 4 WHEN
CRSArrTime > 300 THEN 3 WHEN "
+ + "CRSArrTime > 200 THEN 2 WHEN CRSArrTime > 100 THEN 1 ELSE 0 END
FROM (SELECT * FROM mytable ORDER BY "
+ + "CRSArrTime LIMIT 10)";
+ JsonNode jsonNode = postQuery(sqlQuery);
+ assertNoError(jsonNode);
+
Assert.assertEquals(jsonNode.get("resultTable").get("dataSchema").get("columnDataTypes").size(),
1);
+
Assert.assertEquals(jsonNode.get("resultTable").get("dataSchema").get("columnDataTypes").get(0).asText(),
"INT");
+ JsonNode rowsScalar = jsonNode.get("resultTable").get("rows");
+ assertEquals(rowsScalar.size(), 10);
+
+ // Rewrite the query in a way that the case when will be executed in the
leaf stage projection and hence will use
+ // the transform function variant
+ sqlQuery =
+ "SELECT CASE WHEN CRSArrTime > 2000 THEN 20 WHEN CRSArrTime > 1900
THEN 19 WHEN CRSArrTime > 1800 THEN 18 "
+ + "WHEN CRSArrTime > 1700 THEN 17 WHEN CRSArrTime > 1600 THEN 16
WHEN CRSArrTime > 1500 THEN 15 WHEN "
+ + "CRSArrTime > 1400 THEN 14 WHEN CRSArrTime > 1300 THEN 13 WHEN
CRSArrTime > 1200 THEN 12 WHEN "
+ + "CRSArrTime > 1100 THEN 11 WHEN CRSArrTime > 1000 THEN 10 WHEN
CRSArrTime > 900 THEN 9 WHEN "
+ + "CRSArrTime > 800 THEN 8 WHEN CRSArrTime > 700 THEN 7 WHEN
CRSArrTime > 600 THEN 6 WHEN "
+ + "CRSArrTime > 500 THEN 50 WHEN CRSArrTime > 400 THEN 4 WHEN
CRSArrTime > 300 THEN 3 WHEN "
+ + "CRSArrTime > 200 THEN 2 WHEN CRSArrTime > 100 THEN 1 ELSE 0 END
FROM mytable ORDER BY "
+ + "CRSArrTime LIMIT 10";
+ jsonNode = postQuery(sqlQuery);
+ assertNoError(jsonNode);
+
Assert.assertEquals(jsonNode.get("resultTable").get("dataSchema").get("columnDataTypes").size(),
1);
+
Assert.assertEquals(jsonNode.get("resultTable").get("dataSchema").get("columnDataTypes").get(0).asText(),
"INT");
+ JsonNode rowsTransform = jsonNode.get("resultTable").get("rows");
+ assertEquals(rowsTransform.size(), 10);
+
+ for (int i = 0; i < 10; i++) {
+ assertEquals(rowsScalar.get(i).get(0).asInt(),
rowsTransform.get(i).get(0).asInt());
+ }
+ }
+
@Test
public void testMVNumericCastInFilter() throws Exception {
String sqlQuery = "SELECT COUNT(*) FROM mytable WHERE
ARRAY_TO_MV(CAST(DivAirportIDs AS BIGINT ARRAY)) > 0";
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]