This is an automated email from the ASF dual-hosted git repository.

mbudiu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git


The following commit(s) were added to refs/heads/main by this push:
     new e348a068a7 [CALCITE-6733] Type inferred by coercion for comparisons 
with decimal is too narrow
e348a068a7 is described below

commit e348a068a753c655d5c7d5c4952f30ab97b6dc07
Author: Mihai Budiu <[email protected]>
AuthorDate: Mon Dec 16 14:17:30 2024 -0800

    [CALCITE-6733] Type inferred by coercion for comparisons with decimal is 
too narrow
    
    Signed-off-by: Mihai Budiu <[email protected]>
---
 .../sql/validate/implicit/AbstractTypeCoercion.java  | 17 +++++++----------
 .../org/apache/calcite/test/TypeCoercionTest.java    |  4 +++-
 .../org/apache/calcite/test/RelOptRulesTest.xml      |  4 ++--
 core/src/test/resources/sql/agg.iq                   |  2 +-
 core/src/test/resources/sql/misc.iq                  | 11 +++++++++++
 core/src/test/resources/sql/sub-query.iq             | 20 ++++++++++----------
 6 files changed, 34 insertions(+), 24 deletions(-)

diff --git 
a/core/src/main/java/org/apache/calcite/sql/validate/implicit/AbstractTypeCoercion.java
 
