This is an automated email from the ASF dual-hosted git repository.

kgyrtkirk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git


The following commit(s) were added to refs/heads/main by this push:
     new 73846cceb1 [CALCITE-6435] SqlToRel conversion of IN expressions may 
lead to incorrect simplifications
73846cceb1 is described below

commit 73846cceb1841a3c7f7ab5cbe3c40611db07c148
Author: Zoltan Haindrich <[email protected]>
AuthorDate: Wed Jun 26 08:40:09 2024 +0000

    [CALCITE-6435] SqlToRel conversion of IN expressions may lead to incorrect 
simplifications
    
    Conversion path for comparisions generated from IN expressions was handling 
types differently.
    This may have lead to some over-simplification in some cases.
    
    Altered the conversion to do the full SqlToRex conversion steps for these 
generated nodes as well.
    Added an extra safeguard check to RexSimplify to prevent the bug from being 
triggered.
---
 .../java/org/apache/calcite/rex/RexSimplify.java   |  7 ++-
 .../apache/calcite/sql2rel/SqlToRelConverter.java  | 53 +++++++++-------------
 .../org/apache/calcite/test/RelOptRulesTest.java   |  9 ++++
 .../org/apache/calcite/test/RelOptRulesTest.xml    | 13 ++++++
 .../apache/calcite/test/SqlToRelConverterTest.xml  |  8 ++--
 core/src/test/resources/sql/sub-query.iq           | 33 ++++++++++++--
 .../org/apache/calcite/test/DruidAdapter2IT.java   |  4 +-
 .../org/apache/calcite/test/DruidAdapterIT.java    |  4 +-
 8 files changed, 87 insertions(+), 44 deletions(-)

