This is an automated email from the ASF dual-hosted git repository.

jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 8334add327 Fix incorrect semantics for BETWEEN on MV columns in the 
multi-stage query engine (#14135)
8334add327 is described below

commit 8334add327fc497689d0cc2c9d4371ecef4c17e5
Author: Yash Mayya <[email protected]>
AuthorDate: Sat Oct 5 12:04:44 2024 +0530

    Fix incorrect semantics for BETWEEN on MV columns in the multi-stage query 
engine (#14135)
---
 .../tests/MultiStageEngineIntegrationTest.java     | 77 ++++++++++++++++++++--
 .../calcite/sql2rel/PinotConvertletTable.java      | 52 +++++++++++++++
 2 files changed, 124 insertions(+), 5 deletions(-)

diff --git 
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/MultiStageEngineIntegrationTest.java
 
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/MultiStageEngineIntegrationTest.java
index ca9a20288b..7721198ecb 100644
--- 
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/MultiStageEngineIntegrationTest.java
+++ 
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/MultiStageEngineIntegrationTest.java
@@ -664,8 +664,8 @@ public class MultiStageEngineIntegrationTest extends 
BaseClusterIntegrationTestS
   @Test
   public void testMultiValueColumnGroupBy()
       throws Exception {
-    String pinotQuery = "SELECT count(*), arrayToMV(RandomAirports) FROM 
mytable "
-        + "GROUP BY arrayToMV(RandomAirports)";
+    String pinotQuery = "SELECT count(*), ARRAY_TO_MV(RandomAirports) FROM 
mytable "
+        + "GROUP BY ARRAY_TO_MV(RandomAirports)";
     JsonNode jsonNode = postQuery(pinotQuery);
     Assert.assertEquals(jsonNode.get("resultTable").get("rows").size(), 154);
   }
@@ -800,8 +800,8 @@ public class MultiStageEngineIntegrationTest extends 
BaseClusterIntegrationTestS
   public void testMultiValueColumnGroupByOrderBy()
       throws Exception {
     String pinotQuery =
-        "SELECT count(*), arrayToMV(RandomAirports) FROM mytable " + "GROUP BY 
arrayToMV(RandomAirports) "
-            + "ORDER BY arrayToMV(RandomAirports) DESC";
+        "SELECT count(*), ARRAY_TO_MV(RandomAirports) FROM mytable GROUP BY 
ARRAY_TO_MV(RandomAirports) "
+            + "ORDER BY ARRAY_TO_MV(RandomAirports) DESC";
     JsonNode jsonNode = postQuery(pinotQuery);
     Assert.assertEquals(jsonNode.get("resultTable").get("rows").size(), 154);
   }
@@ -896,9 +896,76 @@ public class MultiStageEngineIntegrationTest extends 
BaseClusterIntegrationTestS
     assertNoError(jsonNode);
   }
 
+  @Test
+  public void testBetween()
+      throws Exception {
+    String sqlQuery = "SELECT COUNT(*) FROM mytable WHERE ArrDelay BETWEEN 10 
AND 50";
+    JsonNode jsonNode = postQuery(sqlQuery);
+    assertNoError(jsonNode);
+    
assertEquals(jsonNode.get("resultTable").get("rows").get(0).get(0).asInt(), 
18572);
+
+    String explainQuery = "EXPLAIN PLAN FOR " + sqlQuery;
+    jsonNode = postQuery(explainQuery);
+    assertNoError(jsonNode);
+    String plan = 
jsonNode.get("resultTable").get("rows").get(0).get(1).asText();
+    // Ensure that the BETWEEN filter predicate was converted to >= and <=
+    Assert.assertFalse(plan.contains("BETWEEN"));
+    Assert.assertTrue(plan.contains(">="));
+    Assert.assertTrue(plan.contains("<="));
+
+    // No rows should be returned since lower bound is greater than upper bound
+    sqlQuery = "SELECT COUNT(*) FROM mytable WHERE ARRAY_TO_MV(RandomAirports) 
BETWEEN 'SUN' AND 'GTR'";
+    jsonNode = postQuery(sqlQuery);
+    assertNoError(jsonNode);
+    
assertEquals(jsonNode.get("resultTable").get("rows").get(0).get(0).asInt(), 0);
+
+    explainQuery = "EXPLAIN PLAN FOR " + sqlQuery;
+    jsonNode = postQuery(explainQuery);
+    assertNoError(jsonNode);
+    plan = jsonNode.get("resultTable").get("rows").get(0).get(1).asText();
+    // Ensure that the BETWEEN filter predicate was not converted to >= and <=
+    Assert.assertTrue(plan.contains("BETWEEN"));
+    Assert.assertFalse(plan.contains(">="));
+    Assert.assertFalse(plan.contains("<="));
+
+    // Expect a non-zero result this time since we're using BETWEEN SYMMETRIC
+    sqlQuery = "SELECT COUNT(*) FROM mytable WHERE ARRAY_TO_MV(RandomAirports) 
BETWEEN SYMMETRIC 'SUN' AND 'GTR'";
+    jsonNode = postQuery(sqlQuery);
+    assertNoError(jsonNode);
+    
assertEquals(jsonNode.get("resultTable").get("rows").get(0).get(0).asInt(), 
57007);
+
+    explainQuery = "EXPLAIN PLAN FOR " + sqlQuery;
+    jsonNode = postQuery(explainQuery);
+    assertNoError(jsonNode);
+    plan = jsonNode.get("resultTable").get("rows").get(0).get(1).asText();
+    // Ensure that the BETWEEN filter predicate was not converted to >= and <=
+    Assert.assertTrue(plan.contains("BETWEEN"));
+    Assert.assertFalse(plan.contains(">="));
+    Assert.assertFalse(plan.contains("<="));
+
+    // Test NOT BETWEEN
+    sqlQuery = "SELECT COUNT(*) FROM mytable WHERE ARRAY_TO_MV(RandomAirports) 
NOT BETWEEN 'GTR' AND 'SUN'";
+    jsonNode = postQuery(sqlQuery);
+    assertNoError(jsonNode);
+    
assertEquals(jsonNode.get("resultTable").get("rows").get(0).get(0).asInt(), 
58538);
+
+    explainQuery =
+        "SET " + 
CommonConstants.Broker.Request.QueryOptionKey.EXPLAIN_ASKING_SERVERS + "=true; 
EXPLAIN PLAN FOR "
+            + sqlQuery;
+    jsonNode = postQuery(explainQuery);
+    assertNoError(jsonNode);
+    plan = jsonNode.get("resultTable").get("rows").get(0).get(1).asText();
+    // Ensure that the BETWEEN filter predicate was not converted to >= and 
<=. Also ensure that the NOT filter is
+    // added.
+    Assert.assertTrue(plan.contains("BETWEEN"));
+    Assert.assertTrue(plan.contains("FilterNot"));
+    Assert.assertFalse(plan.contains(">="));
+    Assert.assertFalse(plan.contains("<="));
+  }
+
   @Test
   public void testMVNumericCastInFilter() throws Exception {
-    String sqlQuery = "SELECT COUNT(*) FROM mytable WHERE 
arrayToMV(CAST(DivAirportIDs AS BIGINT ARRAY)) > 0";
+    String sqlQuery = "SELECT COUNT(*) FROM mytable WHERE 
ARRAY_TO_MV(CAST(DivAirportIDs AS BIGINT ARRAY)) > 0";
     JsonNode jsonNode = postQuery(sqlQuery);
     assertNoError(jsonNode);
     
assertEquals(jsonNode.get("resultTable").get("rows").get(0).get(0).asInt(), 
15482);
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql2rel/PinotConvertletTable.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql2rel/PinotConvertletTable.java
index f8e8545530..392e0a3c46 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql2rel/PinotConvertletTable.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql2rel/PinotConvertletTable.java
@@ -23,11 +23,13 @@ import javax.annotation.Nullable;
 import org.apache.calcite.rex.RexBuilder;
 import org.apache.calcite.rex.RexNode;
 import org.apache.calcite.sql.SqlCall;
+import org.apache.calcite.sql.fun.SqlBetweenOperator;
 import org.apache.calcite.sql.fun.SqlStdOperatorTable;
 import org.apache.calcite.sql2rel.SqlRexContext;
 import org.apache.calcite.sql2rel.SqlRexConvertlet;
 import org.apache.calcite.sql2rel.SqlRexConvertletTable;
 import org.apache.calcite.sql2rel.StandardConvertletTable;
+import org.apache.calcite.util.Litmus;
 
 
 /**
@@ -37,6 +39,13 @@ import org.apache.calcite.sql2rel.StandardConvertletTable;
 public class PinotConvertletTable implements SqlRexConvertletTable {
 
   public static final PinotConvertletTable INSTANCE = new 
PinotConvertletTable();
+  private static final SqlBetweenOperator PINOT_BETWEEN =
+      new SqlBetweenOperator(SqlBetweenOperator.Flag.ASYMMETRIC, false) {
+        @Override
+        public boolean validRexOperands(int count, Litmus litmus) {
+          return litmus.succeed();
+        }
+      };
 
   private PinotConvertletTable() {
   }
@@ -49,6 +58,8 @@ public class PinotConvertletTable implements 
SqlRexConvertletTable {
         return TimestampAddConvertlet.INSTANCE;
       case TIMESTAMP_DIFF:
         return TimestampDiffConvertlet.INSTANCE;
+      case BETWEEN:
+        return BetweenConvertlet.INSTANCE;
       default:
         return StandardConvertletTable.INSTANCE.get(call);
     }
@@ -85,4 +96,45 @@ public class PinotConvertletTable implements 
SqlRexConvertletTable {
               cx.convertExpression(call.operand(2))));
     }
   }
+
+  /**
+   * Override the standard convertlet for BETWEEN to avoid the rewrite to >= 
AND <= for MV columns since that breaks
+   * the filter predicate's semantics.
+   */
+  private static class BetweenConvertlet implements SqlRexConvertlet {
+    private static final BetweenConvertlet INSTANCE = new BetweenConvertlet();
+
+    @Override
+    public RexNode convertCall(SqlRexContext cx, SqlCall call) {
+      if (call.operand(0) instanceof SqlCall && ((SqlCall) 
call.operand(0)).getOperator().getName()
+          .equals("ARRAY_TO_MV")) {
+        RexBuilder rexBuilder = cx.getRexBuilder();
+
+        SqlBetweenOperator betweenOperator = (SqlBetweenOperator) 
call.getOperator();
+
+        RexNode rexNode = 
rexBuilder.makeCall(cx.getValidator().getValidatedNodeType(call), PINOT_BETWEEN,
+            List.of(cx.convertExpression(call.operand(0)), 
cx.convertExpression(call.operand(1)),
+                cx.convertExpression(call.operand(2))));
+
+        // Since Pinot only has support for ASYMMETRIC BETWEEN, we need to 
rewrite SYMMETRIC BETWEEN, ASYMMETRIC NOT
+        // BETWEEN, and SYMMETRIC NOT BETWEEN to the equivalent BETWEEN 
expressions.
+
+        // (val BETWEEN SYMMETRIC x AND y) is equivalent to (val BETWEEN x AND 
y OR val BETWEEN y AND x)
+        if (betweenOperator.flag == SqlBetweenOperator.Flag.SYMMETRIC) {
+          RexNode flipped = 
rexBuilder.makeCall(cx.getValidator().getValidatedNodeType(call), PINOT_BETWEEN,
+              List.of(cx.convertExpression(call.operand(0)), 
cx.convertExpression(call.operand(2)),
+                  cx.convertExpression(call.operand(1))));
+          rexNode = rexBuilder.makeCall(SqlStdOperatorTable.OR, rexNode, 
flipped);
+        }
+
+        if (betweenOperator.isNegated()) {
+          rexNode = rexBuilder.makeCall(SqlStdOperatorTable.NOT, rexNode);
+        }
+
+        return rexNode;
+      } else {
+        return StandardConvertletTable.INSTANCE.convertBetween(cx, 
(SqlBetweenOperator) call.getOperator(), call);
+      }
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to