b/core/src/main/java/org/apache/calcite/sql/validate/implicit/AbstractTypeCoercion.java
index db8679285c..87106eebb8 100644
--- 
a/core/src/main/java/org/apache/calcite/sql/validate/implicit/AbstractTypeCoercion.java
+++ 
b/core/src/main/java/org/apache/calcite/sql/validate/implicit/AbstractTypeCoercion.java
@@ -612,17 +612,14 @@ public abstract class AbstractTypeCoercion implements 
TypeCoercion {
     }
 
     if (SqlTypeUtil.isExactNumeric(type1) && 
SqlTypeUtil.isExactNumeric(type2)) {
-      if (SqlTypeUtil.isDecimal(type1)) {
-        // Use max precision
+      if (SqlTypeUtil.isDecimal(type1) || SqlTypeUtil.isDecimal(type2)) {
+        // Precision used must be large enough to fit either of the types
+        int maxScale = Math.max(type1.getScale(), type2.getScale());
         RelDataType result =
-            factory.createSqlType(type1.getSqlTypeName(),
-                Math.max(type1.getPrecision(), type2.getPrecision()), 
type1.getScale());
-        return factory.createTypeWithNullability(result, type1.isNullable() || 
type2.isNullable());
-      } else if (SqlTypeUtil.isDecimal(type2)) {
-        // Use max precision
-        RelDataType result =
-            factory.createSqlType(type2.getSqlTypeName(),
-                Math.max(type1.getPrecision(), type2.getPrecision()), 
type2.getScale());
+            factory.createSqlType(SqlTypeName.DECIMAL,
+                Math.max(type1.getPrecision() - type1.getScale(),
+                         type2.getPrecision() - type2.getScale()) + maxScale,
+                maxScale);
         return factory.createTypeWithNullability(result, type1.isNullable() || 
type2.isNullable());
       }
       if (type1.getPrecision() > type2.getPrecision()) {
diff --git a/core/src/test/java/org/apache/calcite/test/TypeCoercionTest.java 
b/core/src/test/java/org/apache/calcite/test/TypeCoercionTest.java
index 5c62932999..e5375686b3 100644
--- a/core/src/test/java/org/apache/calcite/test/TypeCoercionTest.java
+++ b/core/src/test/java/org/apache/calcite/test/TypeCoercionTest.java
@@ -384,9 +384,11 @@ class TypeCoercionTest {
         f.typeFactory.createSqlType(SqlTypeName.DECIMAL, 7, 1);
     RelDataType decimal104 =
         f.typeFactory.createSqlType(SqlTypeName.DECIMAL, 10, 4);
+    RelDataType decimal144 =
+        f.typeFactory.createSqlType(SqlTypeName.DECIMAL, 14, 4);
     f.comparisonCommonType(decimal54, decimal71, decimal104);
     f.comparisonCommonType(decimal54, f.doubleType, f.doubleType);
-    f.comparisonCommonType(decimal54, f.intType, decimal104);
+    f.comparisonCommonType(decimal54, f.intType, decimal144);
     // CHAR/VARCHAR
     f.comparisonCommonType(f.charType, f.varcharType, f.varcharType);
     f.comparisonCommonType(f.intType, f.charType, f.intType);
diff --git 
a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml 
b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
index 01b40a44e0..5f08355a97 100644
--- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
+++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
@@ -13299,14 +13299,14 @@ LogicalProject($0=[$3], $1=[$4])
     <Resource name="planBefore">
       <![CDATA[
 LogicalProject(ENAME=[$1])
-  LogicalFilter(condition=[>(CAST($5):DECIMAL(10, 1) NOT NULL, 100.0)])
+  LogicalFilter(condition=[>(CAST($5):DECIMAL(11, 1) NOT NULL, 100.0)])
     LogicalTableScan(table=[[CATALOG, SALES, EMP]])
 ]]>
     </Resource>
     <Resource name="planAfter">
       <![CDATA[
 LogicalProject(ENAME=[$1])
-  LogicalCalc(expr#0..8=[{inputs}], expr#9=[10:BIGINT], expr#10=[*($t5, $t9)], 
expr#11=[true], expr#12=[Reinterpret($t10, $t11)], expr#13=[Reinterpret($t12)], 
expr#14=[100.0:DECIMAL(10, 1)], expr#15=[Reinterpret($t14)], expr#16=[>($t13, 
$t15)], proj#0..8=[{exprs}], $condition=[$t16])
+  LogicalCalc(expr#0..8=[{inputs}], expr#9=[10:BIGINT], expr#10=[*($t5, $t9)], 
expr#11=[Reinterpret($t10)], expr#12=[Reinterpret($t11)], 
expr#13=[100.0:DECIMAL(11, 1)], expr#14=[Reinterpret($t13)], expr#15=[>($t12, 
$t14)], proj#0..8=[{exprs}], $condition=[$t15])
     LogicalTableScan(table=[[CATALOG, SALES, EMP]])
 ]]>
     </Resource>
diff --git a/core/src/test/resources/sql/agg.iq 
b/core/src/test/resources/sql/agg.iq
index 3c5230b583..dfd56acd35 100644
--- a/core/src/test/resources/sql/agg.iq
+++ b/core/src/test/resources/sql/agg.iq
@@ -3120,7 +3120,7 @@ group by dept.deptno;
 
 !ok
 EnumerableAggregate(group=[{0}], S=[COLLECT($1) WITHIN GROUP ([1 DESC])], 
S1=[COLLECT($1) WITHIN GROUP ([2])], S2=[COLLECT($1) WITHIN GROUP ([1]) FILTER 
$3])
-  EnumerableCalc(expr#0..3=[{inputs}], expr#4=[1], 
expr#5=[CAST($t2):DECIMAL(10, 2)], expr#6=[2000.00:DECIMAL(10, 2)], 
expr#7=[>($t5, $t6)], expr#8=[IS TRUE($t7)], DEPTNO=[$t0], SAL=[$t2], 
$f2=[$t4], $f3=[$t8])
+  EnumerableCalc(expr#0..3=[{inputs}], expr#4=[1], 
expr#5=[CAST($t2):DECIMAL(12, 2)], expr#6=[2000.00:DECIMAL(12, 2)], 
expr#7=[>($t5, $t6)], expr#8=[IS TRUE($t7)], DEPTNO=[$t0], SAL=[$t2], 
$f2=[$t4], $f3=[$t8])
     EnumerableHashJoin(condition=[=($0, $3)], joinType=[inner])
       EnumerableCalc(expr#0..2=[{inputs}], DEPTNO=[$t0])
         EnumerableTableScan(table=[[scott, DEPT]])
diff --git a/core/src/test/resources/sql/misc.iq 
b/core/src/test/resources/sql/misc.iq
index 4fa8080b35..23a602bcbb 100644
--- a/core/src/test/resources/sql/misc.iq
+++ b/core/src/test/resources/sql/misc.iq
@@ -18,6 +18,17 @@
 !use post
 !set outputformat mysql
 
+# [CALCITE-6733] Type inferred by coercion for comparisons with decimal is too 
narrow
+SELECT ASCII('8') >= ABS(1.1806236821);
++--------+
+| EXPR$0 |
++--------+
+| true   |
++--------+
+(1 row)
+
+!ok
+
 # [CALCITE-356] Allow column references of the form schema.table.column
 select "hr"."emps"."empid"
 from "hr"."emps";
diff --git a/core/src/test/resources/sql/sub-query.iq 
b/core/src/test/resources/sql/sub-query.iq
index 45edb3e4c9..248ec165ba 100644
--- a/core/src/test/resources/sql/sub-query.iq
+++ b/core/src/test/resources/sql/sub-query.iq
@@ -2908,7 +2908,7 @@ EnumerableCalc(expr#0..1=[{inputs}], expr#2=[IS 
NULL($t1)], DEPTNO=[$t0], $condi
     EnumerableAggregate(group=[{0}])
       EnumerableCalc(expr#0..1=[{inputs}], expr#2=[true], expr#3=[1], 
expr#4=[>($t1, $t3)], i=[$t2], $condition=[$t4])
         EnumerableAggregate(group=[{7}], c=[COUNT()])
-          EnumerableCalc(expr#0..7=[{inputs}], expr#8=[CAST($t5):DECIMAL(10, 
2)], expr#9=[3000.00:DECIMAL(10, 2)], expr#10=[=($t8, $t9)], expr#11=[IS NOT 
NULL($t7)], expr#12=[AND($t10, $t11)], proj#0..7=[{exprs}], $condition=[$t12])
+          EnumerableCalc(expr#0..7=[{inputs}], expr#8=[CAST($t5):DECIMAL(12, 
2)], expr#9=[3000.00:DECIMAL(12, 2)], expr#10=[=($t8, $t9)], expr#11=[IS NOT 
NULL($t7)], expr#12=[AND($t10, $t11)], proj#0..7=[{exprs}], $condition=[$t12])
             EnumerableTableScan(table=[[scott, EMP]])
 !plan
 