diff --git a/core/src/main/java/org/apache/calcite/rex/RexSimplify.java 
b/core/src/main/java/org/apache/calcite/rex/RexSimplify.java
index 2584a53d9e..1f6733ba2b 100644
--- a/core/src/main/java/org/apache/calcite/rex/RexSimplify.java
+++ b/core/src/main/java/org/apache/calcite/rex/RexSimplify.java
@@ -1682,7 +1682,10 @@ public class RexSimplify {
             final RexLiteral literal = comparison.literal;
             final RexLiteral prevLiteral =
                 equalityConstantTerms.put(comparison.ref, literal);
-            if (prevLiteral != null && !literal.equals(prevLiteral)) {
+
+            if (prevLiteral != null
+                && literal.getType().equals(prevLiteral.getType())
+                && !literal.equals(prevLiteral)) {
               return rexBuilder.makeLiteral(false);
             }
           } else if (RexUtil.isReferenceOrAccess(left, true)
@@ -1753,7 +1756,7 @@ public class RexSimplify {
         if (literal2 == null) {
           continue;
         }
-        if (!literal1.equals(literal2)) {
+        if (literal1.getType().equals(literal2.getType()) && 
!literal1.equals(literal2)) {
           // If an expression is equal to two different constants,
           // it is not satisfiable
           return rexBuilder.makeLiteral(false);
diff --git 
a/core/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java 
b/core/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java
index b29bf3884c..80942fbf07 100644
--- a/core/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java
+++ b/core/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java
@@ -176,6 +176,8 @@ import org.apache.calcite.sql.validate.SqlValidatorScope;
 import org.apache.calcite.sql.validate.SqlValidatorTable;
 import org.apache.calcite.sql.validate.SqlValidatorUtil;
 import org.apache.calcite.sql.validate.SqlWithItemTableRef;
+import org.apache.calcite.sql2rel.SqlToRelConverter.Blackboard;
+import org.apache.calcite.sql2rel.SqlToRelConverter.SqlIdentifierFinder;
 import org.apache.calcite.tools.RelBuilder;
 import org.apache.calcite.tools.RelBuilderFactory;
 import org.apache.calcite.util.ImmutableBitSet;
@@ -1185,16 +1187,16 @@ public class SqlToRelConverter {
       }
       final SqlNode leftKeyNode = call.operand(0);
 
-      final List<RexNode> leftKeys;
+      final List<SqlNode> leftSqlKeys;
       switch (leftKeyNode.getKind()) {
       case ROW:
-        leftKeys = new ArrayList<>();
+        leftSqlKeys = new ArrayList<>();
         for (SqlNode sqlExpr : ((SqlBasicCall) leftKeyNode).getOperandList()) {
-          leftKeys.add(bb.convertExpression(sqlExpr));
+          leftSqlKeys.add(sqlExpr);
         }
         break;
       default:
-        leftKeys = ImmutableList.of(bb.convertExpression(leftKeyNode));
+        leftSqlKeys = ImmutableList.of(leftKeyNode);
       }
 
       if (query instanceof SqlNodeList) {
@@ -1205,7 +1207,7 @@ public class SqlToRelConverter {
           subQuery.expr =
               convertInToOr(
                   bb,
-                  leftKeys,
+                  leftSqlKeys,
                   valueList,
                   (SqlInOperator) call.getOperator());
           return;
@@ -1216,6 +1218,10 @@ public class SqlToRelConverter {
         // reference to Q below.
       }
 
+      final List<RexNode> leftKeys = leftSqlKeys.stream()
+          .map(bb::convertExpression)
+          .collect(toImmutableList());
+
       // Project out the search columns from the left side
 
       // Q1:
@@ -1719,12 +1725,11 @@ public class SqlToRelConverter {
    */
   private @Nullable RexNode convertInToOr(
       final Blackboard bb,
-      final List<RexNode> leftKeys,
+      final List<SqlNode> leftKeys,
       SqlNodeList valuesList,
       SqlInOperator op) {
     final List<RexNode> comparisons = new ArrayList<>();
     for (SqlNode rightVals : valuesList) {
-      RexNode rexComparison;
       final SqlOperator comparisonOp;
       if (op instanceof SqlQuantifyOperator) {
         comparisonOp =
@@ -1733,25 +1738,23 @@ public class SqlToRelConverter {
       } else {
         comparisonOp = SqlStdOperatorTable.EQUALS;
       }
+      RexNode rexComparison;
       if (leftKeys.size() == 1) {
-        rexComparison =
-            rexBuilder.makeCall(comparisonOp,
-                leftKeys.get(0),
-                ensureSqlType(leftKeys.get(0).getType(),
-                    bb.convertExpression(rightVals)));
+        SqlCall sqlCall =
+            comparisonOp.createCall(rightVals.getParserPosition(), 
leftKeys.get(0), rightVals);
+        rexComparison = bb.convertExpression(sqlCall);
       } else {
         assert rightVals instanceof SqlCall;
         final SqlBasicCall call = (SqlBasicCall) rightVals;
         assert (call.getOperator() instanceof SqlRowOperator)
             && call.operandCount() == leftKeys.size();
         rexComparison =
-            RexUtil.composeConjunction(rexBuilder,
-                Util.transform(
-                    Pair.zip(leftKeys, call.getOperandList()),
-                    pair -> rexBuilder.makeCall(comparisonOp, pair.left,
-                        // TODO: remove requireNonNull when checkerframework 
issue resolved
-                        ensureSqlType(requireNonNull(pair.left, 
"pair.left").getType(),
-                            bb.convertExpression(pair.right)))));
+            RexUtil.composeConjunction(
+              rexBuilder, Util.transform(
+                  Pair.zip(leftKeys, call.getOperandList()),
+                  pair -> bb.convertExpression(
+                      comparisonOp.createCall(
+                        rightVals.getParserPosition(), pair.left, 
pair.right))));
       }
       comparisons.add(rexComparison);
     }
@@ -1770,18 +1773,6 @@ public class SqlToRelConverter {
     }
   }
 
-  /** Ensures that an expression has a given {@link SqlTypeName}, applying a
-   * cast if necessary. If the expression already has the right type family,
-   * returns the expression unchanged. */
-  private RexNode ensureSqlType(RelDataType type, RexNode node) {
-    if (type.getSqlTypeName() == node.getType().getSqlTypeName()
-        || (type.getSqlTypeName() == SqlTypeName.VARCHAR
-            && node.getType().getSqlTypeName() == SqlTypeName.CHAR)) {
-      return node;
-    }
-    return rexBuilder.ensureType(type, node, true);
-  }
-
   /**
    * Gets the list size threshold under which {@link #convertInToOr} is used.
    * Lists of this size or greater will instead be converted to use a join
diff --git a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java 
b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
index 101b8bbde1..bae5c31b69 100644
--- a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
+++ b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
@@ -1662,6 +1662,15 @@ class RelOptRulesTest extends RelOptTestBase {
         .check();
   }
 
+  @Test void testIncorrectInType() {
+    final String sql = "select ename from emp "
+        + "  where ename in ( 'Sebastian' ) and ename = 'Sebastian' and deptno 
< 100";
+    sql(sql)
+        .withTrim(true)
+        .withRule()
+        .checkUnchanged();
+  }
+
   @Test void testSemiJoinRule() {
     final String sql = "select dept.* from dept join (\n"
         + "  select distinct deptno from emp\n"
diff --git 
a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml 
b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
index 3611a6a08e..8482cc7083 100644
--- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
+++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
@@ -4950,6 +4950,19 @@ LogicalUnion(all=[true])
     LogicalTableScan(table=[[CATALOG, SALES, EMP]])
   LogicalProject(EXPR$0=[LOWER($1)])
     LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+    </Resource>
+  </TestCase>
+  <TestCase name="testIncorrectInType">
+    <Resource name="sql">
+      <![CDATA[select ename from emp   where ename in ( 'Sebastian' ) and 
ename = 'Sebastian' and deptno < 100]]>
+    </Resource>
+    <Resource name="planBefore">
+      <![CDATA[
+LogicalProject(ENAME=[$0])
+  LogicalFilter(condition=[AND(=($0, 'Sebastian'), <($1, 100))])
+    LogicalProject(ENAME=[$1], DEPTNO=[$7])
+      LogicalTableScan(table=[[CATALOG, SALES, EMP]])
 ]]>
     </Resource>
   </TestCase>
diff --git 
a/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml 
b/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml
index 43f8ae0506..22475fff2e 100644
--- a/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml
+++ b/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml
@@ -272,7 +272,7 @@ GROUP by deptno, job]]>
     </Resource>
     <Resource name="plan">
       <![CDATA[
-LogicalProject(JOB_NAME=[CASE(SEARCH($1, Sarg['810000', '820000']:CHAR(6)), 
$1, 'error':VARCHAR(10))], EXPR$1=[$2])
+LogicalProject(JOB_NAME=[CASE(SEARCH($1, Sarg['810000':VARCHAR(10), 
'820000':VARCHAR(10)]:VARCHAR(10)), $1, 'error':VARCHAR(10))], EXPR$1=[$2])
   LogicalAggregate(group=[{0, 1}], EXPR$1=[COUNT()])
     LogicalProject(DEPTNO=[$7], JOB=[$2], EMPNO=[$0])
       LogicalFilter(condition=[OR(<>($2, ''), =($2, '810000'), =($2, 
'820000'))])
@@ -561,7 +561,7 @@ GROUP BY GROUPING SETS ((empno, derived_col),(empno))]]>
     <Resource name="plan">
       <![CDATA[
 LogicalAggregate(group=[{0, 1}], groups=[[{0, 1}, {0}]])
-  LogicalProject(EMPNO=[$0], DERIVED_COL=[CASE(SEARCH($1, Sarg['Eric', 
'Fred']:CHAR(4)), 'CEO  ', 'Other')])
+  LogicalProject(EMPNO=[$0], DERIVED_COL=[CASE(SEARCH($1, 
Sarg['Eric':VARCHAR(20), 'Fred':VARCHAR(20)]:VARCHAR(20)), 'CEO  ', 'Other')])
     LogicalTableScan(table=[[CATALOG, SALES, EMP]])
 ]]>
     </Resource>
@@ -579,7 +579,7 @@ GROUP BY GROUPING SETS (
     <Resource name="plan">
       <![CDATA[
 LogicalAggregate(group=[{0, 1}], groups=[[{0, 1}, {0}]])
-  LogicalProject(EMPNO=[$0], EXPR$1=[CASE(SEARCH($1, Sarg['Eric', 
'Fred']:CHAR(4)), 'Manager', 'Other  ')])
+  LogicalProject(EMPNO=[$0], EXPR$1=[CASE(SEARCH($1, Sarg['Eric':VARCHAR(20), 
'Fred':VARCHAR(20)]:VARCHAR(20)), 'Manager', 'Other  ')])
     LogicalTableScan(table=[[CATALOG, SALES, EMP]])
 ]]>
     </Resource>
@@ -2240,7 +2240,7 @@ group by case when coalesce(ename, 'a') in ('1', '2') 
then 'CKA' else 'QT' END]]
     <Resource name="plan">
       <![CDATA[
 LogicalAggregate(group=[{0}], EXPR$1=[COUNT(DISTINCT $1)])
-  LogicalProject(EXPR$0=[CASE(SEARCH($1, Sarg['1', '2']:CHAR(1)), 'CKA', 'QT 
')], DEPTNO=[$7])
+  LogicalProject(EXPR$0=[CASE(SEARCH($1, Sarg['1':VARCHAR(20), 
'2':VARCHAR(20)]:VARCHAR(20)), 'CKA', 'QT ')], DEPTNO=[$7])
     LogicalTableScan(table=[[CATALOG, SALES, EMP]])
 ]]>
     </Resource>
diff --git a/core/src/test/resources/sql/sub-query.iq 
b/core/src/test/resources/sql/sub-query.iq
index 16cf991f13..69b05cad8d 100644
--- a/core/src/test/resources/sql/sub-query.iq
+++ b/core/src/test/resources/sql/sub-query.iq
@@ -3149,7 +3149,7 @@ select * from "scott".emp where comm in (300, 500, null);
 
 !ok
 
-EnumerableCalc(expr#0..7=[{inputs}], expr#8=[Sarg[300:DECIMAL(7, 2), 
500:DECIMAL(7, 2)]:DECIMAL(7, 2)], expr#9=[SEARCH($t6, $t8)], 
proj#0..7=[{exprs}], $condition=[$t9])
+EnumerableCalc(expr#0..7=[{inputs}], expr#8=[CAST($t6):DECIMAL(12, 2)], 
expr#9=[Sarg[300:DECIMAL(12, 2), 500:DECIMAL(12, 2)]:DECIMAL(12, 2)], 
expr#10=[SEARCH($t8, $t9)], proj#0..7=[{exprs}], $condition=[$t10])
   EnumerableTableScan(table=[[scott, EMP]])
 !plan
 
@@ -3177,7 +3177,7 @@ select *, comm in (300, 500, null) as i from "scott".emp;
 
 !ok
 
-EnumerableCalc(expr#0..7=[{inputs}], expr#8=[Sarg[300:DECIMAL(7, 2), 
500:DECIMAL(7, 2)]:DECIMAL(7, 2)], expr#9=[SEARCH($t6, $t8)], 
expr#10=[null:BOOLEAN], expr#11=[OR($t9, $t10)], proj#0..7=[{exprs}], I=[$t11])
+EnumerableCalc(expr#0..7=[{inputs}], expr#8=[CAST($t6):DECIMAL(12, 2)], 
expr#9=[Sarg[300:DECIMAL(12, 2), 500:DECIMAL(12, 2)]:DECIMAL(12, 2)], 
expr#10=[SEARCH($t8, $t9)], expr#11=[null:BOOLEAN], expr#12=[OR($t10, $t11)], 
proj#0..7=[{exprs}], I=[$t12])
   EnumerableTableScan(table=[[scott, EMP]])
 !plan
 
@@ -3218,7 +3218,34 @@ select *, comm not in (300, 500, null) as i from 
"scott".emp;
 
 !ok
 
-EnumerableCalc(expr#0..7=[{inputs}], expr#8=[Sarg[(-∞..300:DECIMAL(7, 2)), 
(300:DECIMAL(7, 2)..500:DECIMAL(7, 2)), (500:DECIMAL(7, 2)..+∞)]:DECIMAL(7, 
2)], expr#9=[SEARCH($t6, $t8)], expr#10=[null:BOOLEAN], expr#11=[AND($t9, 
$t10)], proj#0..7=[{exprs}], I=[$t11])
+EnumerableCalc(expr#0..7=[{inputs}], expr#8=[CAST($t6):DECIMAL(12, 2)], 
expr#9=[Sarg[(-∞..300:DECIMAL(12, 2)), (300:DECIMAL(12, 2)..500:DECIMAL(12, 
2)), (500:DECIMAL(12, 2)..+∞)]:DECIMAL(12, 2)], expr#10=[SEARCH($t8, $t9)], 
expr#11=[null:BOOLEAN], expr#12=[AND($t10, $t11)], proj#0..7=[{exprs}], 
I=[$t12])
+  EnumerableTableScan(table=[[scott, EMP]])
+!plan
+
+# Previous NOT IN expressions in conjunction form
+select *, (comm <> 300 and comm <> 500 and comm <> null) as i from "scott".emp;
++-------+--------+-----------+------+------------+---------+---------+--------+-------+
+| EMPNO | ENAME  | JOB       | MGR  | HIREDATE   | SAL     | COMM    | DEPTNO 
| I     |
++-------+--------+-----------+------+------------+---------+---------+--------+-------+
+|  7369 | SMITH  | CLERK     | 7902 | 1980-12-17 |  800.00 |         |     20 
|       |
+|  7499 | ALLEN  | SALESMAN  | 7698 | 1981-02-20 | 1600.00 |  300.00 |     30 
| false |
+|  7521 | WARD   | SALESMAN  | 7698 | 1981-02-22 | 1250.00 |  500.00 |     30 
| false |
+|  7566 | JONES  | MANAGER   | 7839 | 1981-02-04 | 2975.00 |         |     20 
|       |
+|  7654 | MARTIN | SALESMAN  | 7698 | 1981-09-28 | 1250.00 | 1400.00 |     30 
|       |
+|  7698 | BLAKE  | MANAGER   | 7839 | 1981-01-05 | 2850.00 |         |     30 
|       |
+|  7782 | CLARK  | MANAGER   | 7839 | 1981-06-09 | 2450.00 |         |     10 
|       |
+|  7788 | SCOTT  | ANALYST   | 7566 | 1987-04-19 | 3000.00 |         |     20 
|       |
+|  7839 | KING   | PRESIDENT |      | 1981-11-17 | 5000.00 |         |     10 
|       |
+|  7844 | TURNER | SALESMAN  | 7698 | 1981-09-08 | 1500.00 |    0.00 |     30 
|       |
+|  7876 | ADAMS  | CLERK     | 7788 | 1987-05-23 | 1100.00 |         |     20 
|       |
+|  7900 | JAMES  | CLERK     | 7698 | 1981-12-03 |  950.00 |         |     30 
|       |
+|  7902 | FORD   | ANALYST   | 7566 | 1981-12-03 | 3000.00 |         |     20 
|       |
+|  7934 | MILLER | CLERK     | 7782 | 1982-01-23 | 1300.00 |         |     10 
|       |
++-------+--------+-----------+------+------------+---------+---------+--------+-------+
+(14 rows)
+
+!ok
+EnumerableCalc(expr#0..7=[{inputs}], expr#8=[CAST($t6):DECIMAL(12, 2)], 
expr#9=[Sarg[(-∞..300:DECIMAL(12, 2)), (300:DECIMAL(12, 2)..500:DECIMAL(12, 
2)), (500:DECIMAL(12, 2)..+∞)]:DECIMAL(12, 2)], expr#10=[SEARCH($t8, $t9)], 
expr#11=[null:BOOLEAN], expr#12=[AND($t10, $t11)], proj#0..7=[{exprs}], 
I=[$t12])
   EnumerableTableScan(table=[[scott, EMP]])
 !plan
 
diff --git a/druid/src/test/java/org/apache/calcite/test/DruidAdapter2IT.java 
b/druid/src/test/java/org/apache/calcite/test/DruidAdapter2IT.java
index 4beddd31ee..7bf53579db 100644
--- a/druid/src/test/java/org/apache/calcite/test/DruidAdapter2IT.java
+++ b/druid/src/test/java/org/apache/calcite/test/DruidAdapter2IT.java
@@ -1030,7 +1030,7 @@ public class DruidAdapter2IT {
         + "intervals=[[1900-01-09T00:00:00.000Z/2992-01-10T00:00:00.000Z]], "
         + "filter=[AND("
         + "=($3, 'High Top Dried Mushrooms'), "
-        + "SEARCH($87, Sarg['Q2', 'Q3']:CHAR(2)), "
+        + "SEARCH($87, Sarg['Q2':VARCHAR, 'Q3':VARCHAR]:VARCHAR), "
         + "=($30, 'WA'))], "
         + "projects=[[$30, $29, $3]], groups=[{0, 1, 2}], aggs=[[]])\n";
     sql(sql)
@@ -1072,7 +1072,7 @@ public class DruidAdapter2IT {
         + "intervals=[[1900-01-09T00:00:00.000Z/2992-01-10T00:00:00.000Z]], "
         + "filter=[AND("
         + "=($3, 'High Top Dried Mushrooms'), "
-        + "SEARCH($87, Sarg['Q2', 'Q3']:CHAR(2)), "
+        + "SEARCH($87, Sarg['Q2':VARCHAR, 'Q3':VARCHAR]:VARCHAR), "
         + "=($30, 'WA'))], "
         + "projects=[[$30, $29, $3]])\n";
     sql(sql)
diff --git a/druid/src/test/java/org/apache/calcite/test/DruidAdapterIT.java 
b/druid/src/test/java/org/apache/calcite/test/DruidAdapterIT.java
index b03f3e19e8..b3d62632c0 100644
--- a/druid/src/test/java/org/apache/calcite/test/DruidAdapterIT.java
+++ b/druid/src/test/java/org/apache/calcite/test/DruidAdapterIT.java
@@ -1305,7 +1305,7 @@ public class DruidAdapterIT {
         + "intervals=[[1900-01-09T00:00:00.000Z/2992-01-10T00:00:00.000Z]], "
         + "filter=[AND("
         + "=($3, 'High Top Dried Mushrooms'), "
-        + "SEARCH($87, Sarg['Q2', 'Q3']:CHAR(2)), "
+        + "SEARCH($87, Sarg['Q2':VARCHAR, 'Q3':VARCHAR]:VARCHAR), "
         + "=($30, 'WA'))], "
         + "projects=[[$30, $29, $3]], groups=[{0, 1, 2}], aggs=[[]])\n";
     sql(sql)
@@ -1347,7 +1347,7 @@ public class DruidAdapterIT {
         + "intervals=[[1900-01-09T00:00:00.000Z/2992-01-10T00:00:00.000Z]], "
         + "filter=[AND("
         + "=($3, 'High Top Dried Mushrooms'), "
-        + "SEARCH($87, Sarg['Q2', 'Q3']:CHAR(2)), "
+        + "SEARCH($87, Sarg['Q2':VARCHAR, 'Q3':VARCHAR]:VARCHAR), "
         + "=($30, 'WA'))], "
         + "projects=[[$30, $29, $3]])\n";
     sql(sql)

Reply via email to