This is an automated email from the ASF dual-hosted git repository.

jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 5d9a794073 Support arbitrary number of WHEN THEN clauses in the scalar 
CASE function (#14125)
5d9a794073 is described below

commit 5d9a79407345b48cbad96d4d1f65cd0d099a9e8a
Author: Yash Mayya <[email protected]>
AuthorDate: Mon Oct 7 23:25:06 2024 +0530

    Support arbitrary number of WHEN THEN clauses in the scalar CASE function 
(#14125)
---
 .../common/function/scalar/ObjectFunctions.java    | 143 +--------------------
 .../pinot/sql/parsers/CalciteSqlCompilerTest.java  |  13 +-
 .../postaggregation/PostAggregationFunction.java   |  59 +++++----
 .../PostAggregationFunctionTest.java               |  28 +++-
 .../tests/MultiStageEngineIntegrationTest.java     |  47 +++++++
 5 files changed, 119 insertions(+), 171 deletions(-)

diff --git 
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ObjectFunctions.java
 
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ObjectFunctions.java
index 7b8fa82d11..6c3df77f96 100644
--- 
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ObjectFunctions.java
+++ 
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ObjectFunctions.java
@@ -93,147 +93,8 @@ public class ObjectFunctions {
     return null;
   }
 
-  @Nullable
-  @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen"})
-  public static Object caseWhen(@Nullable Boolean c1, @Nullable Object o1, 
@Nullable Object oe) {
-    return caseWhenVar(c1, o1, oe);
-  }
-
-  @Nullable
-  @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen"})
-  public static Object caseWhen(@Nullable Boolean c1, @Nullable Object o1, 
@Nullable Boolean c2, @Nullable Object o2,
-      @Nullable Object oe) {
-    return caseWhenVar(c1, o1, c2, o2, oe);
-  }
-
-  @Nullable
-  @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen"})
-  public static Object caseWhen(@Nullable Boolean c1, @Nullable Object o1, 
@Nullable Boolean c2, @Nullable Object o2,
-      @Nullable Boolean c3, @Nullable Object o3, @Nullable Object oe) {
-    return caseWhenVar(c1, o1, c2, o2, c3, o3, oe);
-  }
-
-  @Nullable
-  @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen"})
-  public static Object caseWhen(@Nullable Boolean c1, @Nullable Object o1, 
@Nullable Boolean c2, @Nullable Object o2,
-      @Nullable Boolean c3, @Nullable Object o3, @Nullable Boolean c4, 
@Nullable Object o4, @Nullable Object oe) {
-    return caseWhenVar(c1, o1, c2, o2, c3, o3, c4, o4, oe);
-  }
-
-  @Nullable
-  @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen"})
-  public static Object caseWhen(@Nullable Boolean c1, @Nullable Object o1, 
@Nullable Boolean c2, @Nullable Object o2,
-      @Nullable Boolean c3, @Nullable Object o3, @Nullable Boolean c4, 
@Nullable Object o4, @Nullable Boolean c5,
-      @Nullable Object o5, @Nullable Object oe) {
-    return caseWhenVar(c1, o1, c2, o2, c3, o3, c4, o4, c5, o5, oe);
-  }
-
-  @Nullable
-  @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen"})
-  public static Object caseWhen(@Nullable Boolean c1, @Nullable Object o1, 
@Nullable Boolean c2, @Nullable Object o2,
-      @Nullable Boolean c3, @Nullable Object o3, @Nullable Boolean c4, 
@Nullable Object o4, @Nullable Boolean c5,
-      @Nullable Object o5, @Nullable Boolean c6, @Nullable Object o6, 
@Nullable Object oe) {
-    return caseWhenVar(c1, o1, c2, o2, c3, o3, c4, o4, c5, o5, c6, o6, oe);
-  }
-
-  @Nullable
-  @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen"})
-  public static Object caseWhen(@Nullable Boolean c1, @Nullable Object o1, 
@Nullable Boolean c2, @Nullable Object o2,
-      @Nullable Boolean c3, @Nullable Object o3, @Nullable Boolean c4, 
@Nullable Object o4, @Nullable Boolean c5,
-      @Nullable Object o5, @Nullable Boolean c6, @Nullable Object o6, 
@Nullable Boolean c7, @Nullable Object o7,
-      @Nullable Object oe) {
-    return caseWhenVar(c1, o1, c2, o2, c3, o3, c4, o4, c5, o5, c6, o6, c7, o7, 
oe);
-  }
-
-  @Nullable
-  @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen"})
-  public static Object caseWhen(@Nullable Boolean c1, @Nullable Object o1, 
@Nullable Boolean c2, @Nullable Object o2,
-      @Nullable Boolean c3, @Nullable Object o3, @Nullable Boolean c4, 
@Nullable Object o4, @Nullable Boolean c5,
-      @Nullable Object o5, @Nullable Boolean c6, @Nullable Object o6, 
@Nullable Boolean c7, @Nullable Object o7,
-      @Nullable Boolean c8, @Nullable Object o8, @Nullable Object oe) {
-    return caseWhenVar(c1, o1, c2, o2, c3, o3, c4, o4, c5, o5, c6, o6, c7, o7, 
c8, o8, oe);
-  }
-
-  @Nullable
-  @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen"})
-  public static Object caseWhen(@Nullable Boolean c1, @Nullable Object o1, 
@Nullable Boolean c2, @Nullable Object o2,
-      @Nullable Boolean c3, @Nullable Object o3, @Nullable Boolean c4, 
@Nullable Object o4, @Nullable Boolean c5,
-      @Nullable Object o5, @Nullable Boolean c6, @Nullable Object o6, 
@Nullable Boolean c7, @Nullable Object o7,
-      @Nullable Boolean c8, @Nullable Object o8, @Nullable Boolean c9, 
@Nullable Object o9, @Nullable Object oe) {
-    return caseWhenVar(c1, o1, c2, o2, c3, o3, c4, o4, c5, o5, c6, o6, c7, o7, 
c8, o8, c9, o9, oe);
-  }
-
-  @Nullable
-  @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen"})
-  public static Object caseWhen(@Nullable Boolean c1, @Nullable Object o1, 
@Nullable Boolean c2, @Nullable Object o2,
-      @Nullable Boolean c3, @Nullable Object o3, @Nullable Boolean c4, 
@Nullable Object o4, @Nullable Boolean c5,
-      @Nullable Object o5, @Nullable Boolean c6, @Nullable Object o6, 
@Nullable Boolean c7, @Nullable Object o7,
-      @Nullable Boolean c8, @Nullable Object o8, @Nullable Boolean c9, 
@Nullable Object o9, @Nullable Boolean c10,
-      @Nullable Object o10, @Nullable Object oe) {
-    return caseWhenVar(c1, o1, c2, o2, c3, o3, c4, o4, c5, o5, c6, o6, c7, o7, 
c8, o8, c9, o9, c10, o10, oe);
-  }
-
-  @Nullable
-  @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen"})
-  public static Object caseWhen(@Nullable Boolean c1, @Nullable Object o1, 
@Nullable Boolean c2, @Nullable Object o2,
-      @Nullable Boolean c3, @Nullable Object o3, @Nullable Boolean c4, 
@Nullable Object o4, @Nullable Boolean c5,
-      @Nullable Object o5, @Nullable Boolean c6, @Nullable Object o6, 
@Nullable Boolean c7, @Nullable Object o7,
-      @Nullable Boolean c8, @Nullable Object o8, @Nullable Boolean c9, 
@Nullable Object o9, @Nullable Boolean c10,
-      @Nullable Object o10, @Nullable Boolean c11, @Nullable Object o11, 
@Nullable Object oe) {
-    return caseWhenVar(c1, o1, c2, o2, c3, o3, c4, o4, c5, o5, c6, o6, c7, o7, 
c8, o8, c9, o9, c10, o10, c11, o11, oe);
-  }
-
-  @Nullable
-  @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen"})
-  public static Object caseWhen(@Nullable Boolean c1, @Nullable Object o1, 
@Nullable Boolean c2, @Nullable Object o2,
-      @Nullable Boolean c3, @Nullable Object o3, @Nullable Boolean c4, 
@Nullable Object o4, @Nullable Boolean c5,
-      @Nullable Object o5, @Nullable Boolean c6, @Nullable Object o6, 
@Nullable Boolean c7, @Nullable Object o7,
-      @Nullable Boolean c8, @Nullable Object o8, @Nullable Boolean c9, 
@Nullable Object o9, @Nullable Boolean c10,
-      @Nullable Object o10, @Nullable Boolean c11, @Nullable Object o11, 
@Nullable Boolean c12, @Nullable Object o12,
-      @Nullable Object oe) {
-    return caseWhenVar(c1, o1, c2, o2, c3, o3, c4, o4, c5, o5, c6, o6, c7, o7, 
c8, o8, c9, o9, c10, o10, c11, o11, c12,
-        o12, oe);
-  }
-
-  @Nullable
-  @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen"})
-  public static Object caseWhen(@Nullable Boolean c1, @Nullable Object o1, 
@Nullable Boolean c2, @Nullable Object o2,
-      @Nullable Boolean c3, @Nullable Object o3, @Nullable Boolean c4, 
@Nullable Object o4, @Nullable Boolean c5,
-      @Nullable Object o5, @Nullable Boolean c6, @Nullable Object o6, 
@Nullable Boolean c7, @Nullable Object o7,
-      @Nullable Boolean c8, @Nullable Object o8, @Nullable Boolean c9, 
@Nullable Object o9, @Nullable Boolean c10,
-      @Nullable Object o10, @Nullable Boolean c11, @Nullable Object o11, 
@Nullable Boolean c12, @Nullable Object o12,
-      @Nullable Boolean c13, @Nullable Object o13, @Nullable Object oe) {
-    return caseWhenVar(c1, o1, c2, o2, c3, o3, c4, o4, c5, o5, c6, o6, c7, o7, 
c8, o8, c9, o9, c10, o10, c11, o11, c12,
-        o12, c13, o13, oe);
-  }
-
-  @Nullable
-  @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen"})
-  public static Object caseWhen(@Nullable Boolean c1, @Nullable Object o1, 
@Nullable Boolean c2, @Nullable Object o2,
-      @Nullable Boolean c3, @Nullable Object o3, @Nullable Boolean c4, 
@Nullable Object o4, @Nullable Boolean c5,
-      @Nullable Object o5, @Nullable Boolean c6, @Nullable Object o6, 
@Nullable Boolean c7, @Nullable Object o7,
-      @Nullable Boolean c8, @Nullable Object o8, @Nullable Boolean c9, 
@Nullable Object o9, @Nullable Boolean c10,
-      @Nullable Object o10, @Nullable Boolean c11, @Nullable Object o11, 
@Nullable Boolean c12, @Nullable Object o12,
-      @Nullable Boolean c13, @Nullable Object o13, @Nullable Boolean c14, 
@Nullable Object o14, @Nullable Object oe) {
-    return caseWhenVar(c1, o1, c2, o2, c3, o3, c4, o4, c5, o5, c6, o6, c7, o7, 
c8, o8, c9, o9, c10, o10, c11, o11, c12,
-        o12, c13, o13, c14, o14, oe);
-  }
-
-  @Nullable
-  @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen"})
-  public static Object caseWhen(@Nullable Boolean c1, @Nullable Object o1, 
@Nullable Boolean c2, @Nullable Object o2,
-      @Nullable Boolean c3, @Nullable Object o3, @Nullable Boolean c4, 
@Nullable Object o4, @Nullable Boolean c5,
-      @Nullable Object o5, @Nullable Boolean c6, @Nullable Object o6, 
@Nullable Boolean c7, @Nullable Object o7,
-      @Nullable Boolean c8, @Nullable Object o8, @Nullable Boolean c9, 
@Nullable Object o9, @Nullable Boolean c10,
-      @Nullable Object o10, @Nullable Boolean c11, @Nullable Object o11, 
@Nullable Boolean c12, @Nullable Object o12,
-      @Nullable Boolean c13, @Nullable Object o13, @Nullable Boolean c14, 
@Nullable Object o14, @Nullable Boolean c15,
-      @Nullable Object o15, @Nullable Object oe) {
-    return caseWhenVar(c1, o1, c2, o2, c3, o3, c4, o4, c5, o5, c6, o6, c7, o7, 
c8, o8, c9, o9, c10, o10, c11, o11, c12,
-        o12, c13, o13, c14, o14, c15, o15, oe);
-  }
-
-  @Nullable
-  private static Object caseWhenVar(Object... objs) {
+  @ScalarFunction(names = {"case", "caseWhen"}, nullableParameters = true, 
isVarArg = true)
+  public static Object caseWhen(Object... objs) {
     for (int i = 0; i < objs.length - 1; i += 2) {
       if (Boolean.TRUE.equals(objs[i])) {
         return objs[i + 1];
diff --git 
a/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java
 
b/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java
index 67940a64c0..561fc755f1 100644
--- 
a/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java
+++ 
b/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java
@@ -117,7 +117,7 @@ public class CalciteSqlCompilerTest {
   }
 
   @Test
-  public void testCaseWhenStatements() {
+  public void testCaseWhenTransformStatements() {
     //@formatter:off
     PinotQuery pinotQuery = compileToPinotQuery(
         "SELECT OrderID, Quantity,\n"
@@ -237,6 +237,17 @@ public class CalciteSqlCompilerTest {
             .getName(), "Quantity");
   }
 
+  @Test
+  public void testCaseWhenScalar() {
+    PinotQuery pinotQuery = compileToPinotQuery("SELECT CASE WHEN NOW() > 0 
THEN 1 ELSE -1 END FROM myTable");
+    Assert.assertEquals(pinotQuery.getSelectList().size(), 1);
+    Assert.assertTrue(pinotQuery.getSelectList().get(0).isSetLiteral());
+    
Assert.assertEquals(pinotQuery.getSelectList().get(0).getLiteral().getIntValue(),
 1);
+
+    Assert.assertThrows(SqlCompilationException.class,
+        () -> compileToPinotQuery("SELECT CASE WHEN 1 > 0 END FROM myTable"));
+  }
+
   @Test
   public void testQuotedStrings() {
     PinotQuery pinotQuery = compileToPinotQuery("select * from vegetables 
where origin = 'Martha''s Vineyard'");
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/postaggregation/PostAggregationFunction.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/postaggregation/PostAggregationFunction.java
index 370865227c..1ad0fdfaf4 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/postaggregation/PostAggregationFunction.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/postaggregation/PostAggregationFunction.java
@@ -20,6 +20,7 @@ package org.apache.pinot.core.query.postaggregation;
 
 import com.google.common.base.Preconditions;
 import java.util.Arrays;
+import javax.annotation.Nullable;
 import org.apache.pinot.common.function.FunctionInfo;
 import org.apache.pinot.common.function.FunctionInvoker;
 import org.apache.pinot.common.function.FunctionRegistry;
@@ -33,8 +34,9 @@ import org.apache.pinot.common.utils.PinotDataType;
  */
 public class PostAggregationFunction {
   private final FunctionInvoker _functionInvoker;
-  private final PinotDataType[] _argumentTypes;
   private final ColumnDataType _resultType;
+  @Nullable
+  private PinotDataType[] _argumentTypes;
 
   public PostAggregationFunction(String functionName, ColumnDataType[] 
argumentTypes) {
     String canonicalName = FunctionRegistry.canonicalize(functionName);
@@ -49,24 +51,27 @@ public class PostAggregationFunction {
       }
     }
     _functionInvoker = new FunctionInvoker(functionInfo);
-    Class<?>[] parameterClasses = _functionInvoker.getParameterClasses();
-    PinotDataType[] parameterTypes = _functionInvoker.getParameterTypes();
-    int numArguments = argumentTypes.length;
-    int numParameters = parameterClasses.length;
-    Preconditions.checkArgument(numArguments == numParameters,
-        "Wrong number of arguments for method: %s, expected: %s, actual: %s", 
functionInfo.getMethod(), numParameters,
-        numArguments);
-    for (int i = 0; i < numParameters; i++) {
-      Preconditions.checkArgument(parameterTypes[i] != null, "Unsupported 
parameter class: %s for method: %s",
-          parameterClasses[i], functionInfo.getMethod());
-    }
-    _argumentTypes = new PinotDataType[numArguments];
-    for (int i = 0; i < numArguments; i++) {
-      _argumentTypes[i] = 
PinotDataType.getPinotDataTypeForExecution(argumentTypes[i]);
-    }
     ColumnDataType resultType = 
FunctionUtils.getColumnDataType(_functionInvoker.getResultClass());
     // Handle unrecognized result class with STRING
     _resultType = resultType != null ? resultType : ColumnDataType.STRING;
+
+    if (!_functionInvoker.getMethod().isVarArgs()) {
+      Class<?>[] parameterClasses = _functionInvoker.getParameterClasses();
+      PinotDataType[] parameterTypes = _functionInvoker.getParameterTypes();
+      int numArguments = argumentTypes.length;
+      int numParameters = parameterClasses.length;
+      Preconditions.checkArgument(numArguments == numParameters,
+          "Wrong number of arguments for method: %s, expected: %s, actual: 
%s", functionInfo.getMethod(), numParameters,
+          numArguments);
+      for (int i = 0; i < numParameters; i++) {
+        Preconditions.checkArgument(parameterTypes[i] != null, "Unsupported 
parameter class: %s for method: %s",
+            parameterClasses[i], functionInfo.getMethod());
+      }
+      _argumentTypes = new PinotDataType[numArguments];
+      for (int i = 0; i < numArguments; i++) {
+        _argumentTypes[i] = 
PinotDataType.getPinotDataTypeForExecution(argumentTypes[i]);
+      }
+    }
   }
 
   /**
@@ -81,16 +86,22 @@ public class PostAggregationFunction {
    * NOTE: The passed in arguments could be modified during the type 
conversion.
    */
   public Object invoke(Object[] arguments) {
-    int numArguments = arguments.length;
-    PinotDataType[] parameterTypes = _functionInvoker.getParameterTypes();
-    for (int i = 0; i < numArguments; i++) {
-      PinotDataType parameterType = parameterTypes[i];
-      PinotDataType argumentType = _argumentTypes[i];
-      if (parameterType != argumentType) {
-        arguments[i] = parameterType.convert(arguments[i], argumentType);
+    Object result;
+    if (_functionInvoker.getMethod().isVarArgs()) {
+      result = _functionInvoker.invoke(new Object[]{arguments});
+    } else {
+      int numArguments = arguments.length;
+      PinotDataType[] parameterTypes = _functionInvoker.getParameterTypes();
+      for (int i = 0; i < numArguments; i++) {
+        PinotDataType parameterType = parameterTypes[i];
+        PinotDataType argumentType = _argumentTypes[i];
+        if (parameterType != argumentType) {
+          arguments[i] = parameterType.convert(arguments[i], argumentType);
+        }
       }
+      result = _functionInvoker.invoke(arguments);
     }
-    Object result = _functionInvoker.invoke(arguments);
+
     return _resultType == ColumnDataType.STRING ? result.toString() : result;
   }
 }
diff --git 
a/pinot-core/src/test/java/org/apache/pinot/core/query/postaggregation/PostAggregationFunctionTest.java
 
b/pinot-core/src/test/java/org/apache/pinot/core/query/postaggregation/PostAggregationFunctionTest.java
index 6f4cd02a29..01f88b1d91 100644
--- 
a/pinot-core/src/test/java/org/apache/pinot/core/query/postaggregation/PostAggregationFunctionTest.java
+++ 
b/pinot-core/src/test/java/org/apache/pinot/core/query/postaggregation/PostAggregationFunctionTest.java
@@ -58,9 +58,9 @@ public class PostAggregationFunctionTest {
     // ST_AsText
     function = new PostAggregationFunction("ST_AsText", new 
ColumnDataType[]{ColumnDataType.BYTES});
     assertEquals(function.getResultType(), ColumnDataType.STRING);
-    assertEquals(function.invoke(
-        new 
Object[]{GeometrySerializer.serialize(GeometryUtils.GEOMETRY_FACTORY.createPoint(new
 Coordinate(10, 20)))}),
-        "POINT (10 20)");
+    assertEquals(function.invoke(new Object[]{
+        
GeometrySerializer.serialize(GeometryUtils.GEOMETRY_FACTORY.createPoint(new 
Coordinate(10, 20)))
+    }), "POINT (10 20)");
 
     // Cast
     function = new PostAggregationFunction("cast", new 
ColumnDataType[]{ColumnDataType.INT, ColumnDataType.STRING});
@@ -94,12 +94,30 @@ public class PostAggregationFunctionTest {
     assertFalse((Boolean) function.invoke(new Object[]{"a", "b"}));
 
     // Coalesce
-    function = new PostAggregationFunction("coalesce", new 
ColumnDataType[]{ColumnDataType.INT, ColumnDataType.STRING,
-    ColumnDataType.BOOLEAN});
+    function = new PostAggregationFunction("coalesce", new ColumnDataType[]{
+        ColumnDataType.INT, ColumnDataType.STRING, ColumnDataType.BOOLEAN
+    });
     assertEquals(function.getResultType(), ColumnDataType.OBJECT);
     assertNull(function.invoke(new Object[]{null, null, null}));
     assertEquals(function.invoke(new Object[]{null, "1", null}), "1");
     assertEquals(function.invoke(new Object[]{1, "2", false}), 1);
     assertEquals(function.invoke(new Object[]{null, null, true}), true);
+
+    // Case with a large number of when then clauses
+    function = new PostAggregationFunction("case", new ColumnDataType[]{
+        ColumnDataType.BOOLEAN, ColumnDataType.INT, ColumnDataType.BOOLEAN, 
ColumnDataType.INT,
+        ColumnDataType.BOOLEAN, ColumnDataType.INT, ColumnDataType.BOOLEAN, 
ColumnDataType.INT,
+        ColumnDataType.BOOLEAN, ColumnDataType.INT, ColumnDataType.BOOLEAN, 
ColumnDataType.INT,
+        ColumnDataType.BOOLEAN, ColumnDataType.INT, ColumnDataType.BOOLEAN, 
ColumnDataType.INT,
+        ColumnDataType.BOOLEAN, ColumnDataType.INT, ColumnDataType.BOOLEAN, 
ColumnDataType.INT,
+        ColumnDataType.BOOLEAN, ColumnDataType.INT, ColumnDataType.BOOLEAN, 
ColumnDataType.INT,
+        ColumnDataType.BOOLEAN, ColumnDataType.INT, ColumnDataType.BOOLEAN, 
ColumnDataType.INT,
+        ColumnDataType.BOOLEAN, ColumnDataType.INT, ColumnDataType.BOOLEAN, 
ColumnDataType.INT, ColumnDataType.INT
+    });
+    assertEquals(function.getResultType(), ColumnDataType.OBJECT);
+    assertEquals(function.invoke(new Object[]{
+        false, 1, false, 2, false, 3, false, 4, false, 5, false, 6, false, 7, 
false, 8, false, 9, false, 10, false,
+        11, false, 12, false, 13, false, 14, false, 15, false, 16, false, 17, 
false, 18, false, 19, 20
+    }), 20);
   }
 }
diff --git 
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/MultiStageEngineIntegrationTest.java
 
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/MultiStageEngineIntegrationTest.java
index 7721198ecb..b58ae2dee2 100644
--- 
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/MultiStageEngineIntegrationTest.java
+++ 
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/MultiStageEngineIntegrationTest.java
@@ -963,6 +963,53 @@ public class MultiStageEngineIntegrationTest extends 
BaseClusterIntegrationTestS
     Assert.assertFalse(plan.contains("<="));
   }
 
+  @Test
+  public void testCaseWhenWithLargeNumberOfWhenThenClauses()
+      throws Exception {
+    // This test is to verify that the case when function with a large number 
of when then clauses works correctly.
+    // The test verifies both the scalar and transform function variants.
+
+    // Write the query in a way that the case when will be executed in the 
intermediate stage and hence will have
+    // to use the scalar function variant instead of the transform function 
variant.
+    String sqlQuery =
+        "SELECT CASE WHEN CRSArrTime > 2000 THEN 20 WHEN CRSArrTime > 1900 
THEN 19 WHEN CRSArrTime > 1800 THEN 18 "
+            + "WHEN CRSArrTime > 1700 THEN 17 WHEN CRSArrTime > 1600 THEN 16 
WHEN CRSArrTime > 1500 THEN 15 WHEN "
+            + "CRSArrTime > 1400 THEN 14 WHEN CRSArrTime > 1300 THEN 13 WHEN 
CRSArrTime > 1200 THEN 12 WHEN "
+            + "CRSArrTime > 1100 THEN 11 WHEN CRSArrTime > 1000 THEN 10 WHEN 
CRSArrTime > 900 THEN 9 WHEN "
+            + "CRSArrTime > 800 THEN 8 WHEN CRSArrTime > 700 THEN 7 WHEN 
CRSArrTime > 600 THEN 6 WHEN "
+            + "CRSArrTime > 500 THEN 50 WHEN CRSArrTime > 400 THEN 4 WHEN 
CRSArrTime > 300 THEN 3 WHEN "
+            + "CRSArrTime > 200 THEN 2 WHEN CRSArrTime > 100 THEN 1 ELSE 0 END 
FROM (SELECT * FROM mytable ORDER BY "
+            + "CRSArrTime LIMIT 10)";
+    JsonNode jsonNode = postQuery(sqlQuery);
+    assertNoError(jsonNode);
+    
Assert.assertEquals(jsonNode.get("resultTable").get("dataSchema").get("columnDataTypes").size(),
 1);
+    
Assert.assertEquals(jsonNode.get("resultTable").get("dataSchema").get("columnDataTypes").get(0).asText(),
 "INT");
+    JsonNode rowsScalar = jsonNode.get("resultTable").get("rows");
+    assertEquals(rowsScalar.size(), 10);
+
+    // Rewrite the query in a way that the case when will be executed in the 
leaf stage projection and hence will use
+    // the transform function variant
+    sqlQuery =
+        "SELECT CASE WHEN CRSArrTime > 2000 THEN 20 WHEN CRSArrTime > 1900 
THEN 19 WHEN CRSArrTime > 1800 THEN 18 "
+            + "WHEN CRSArrTime > 1700 THEN 17 WHEN CRSArrTime > 1600 THEN 16 
WHEN CRSArrTime > 1500 THEN 15 WHEN "
+            + "CRSArrTime > 1400 THEN 14 WHEN CRSArrTime > 1300 THEN 13 WHEN 
CRSArrTime > 1200 THEN 12 WHEN "
+            + "CRSArrTime > 1100 THEN 11 WHEN CRSArrTime > 1000 THEN 10 WHEN 
CRSArrTime > 900 THEN 9 WHEN "
+            + "CRSArrTime > 800 THEN 8 WHEN CRSArrTime > 700 THEN 7 WHEN 
CRSArrTime > 600 THEN 6 WHEN "
+            + "CRSArrTime > 500 THEN 50 WHEN CRSArrTime > 400 THEN 4 WHEN 
CRSArrTime > 300 THEN 3 WHEN "
+            + "CRSArrTime > 200 THEN 2 WHEN CRSArrTime > 100 THEN 1 ELSE 0 END 
FROM mytable ORDER BY "
+            + "CRSArrTime LIMIT 10";
+    jsonNode = postQuery(sqlQuery);
+    assertNoError(jsonNode);
+    
Assert.assertEquals(jsonNode.get("resultTable").get("dataSchema").get("columnDataTypes").size(),
 1);
+    
Assert.assertEquals(jsonNode.get("resultTable").get("dataSchema").get("columnDataTypes").get(0).asText(),
 "INT");
+    JsonNode rowsTransform = jsonNode.get("resultTable").get("rows");
+    assertEquals(rowsTransform.size(), 10);
+
+    for (int i = 0; i < 10; i++) {
+      assertEquals(rowsScalar.get(i).get(0).asInt(), 
rowsTransform.get(i).get(0).asInt());
+    }
+  }
+
   @Test
   public void testMVNumericCastInFilter() throws Exception {
     String sqlQuery = "SELECT COUNT(*) FROM mytable WHERE 
ARRAY_TO_MV(CAST(DivAirportIDs AS BIGINT ARRAY)) > 0";


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to