@@ -2934,7 +2934,7 @@ EnumerableCalc(expr#0..1=[{inputs}], expr#2=[IS 
NULL($t1)], DEPTNO=[$t0], U=[$t2
     EnumerableAggregate(group=[{0}])
       EnumerableCalc(expr#0..1=[{inputs}], expr#2=[true], expr#3=[1], 
expr#4=[>($t1, $t3)], i=[$t2], $condition=[$t4])
         EnumerableAggregate(group=[{7}], c=[COUNT()])
-          EnumerableCalc(expr#0..7=[{inputs}], expr#8=[CAST($t5):DECIMAL(10, 
2)], expr#9=[3000.00:DECIMAL(10, 2)], expr#10=[=($t8, $t9)], expr#11=[IS NOT 
NULL($t7)], expr#12=[AND($t10, $t11)], proj#0..7=[{exprs}], $condition=[$t12])
+          EnumerableCalc(expr#0..7=[{inputs}], expr#8=[CAST($t5):DECIMAL(12, 
2)], expr#9=[3000.00:DECIMAL(12, 2)], expr#10=[=($t8, $t9)], expr#11=[IS NOT 
NULL($t7)], expr#12=[AND($t10, $t11)], proj#0..7=[{exprs}], $condition=[$t12])
             EnumerableTableScan(table=[[scott, EMP]])
 !plan
 
@@ -2960,7 +2960,7 @@ EnumerableCalc(expr#0..1=[{inputs}], expr#2=[IS NOT 
NULL($t1)], DEPTNO=[$t0], U=
     EnumerableAggregate(group=[{0}])
       EnumerableCalc(expr#0..1=[{inputs}], expr#2=[true], expr#3=[1], 
expr#4=[>($t1, $t3)], i=[$t2], $condition=[$t4])
         EnumerableAggregate(group=[{7}], c=[COUNT()])
-          EnumerableCalc(expr#0..7=[{inputs}], expr#8=[CAST($t5):DECIMAL(10, 
2)], expr#9=[3000.00:DECIMAL(10, 2)], expr#10=[=($t8, $t9)], expr#11=[IS NOT 
NULL($t7)], expr#12=[AND($t10, $t11)], proj#0..7=[{exprs}], $condition=[$t12])
+          EnumerableCalc(expr#0..7=[{inputs}], expr#8=[CAST($t5):DECIMAL(12, 
2)], expr#9=[3000.00:DECIMAL(12, 2)], expr#10=[=($t8, $t9)], expr#11=[IS NOT 
NULL($t7)], expr#12=[AND($t10, $t11)], proj#0..7=[{exprs}], $condition=[$t12])
             EnumerableTableScan(table=[[scott, EMP]])
 !plan
 
