This is an automated email from the ASF dual-hosted git repository.

mbudiu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git

commit a3c3205871f0e03c4b0203977fd5fd3e53034d3a
Author: Mihai Budiu <[email protected]>
AuthorDate: Tue Aug 6 18:27:38 2024 -0700

    [CALCITE-6464] Type inference for DECIMAL division seems incorrect
    
    Signed-off-by: Mihai Budiu <[email protected]>
---
 .../apache/calcite/rel/type/RelDataTypeSystem.java | 28 ++++++------
 .../calcite/sql/type/RelDataTypeSystemTest.java    | 50 ++++++++++++++++++++++
 .../org/apache/calcite/test/SqlValidatorTest.java  | 14 +++---
 .../org/apache/calcite/test/TypeCoercionTest.java  |  6 +--
 core/src/test/resources/sql/measure-paper.iq       | 14 +++---
 .../org/apache/calcite/test/SqlOperatorTest.java   | 10 ++---
 6 files changed, 86 insertions(+), 36 deletions(-)

diff --git 
a/core/src/main/java/org/apache/calcite/rel/type/RelDataTypeSystem.java 
b/core/src/main/java/org/apache/calcite/rel/type/RelDataTypeSystem.java
index 0c89381d8b..0028cbbfd1 100644
--- a/core/src/main/java/org/apache/calcite/rel/type/RelDataTypeSystem.java
+++ b/core/src/main/java/org/apache/calcite/rel/type/RelDataTypeSystem.java
@@ -259,7 +259,7 @@ public interface RelDataTypeSystem {
    * <li>Then the result type is a decimal with:
    *   <ul>
    *   <li>d = p1 - s1 + s2</li>
-   *   <li>s &lt; max(6, s1 + p2 + 1)</li>
+   *   <li>s = max(6, s1 + p2 + 1)</li>
    *   <li>p = d + s</li>
    *   </ul>
    * </li>
@@ -294,21 +294,21 @@ public interface RelDataTypeSystem {
         int s1 = type1.getScale();
         int s2 = type2.getScale();
 
-        final int maxNumericPrecision = getMaxNumericPrecision();
-        int dout =
-            Math.min(
-                p1 - s1 + s2,
-                maxNumericPrecision);
-
+        int d = p1 - s1 + s2;
         int scale = Math.max(6, s1 + p2 + 1);
-        scale =
-            Math.min(
-                scale,
-                maxNumericPrecision - dout);
-        scale = Math.min(scale, getMaxNumericScale());
+        int precision = d + scale;
+
+  // Rules from
+  // 
https://learn.microsoft.com/en-us/sql/t-sql/data-types/precision-scale-and-length-transact-sql
+        int bound = getMaxNumericPrecision() - 6;  // 32 in the MS 
documentation
+        if (precision <= bound) {
+          scale = Math.min(scale, getMaxNumericScale() - (precision - scale));
+        } else {
+          // precision > bound
+          scale = Math.min(6, scale);
+        }
 
-        int precision = dout + scale;
-        assert precision <= maxNumericPrecision;
+        precision = Math.min(precision, getMaxNumericPrecision());
         assert precision > 0;
 
         RelDataType ret;
diff --git 
a/core/src/test/java/org/apache/calcite/sql/type/RelDataTypeSystemTest.java 
b/core/src/test/java/org/apache/calcite/sql/type/RelDataTypeSystemTest.java
index f363c933a5..6332f69dcf 100644
--- a/core/src/test/java/org/apache/calcite/sql/type/RelDataTypeSystemTest.java
+++ b/core/src/test/java/org/apache/calcite/sql/type/RelDataTypeSystemTest.java
@@ -149,12 +149,62 @@ class RelDataTypeSystemTest {
     RelDataType operand2 = f.createSqlType(SqlTypeName.DECIMAL, 10, 2);
 
     RelDataType dataType =
+        SqlStdOperatorTable.PLUS.inferReturnType(f,
+            Lists.newArrayList(operand1, operand2));
+    assertEquals(12, dataType.getPrecision());
+    assertEquals(2, dataType.getScale());
+
+    dataType =
         SqlStdOperatorTable.MINUS.inferReturnType(f,
             Lists.newArrayList(operand1, operand2));
     assertEquals(12, dataType.getPrecision());
     assertEquals(2, dataType.getScale());
   }
 
+  @Test void testDecimalDivideReturnTypeInference() {
+    final SqlTypeFactoryImpl f = new Fixture().typeFactory;
+    RelDataType operand1 = f.createSqlType(SqlTypeName.DECIMAL, 6, 2);
+    RelDataType operand2 = f.createSqlType(SqlTypeName.DECIMAL, 6, 2);
+
+    RelDataType dataType =
+        SqlStdOperatorTable.DIVIDE.inferReturnType(f,
+            Lists.newArrayList(operand1, operand2));
+    assertEquals(15, dataType.getPrecision());
+    assertEquals(6, dataType.getScale());
+  }
+
+  /**
+   * Tests that the return type inference for a division with a custom type 
system
+   * (max precision=28, max scale=10) works correctly.
+   * <a 
href="https://issues.apache.org/jira/browse/CALCITE-6464";>[CALCITE-6464]
+   * Type inference for DECIMAL division seems incorrect</a>
+   */
+  @Test void 
testCustomMaxPrecisionCustomMaxScaleDecimalDivideReturnTypeInference() {
+    /**
+     * Custom type system class that overrides the default max precision and 
max scale.
+     */
+    final class CustomTypeSystem extends RelDataTypeSystemImpl {
+      @Override public int getMaxNumericPrecision() {
+        return 28;
+      }
+
+      @Override public int getMaxNumericScale() {
+        return 10;
+      }
+    }
+
+    final SqlTypeFactoryImpl f = new SqlTypeFactoryImpl(new 
CustomTypeSystem());
+
+    RelDataType operand1 = f.createSqlType(SqlTypeName.DECIMAL, 28, 10);
+    RelDataType operand2 = f.createSqlType(SqlTypeName.DECIMAL, 28, 10);
+
+    RelDataType dataType = SqlStdOperatorTable.DIVIDE.inferReturnType(f, Lists
+        .newArrayList(operand1, operand2));
+    assertEquals(SqlTypeName.DECIMAL, dataType.getSqlTypeName());
+    assertEquals(28, dataType.getPrecision());
+    assertEquals(6, dataType.getScale());
+  }
+
   @Test void testDecimalModReturnTypeInference() {
     final SqlTypeFactoryImpl f = new Fixture().typeFactory;
     RelDataType operand1 = f.createSqlType(SqlTypeName.DECIMAL, 10, 1);
diff --git a/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java 
b/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java
index fffb7591d0..de415956a1 100644
--- a/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java
+++ b/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java
@@ -2645,14 +2645,14 @@ public class SqlValidatorTest extends 
SqlValidatorTestCase {
     expr("cast(null as REAL) / cast(5 as DOUBLE)")
         .columnType("DOUBLE");
     expr("cast(1 as DECIMAL(7, 3)) / 1.654")
-        .columnType("DECIMAL(15, 8) NOT NULL");
+        .columnType("DECIMAL(15, 6) NOT NULL");
     expr("cast(null as DECIMAL(7, 3)) / cast (1.654 as DOUBLE)")
         .columnType("DOUBLE");
 
     expr("cast(null as DECIMAL(5, 2)) / cast(1 as BIGINT)")
-        .columnType("DECIMAL(19, 16)");
+        .columnType("DECIMAL(19, 6)");
     expr("cast(1 as DECIMAL(5, 2)) / cast(1 as INTEGER)")
-        .columnType("DECIMAL(16, 13) NOT NULL");
+        .columnType("DECIMAL(16, 6) NOT NULL");
     expr("cast(1 as DECIMAL(5, 2)) / cast(1 as SMALLINT)")
         .columnType("DECIMAL(11, 8) NOT NULL");
     expr("cast(1 as DECIMAL(5, 2)) / cast(1 as TINYINT)")
@@ -2661,15 +2661,15 @@ public class SqlValidatorTest extends 
SqlValidatorTestCase {
     expr("cast(1 as DECIMAL(5, 2)) / cast(1 as DECIMAL(5, 2))")
         .columnType("DECIMAL(13, 8) NOT NULL");
     expr("cast(1 as DECIMAL(5, 2)) / cast(1 as DECIMAL(6, 2))")
-        .columnType("DECIMAL(14, 9) NOT NULL");
+        .columnType("DECIMAL(14, 6) NOT NULL");
     expr("cast(1 as DECIMAL(4, 2)) / cast(1 as DECIMAL(6, 4))")
-        .columnType("DECIMAL(15, 9) NOT NULL");
+        .columnType("DECIMAL(15, 6) NOT NULL");
     expr("cast(null as DECIMAL(4, 2)) / cast(1 as DECIMAL(6, 4))")
-        .columnType("DECIMAL(15, 9)");
+        .columnType("DECIMAL(15, 6)");
     expr("cast(1 as DECIMAL(4, 10)) / cast(null as DECIMAL(6, 19))")
         .columnType("DECIMAL(19, 6)");
     expr("cast(1 as DECIMAL(19, 2)) / cast(1 as DECIMAL(19, 2))")
-        .columnType("DECIMAL(19, 0) NOT NULL");
+        .columnType("DECIMAL(19, 6) NOT NULL");
     expr("4/3")
         .columnType("INTEGER NOT NULL");
     expr("-4.0/3")
diff --git a/core/src/test/java/org/apache/calcite/test/TypeCoercionTest.java 
b/core/src/test/java/org/apache/calcite/test/TypeCoercionTest.java
index 425db3cfb6..b20c32ebc7 100644
--- a/core/src/test/java/org/apache/calcite/test/TypeCoercionTest.java
+++ b/core/src/test/java/org/apache/calcite/test/TypeCoercionTest.java
@@ -309,9 +309,9 @@ class TypeCoercionTest {
     expr("'12.3'/cast(5 as double)")
         .columnType("DOUBLE NOT NULL");
     expr("'12.3'/5.1")
-        .columnType("DECIMAL(19, 8) NOT NULL");
+        .columnType("DECIMAL(19, 6) NOT NULL");
     expr("12.3/'5.1'")
-        .columnType("DECIMAL(19, 8) NOT NULL");
+        .columnType("DECIMAL(19, 6) NOT NULL");
     // test binary arithmetic with two strings.
     expr("'12.3' + '5'")
         .columnType("DECIMAL(19, 9) NOT NULL");
@@ -320,7 +320,7 @@ class TypeCoercionTest {
     expr("'12.3' * '5'")
         .columnType("DECIMAL(19, 18) NOT NULL");
     expr("'12.3' / '5'")
-        .columnType("DECIMAL(19, 0) NOT NULL");
+        .columnType("DECIMAL(19, 6) NOT NULL");
   }
 
   /** Test cases for binary comparison expressions. */
diff --git a/core/src/test/resources/sql/measure-paper.iq 
b/core/src/test/resources/sql/measure-paper.iq
index 59b03be9e5..b7a0565b3b 100644
--- a/core/src/test/resources/sql/measure-paper.iq
+++ b/core/src/test/resources/sql/measure-paper.iq
@@ -79,13 +79,13 @@ CREATE VIEW "SummarizedOrders" AS
 SELECT "prodName", AVG("profitMargin") AS "m"
 FROM "SummarizedOrders"
 GROUP BY "prodName";
-+----------+-----------------+
-| prodName | m               |
-+----------+-----------------+
-| Acme     | 0.6000000000000 |
-| Happy    | 0.5039682539682 |
-| Whizz    | 0.6666666666666 |
-+----------+-----------------+
++----------+----------+
+| prodName | m        |
++----------+----------+
+| Acme     | 0.600000 |
+| Happy    | 0.503968 |
+| Whizz    | 0.666666 |
++----------+----------+
 (3 rows)
 
 !ok
diff --git a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java 
b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
index ce80c28372..c14b6d0003 100644
--- a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
+++ b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
@@ -2646,12 +2646,12 @@ public class SqlOperatorTest {
         isExactly("0.6"));
     f.checkScalarExact("10.0 / 5.0", "DECIMAL(9, 6) NOT NULL", "2");
     f.checkScalarExact("1.0 / 3.0", "DECIMAL(8, 6) NOT NULL", 
"0.3333333333333333");
-    f.checkScalarExact("100.1 / 0.0001", "DECIMAL(14, 7) NOT NULL",
+    f.checkScalarExact("100.1 / 0.0001", "DECIMAL(14, 6) NOT NULL",
         "1.001E+6");
-    f.checkScalarExact("100.1 / 0.00000001", "DECIMAL(19, 8) NOT NULL",
+    f.checkScalarExact("100.1 / 0.00000001", "DECIMAL(19, 6) NOT NULL",
         "1.001E+10");
     f.checkNull("1e1 / cast(null as float)");
-    f.checkScalarExact("100.1 / 0.00000000000000001", "DECIMAL(19, 0) NOT 
NULL",
+    f.checkScalarExact("100.1 / 0.00000000000000001", "DECIMAL(19, 6) NOT 
NULL",
         "1.001E+19");
   }
 
@@ -9787,9 +9787,9 @@ public class SqlOperatorTest {
     f.checkScalar("safe_divide(cast(2 as bigint), cast(4 as bigint))",
         "0.5", "DOUBLE");
     f.checkScalar("safe_divide(cast(15 as bigint), cast(1.2 as decimal(2,1)))",
-        "12.5", "DECIMAL(19, 0)");
+        "12.5", "DECIMAL(19, 6)");
     f.checkScalar("safe_divide(cast(4.5 as decimal(2,1)), cast(3 as bigint))",
-        "1.5", "DECIMAL(19, 18)");
+        "1.5", "DECIMAL(19, 6)");
     f.checkScalar("safe_divide(cast(4.5 as decimal(2,1)), "
         + "cast(1.5 as decimal(2, 1)))", "3", "DECIMAL(8, 6)");
     f.checkScalar("safe_divide(cast(3 as double), cast(3 as bigint))",

Reply via email to