This is an automated email from the ASF dual-hosted git repository. mbudiu pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/calcite.git
commit a3c3205871f0e03c4b0203977fd5fd3e53034d3a Author: Mihai Budiu <[email protected]> AuthorDate: Tue Aug 6 18:27:38 2024 -0700 [CALCITE-6464] Type inference for DECIMAL division seems incorrect Signed-off-by: Mihai Budiu <[email protected]> --- .../apache/calcite/rel/type/RelDataTypeSystem.java | 28 ++++++------ .../calcite/sql/type/RelDataTypeSystemTest.java | 50 ++++++++++++++++++++++ .../org/apache/calcite/test/SqlValidatorTest.java | 14 +++--- .../org/apache/calcite/test/TypeCoercionTest.java | 6 +-- core/src/test/resources/sql/measure-paper.iq | 14 +++--- .../org/apache/calcite/test/SqlOperatorTest.java | 10 ++--- 6 files changed, 86 insertions(+), 36 deletions(-) diff --git a/core/src/main/java/org/apache/calcite/rel/type/RelDataTypeSystem.java b/core/src/main/java/org/apache/calcite/rel/type/RelDataTypeSystem.java index 0c89381d8b..0028cbbfd1 100644 --- a/core/src/main/java/org/apache/calcite/rel/type/RelDataTypeSystem.java +++ b/core/src/main/java/org/apache/calcite/rel/type/RelDataTypeSystem.java @@ -259,7 +259,7 @@ public interface RelDataTypeSystem { * <li>Then the result type is a decimal with: * <ul> * <li>d = p1 - s1 + s2</li> - * <li>s < max(6, s1 + p2 + 1)</li> + * <li>s = max(6, s1 + p2 + 1)</li> * <li>p = d + s</li> * </ul> * </li> @@ -294,21 +294,21 @@ public interface RelDataTypeSystem { int s1 = type1.getScale(); int s2 = type2.getScale(); - final int maxNumericPrecision = getMaxNumericPrecision(); - int dout = - Math.min( - p1 - s1 + s2, - maxNumericPrecision); - + int d = p1 - s1 + s2; int scale = Math.max(6, s1 + p2 + 1); - scale = - Math.min( - scale, - maxNumericPrecision - dout); - scale = Math.min(scale, getMaxNumericScale()); + int precision = d + scale; + + // Rules from + // https://learn.microsoft.com/en-us/sql/t-sql/data-types/precision-scale-and-length-transact-sql + int bound = getMaxNumericPrecision() - 6; // 32 in the MS documentation + if (precision <= bound) { + scale = Math.min(scale, getMaxNumericScale() - (precision - scale)); + } else { + // precision > bound + scale = Math.min(6, scale); + } - int precision = dout + scale; - assert precision <= maxNumericPrecision; + precision = Math.min(precision, getMaxNumericPrecision()); assert precision > 0; RelDataType ret; diff --git a/core/src/test/java/org/apache/calcite/sql/type/RelDataTypeSystemTest.java b/core/src/test/java/org/apache/calcite/sql/type/RelDataTypeSystemTest.java index f363c933a5..6332f69dcf 100644 --- a/core/src/test/java/org/apache/calcite/sql/type/RelDataTypeSystemTest.java +++ b/core/src/test/java/org/apache/calcite/sql/type/RelDataTypeSystemTest.java @@ -149,12 +149,62 @@ class RelDataTypeSystemTest { RelDataType operand2 = f.createSqlType(SqlTypeName.DECIMAL, 10, 2); RelDataType dataType = + SqlStdOperatorTable.PLUS.inferReturnType(f, + Lists.newArrayList(operand1, operand2)); + assertEquals(12, dataType.getPrecision()); + assertEquals(2, dataType.getScale()); + + dataType = SqlStdOperatorTable.MINUS.inferReturnType(f, Lists.newArrayList(operand1, operand2)); assertEquals(12, dataType.getPrecision()); assertEquals(2, dataType.getScale()); } + @Test void testDecimalDivideReturnTypeInference() { + final SqlTypeFactoryImpl f = new Fixture().typeFactory; + RelDataType operand1 = f.createSqlType(SqlTypeName.DECIMAL, 6, 2); + RelDataType operand2 = f.createSqlType(SqlTypeName.DECIMAL, 6, 2); + + RelDataType dataType = + SqlStdOperatorTable.DIVIDE.inferReturnType(f, + Lists.newArrayList(operand1, operand2)); + assertEquals(15, dataType.getPrecision()); + assertEquals(6, dataType.getScale()); + } + + /** + * Tests that the return type inference for a division with a custom type system + * (max precision=28, max scale=10) works correctly. + * <a href="https://issues.apache.org/jira/browse/CALCITE-6464">[CALCITE-6464] + * Type inference for DECIMAL division seems incorrect</a> + */ + @Test void testCustomMaxPrecisionCustomMaxScaleDecimalDivideReturnTypeInference() { + /** + * Custom type system class that overrides the default max precision and max scale. + */ + final class CustomTypeSystem extends RelDataTypeSystemImpl { + @Override public int getMaxNumericPrecision() { + return 28; + } + + @Override public int getMaxNumericScale() { + return 10; + } + } + + final SqlTypeFactoryImpl f = new SqlTypeFactoryImpl(new CustomTypeSystem()); + + RelDataType operand1 = f.createSqlType(SqlTypeName.DECIMAL, 28, 10); + RelDataType operand2 = f.createSqlType(SqlTypeName.DECIMAL, 28, 10); + + RelDataType dataType = SqlStdOperatorTable.DIVIDE.inferReturnType(f, Lists + .newArrayList(operand1, operand2)); + assertEquals(SqlTypeName.DECIMAL, dataType.getSqlTypeName()); + assertEquals(28, dataType.getPrecision()); + assertEquals(6, dataType.getScale()); + } + @Test void testDecimalModReturnTypeInference() { final SqlTypeFactoryImpl f = new Fixture().typeFactory; RelDataType operand1 = f.createSqlType(SqlTypeName.DECIMAL, 10, 1); diff --git a/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java b/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java index fffb7591d0..de415956a1 100644 --- a/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java +++ b/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java @@ -2645,14 +2645,14 @@ public class SqlValidatorTest extends SqlValidatorTestCase { expr("cast(null as REAL) / cast(5 as DOUBLE)") .columnType("DOUBLE"); expr("cast(1 as DECIMAL(7, 3)) / 1.654") - .columnType("DECIMAL(15, 8) NOT NULL"); + .columnType("DECIMAL(15, 6) NOT NULL"); expr("cast(null as DECIMAL(7, 3)) / cast (1.654 as DOUBLE)") .columnType("DOUBLE"); expr("cast(null as DECIMAL(5, 2)) / cast(1 as BIGINT)") - .columnType("DECIMAL(19, 16)"); + .columnType("DECIMAL(19, 6)"); expr("cast(1 as DECIMAL(5, 2)) / cast(1 as INTEGER)") - .columnType("DECIMAL(16, 13) NOT NULL"); + .columnType("DECIMAL(16, 6) NOT NULL"); expr("cast(1 as DECIMAL(5, 2)) / cast(1 as SMALLINT)") .columnType("DECIMAL(11, 8) NOT NULL"); expr("cast(1 as DECIMAL(5, 2)) / cast(1 as TINYINT)") @@ -2661,15 +2661,15 @@ public class SqlValidatorTest extends SqlValidatorTestCase { expr("cast(1 as DECIMAL(5, 2)) / cast(1 as DECIMAL(5, 2))") .columnType("DECIMAL(13, 8) NOT NULL"); expr("cast(1 as DECIMAL(5, 2)) / cast(1 as DECIMAL(6, 2))") - .columnType("DECIMAL(14, 9) NOT NULL"); + .columnType("DECIMAL(14, 6) NOT NULL"); expr("cast(1 as DECIMAL(4, 2)) / cast(1 as DECIMAL(6, 4))") - .columnType("DECIMAL(15, 9) NOT NULL"); + .columnType("DECIMAL(15, 6) NOT NULL"); expr("cast(null as DECIMAL(4, 2)) / cast(1 as DECIMAL(6, 4))") - .columnType("DECIMAL(15, 9)"); + .columnType("DECIMAL(15, 6)"); expr("cast(1 as DECIMAL(4, 10)) / cast(null as DECIMAL(6, 19))") .columnType("DECIMAL(19, 6)"); expr("cast(1 as DECIMAL(19, 2)) / cast(1 as DECIMAL(19, 2))") - .columnType("DECIMAL(19, 0) NOT NULL"); + .columnType("DECIMAL(19, 6) NOT NULL"); expr("4/3") .columnType("INTEGER NOT NULL"); expr("-4.0/3") diff --git a/core/src/test/java/org/apache/calcite/test/TypeCoercionTest.java b/core/src/test/java/org/apache/calcite/test/TypeCoercionTest.java index 425db3cfb6..b20c32ebc7 100644 --- a/core/src/test/java/org/apache/calcite/test/TypeCoercionTest.java +++ b/core/src/test/java/org/apache/calcite/test/TypeCoercionTest.java @@ -309,9 +309,9 @@ class TypeCoercionTest { expr("'12.3'/cast(5 as double)") .columnType("DOUBLE NOT NULL"); expr("'12.3'/5.1") - .columnType("DECIMAL(19, 8) NOT NULL"); + .columnType("DECIMAL(19, 6) NOT NULL"); expr("12.3/'5.1'") - .columnType("DECIMAL(19, 8) NOT NULL"); + .columnType("DECIMAL(19, 6) NOT NULL"); // test binary arithmetic with two strings. expr("'12.3' + '5'") .columnType("DECIMAL(19, 9) NOT NULL"); @@ -320,7 +320,7 @@ class TypeCoercionTest { expr("'12.3' * '5'") .columnType("DECIMAL(19, 18) NOT NULL"); expr("'12.3' / '5'") - .columnType("DECIMAL(19, 0) NOT NULL"); + .columnType("DECIMAL(19, 6) NOT NULL"); } /** Test cases for binary comparison expressions. */ diff --git a/core/src/test/resources/sql/measure-paper.iq b/core/src/test/resources/sql/measure-paper.iq index 59b03be9e5..b7a0565b3b 100644 --- a/core/src/test/resources/sql/measure-paper.iq +++ b/core/src/test/resources/sql/measure-paper.iq @@ -79,13 +79,13 @@ CREATE VIEW "SummarizedOrders" AS SELECT "prodName", AVG("profitMargin") AS "m" FROM "SummarizedOrders" GROUP BY "prodName"; -+----------+-----------------+ -| prodName | m | -+----------+-----------------+ -| Acme | 0.6000000000000 | -| Happy | 0.5039682539682 | -| Whizz | 0.6666666666666 | -+----------+-----------------+ ++----------+----------+ +| prodName | m | ++----------+----------+ +| Acme | 0.600000 | +| Happy | 0.503968 | +| Whizz | 0.666666 | ++----------+----------+ (3 rows) !ok diff --git a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java index ce80c28372..c14b6d0003 100644 --- a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java +++ b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java @@ -2646,12 +2646,12 @@ public class SqlOperatorTest { isExactly("0.6")); f.checkScalarExact("10.0 / 5.0", "DECIMAL(9, 6) NOT NULL", "2"); f.checkScalarExact("1.0 / 3.0", "DECIMAL(8, 6) NOT NULL", "0.3333333333333333"); - f.checkScalarExact("100.1 / 0.0001", "DECIMAL(14, 7) NOT NULL", + f.checkScalarExact("100.1 / 0.0001", "DECIMAL(14, 6) NOT NULL", "1.001E+6"); - f.checkScalarExact("100.1 / 0.00000001", "DECIMAL(19, 8) NOT NULL", + f.checkScalarExact("100.1 / 0.00000001", "DECIMAL(19, 6) NOT NULL", "1.001E+10"); f.checkNull("1e1 / cast(null as float)"); - f.checkScalarExact("100.1 / 0.00000000000000001", "DECIMAL(19, 0) NOT NULL", + f.checkScalarExact("100.1 / 0.00000000000000001", "DECIMAL(19, 6) NOT NULL", "1.001E+19"); } @@ -9787,9 +9787,9 @@ public class SqlOperatorTest { f.checkScalar("safe_divide(cast(2 as bigint), cast(4 as bigint))", "0.5", "DOUBLE"); f.checkScalar("safe_divide(cast(15 as bigint), cast(1.2 as decimal(2,1)))", - "12.5", "DECIMAL(19, 0)"); + "12.5", "DECIMAL(19, 6)"); f.checkScalar("safe_divide(cast(4.5 as decimal(2,1)), cast(3 as bigint))", - "1.5", "DECIMAL(19, 18)"); + "1.5", "DECIMAL(19, 6)"); f.checkScalar("safe_divide(cast(4.5 as decimal(2,1)), " + "cast(1.5 as decimal(2, 1)))", "3", "DECIMAL(8, 6)"); f.checkScalar("safe_divide(cast(3 as double), cast(3 as bigint))",