@@ -3302,7 +3302,7 @@ select *, (comm <> 300 and comm <> 500 and comm <> null) 
as i from "scott".emp;
 (14 rows)
 
 !ok
-EnumerableCalc(expr#0..7=[{inputs}], expr#8=[CAST($t6):DECIMAL(10, 2)], 
expr#9=[Sarg[(-∞..300.00:DECIMAL(10, 2)), (300.00:DECIMAL(10, 
2)..500.00:DECIMAL(10, 2)), (500.00:DECIMAL(10, 2)..+∞)]:DECIMAL(10, 2)], 
expr#10=[SEARCH($t8, $t9)], expr#11=[null:BOOLEAN], expr#12=[AND($t10, $t11)], 
proj#0..7=[{exprs}], I=[$t12])
+EnumerableCalc(expr#0..7=[{inputs}], expr#8=[CAST($t6):DECIMAL(12, 2)], 
expr#9=[Sarg[(-∞..300.00:DECIMAL(12, 2)), (300.00:DECIMAL(12, 
2)..500.00:DECIMAL(12, 2)), (500.00:DECIMAL(12, 2)..+∞)]:DECIMAL(12, 2)], 
expr#10=[SEARCH($t8, $t9)], expr#11=[null:BOOLEAN], expr#12=[AND($t10, $t11)], 
proj#0..7=[{exprs}], I=[$t12])
   EnumerableTableScan(table=[[scott, EMP]])
 !plan
 
@@ -3934,10 +3934,10 @@ select comm, comm in (500, 300, 0) from emp;
 
 !ok
 
-EnumerableCalc(expr#0..6=[{inputs}], expr#7=[0], expr#8=[=($t2, $t7)], 
expr#9=[false], expr#10=[CAST($t1):DECIMAL(10, 2)], expr#11=[IS NULL($t10)], 
expr#12=[null:BOOLEAN], expr#13=[IS NOT NULL($t6)], expr#14=[true], 
expr#15=[<($t3, $t2)], expr#16=[CASE($t8, $t9, $t11, $t12, $t13, $t14, $t15, 
$t12, $t9)], COMM=[$t1], EXPR$1=[$t16])
+EnumerableCalc(expr#0..6=[{inputs}], expr#7=[0], expr#8=[=($t2, $t7)], 
expr#9=[false], expr#10=[CAST($t1):DECIMAL(12, 2)], expr#11=[IS NULL($t10)], 
expr#12=[null:BOOLEAN], expr#13=[IS NOT NULL($t6)], expr#14=[true], 
expr#15=[<($t3, $t2)], expr#16=[CASE($t8, $t9, $t11, $t12, $t13, $t14, $t15, 
$t12, $t9)], COMM=[$t1], EXPR$1=[$t16])
   EnumerableMergeJoin(condition=[=($4, $5)], joinType=[left])
     EnumerableSort(sort0=[$4], dir0=[ASC])
-      EnumerableCalc(expr#0..3=[{inputs}], expr#4=[CAST($t1):DECIMAL(10, 2)], 
proj#0..4=[{exprs}])
+      EnumerableCalc(expr#0..3=[{inputs}], expr#4=[CAST($t1):DECIMAL(12, 2)], 
proj#0..4=[{exprs}])
         EnumerableNestedLoopJoin(condition=[true], joinType=[inner])
           EnumerableCalc(expr#0..7=[{inputs}], EMPNO=[$t0], COMM=[$t6])
             EnumerableTableScan(table=[[scott, EMP]])
@@ -4047,10 +4047,10 @@ select comm, (comm, comm) in ((500, 500), (300, 300), 
(0, 0)) from emp;
 
 !ok
 
-EnumerableCalc(expr#0..8=[{inputs}], expr#9=[0], expr#10=[=($t2, $t9)], 
expr#11=[false], expr#12=[CAST($t1):DECIMAL(10, 2)], expr#13=[IS NULL($t12)], 
expr#14=[null:BOOLEAN], expr#15=[IS NOT NULL($t8)], expr#16=[true], 
expr#17=[<($t3, $t2)], expr#18=[CASE($t10, $t11, $t13, $t14, $t15, $t16, $t17, 
$t14, $t11)], COMM=[$t1], EXPR$1=[$t18])
+EnumerableCalc(expr#0..8=[{inputs}], expr#9=[0], expr#10=[=($t2, $t9)], 
expr#11=[false], expr#12=[CAST($t1):DECIMAL(12, 2)], expr#13=[IS NULL($t12)], 
expr#14=[null:BOOLEAN], expr#15=[IS NOT NULL($t8)], expr#16=[true], 
expr#17=[<($t3, $t2)], expr#18=[CASE($t10, $t11, $t13, $t14, $t15, $t16, $t17, 
$t14, $t11)], COMM=[$t1], EXPR$1=[$t18])
   EnumerableMergeJoin(condition=[AND(=($4, $6), =($5, $7))], joinType=[left])
     EnumerableSort(sort0=[$4], sort1=[$5], dir0=[ASC], dir1=[ASC])
-      EnumerableCalc(expr#0..3=[{inputs}], expr#4=[CAST($t1):DECIMAL(10, 2)], 
proj#0..4=[{exprs}], COMM1=[$t4])
+      EnumerableCalc(expr#0..3=[{inputs}], expr#4=[CAST($t1):DECIMAL(12, 2)], 
proj#0..4=[{exprs}], COMM1=[$t4])
         EnumerableNestedLoopJoin(condition=[true], joinType=[inner])
           EnumerableCalc(expr#0..7=[{inputs}], EMPNO=[$t0], COMM=[$t6])
             EnumerableTableScan(table=[[scott, EMP]])
@@ -4087,10 +4087,10 @@ select comm, (comm, comm) in ((500, 500), (300, 300), 
(0, 0), (null , null)) fro
 
 !ok
 
-EnumerableCalc(expr#0..8=[{inputs}], expr#9=[0], expr#10=[=($t2, $t9)], 
expr#11=[false], expr#12=[CAST($t1):DECIMAL(10, 2)], expr#13=[IS NULL($t12)], 
expr#14=[null:BOOLEAN], expr#15=[IS NOT NULL($t8)], expr#16=[true], 
expr#17=[<($t3, $t2)], expr#18=[CASE($t10, $t11, $t13, $t14, $t15, $t16, $t17, 
$t14, $t11)], COMM=[$t1], EXPR$1=[$t18])
+EnumerableCalc(expr#0..8=[{inputs}], expr#9=[0], expr#10=[=($t2, $t9)], 
expr#11=[false], expr#12=[CAST($t1):DECIMAL(12, 2)], expr#13=[IS NULL($t12)], 
expr#14=[null:BOOLEAN], expr#15=[IS NOT NULL($t8)], expr#16=[true], 
expr#17=[<($t3, $t2)], expr#18=[CASE($t10, $t11, $t13, $t14, $t15, $t16, $t17, 
$t14, $t11)], COMM=[$t1], EXPR$1=[$t18])
   EnumerableMergeJoin(condition=[AND(=($4, $6), =($5, $7))], joinType=[left])
     EnumerableSort(sort0=[$4], sort1=[$5], dir0=[ASC], dir1=[ASC])
-      EnumerableCalc(expr#0..3=[{inputs}], expr#4=[CAST($t1):DECIMAL(10, 2)], 
proj#0..4=[{exprs}], COMM1=[$t4])
+      EnumerableCalc(expr#0..3=[{inputs}], expr#4=[CAST($t1):DECIMAL(12, 2)], 
proj#0..4=[{exprs}], COMM1=[$t4])
         EnumerableNestedLoopJoin(condition=[true], joinType=[inner])
           EnumerableCalc(expr#0..7=[{inputs}], EMPNO=[$t0], COMM=[$t6])
             EnumerableTableScan(table=[[scott, EMP]])

Reply via email to