This is an automated email from the ASF dual-hosted git repository.

mbudiu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git


The following commit(s) were added to refs/heads/main by this push:
     new 5e14cb2eb0 [CALCITE-6427] Use a higher precision for DECIMAL 
intermediate results for some aggregate functions like STDDEV
5e14cb2eb0 is described below

commit 5e14cb2eb0a56cf81e84e3459e6fe64d41e9a5c3
Author: Mihai Budiu <[email protected]>
AuthorDate: Tue Jun 4 14:53:51 2024 -0700

    [CALCITE-6427] Use a higher precision for DECIMAL intermediate results for 
some aggregate functions like STDDEV
    
    Signed-off-by: Mihai Budiu <[email protected]>
---
 .../calcite/sql2rel/StandardConvertletTable.java   | 48 ++++++++++++++++------
 1 file changed, 35 insertions(+), 13 deletions(-)

diff --git 
a/core/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java 
b/core/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java
index 82e1ceeaea..e04f2e7f3c 100644
--- a/core/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java
+++ b/core/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java
@@ -1591,12 +1591,13 @@ public class StandardConvertletTable extends 
ReflectiveConvertletTable {
       //     / (count(x1, x2) - 1)
       final SqlParserPos pos = SqlParserPos.ZERO;
       final SqlLiteral nullLiteral = SqlLiteral.createNull(SqlParserPos.ZERO);
+      final RelDataType highPrecision = 
AvgVarianceConvertlet.highPrecision(cx, varType);
 
       final RexNode arg0Rex = cx.convertExpression(arg0Input);
       final RexNode arg1Rex = cx.convertExpression(arg1Input);
 
-      final SqlNode arg0 = getCastedSqlNode(arg0Input, varType, pos, arg0Rex);
-      final SqlNode arg1 = getCastedSqlNode(arg1Input, varType, pos, arg1Rex);
+      final SqlNode arg0 = getCastedSqlNode(arg0Input, highPrecision, pos, 
arg0Rex);
+      final SqlNode arg1 = getCastedSqlNode(arg1Input, highPrecision, pos, 
arg1Rex);
       final SqlNode argSquared = SqlStdOperatorTable.MULTIPLY.createCall(pos, 
arg0, arg1);
       final SqlNode sumArgSquared;
       final SqlNode sum0;
@@ -1622,7 +1623,7 @@ public class StandardConvertletTable extends 
ReflectiveConvertletTable {
 
       final SqlNode sumSquared = SqlStdOperatorTable.MULTIPLY.createCall(pos, 
sum0, sum1);
       final SqlNode countCasted =
-          getCastedSqlNode(count, varType, pos, cx.convertExpression(count));
+          getCastedSqlNode(count, highPrecision, pos, 
cx.convertExpression(count));
 
       final SqlNode avgSumSquared =
           SqlStdOperatorTable.DIVIDE.createCall(pos, sumSquared, countCasted);
@@ -1637,7 +1638,7 @@ public class StandardConvertletTable extends 
ReflectiveConvertletTable {
             new SqlCase(SqlParserPos.ZERO, countCasted,
                 SqlNodeList.of(
                     SqlStdOperatorTable.EQUALS.createCall(pos, countCasted, 
one)),
-                SqlNodeList.of(getCastedSqlNode(nullLiteral, varType, pos, 
null)),
+                SqlNodeList.of(getCastedSqlNode(nullLiteral, highPrecision, 
pos, null)),
                 SqlStdOperatorTable.MINUS.createCall(pos, countCasted, one));
       }
 
@@ -1706,6 +1707,26 @@ public class StandardConvertletTable extends 
ReflectiveConvertletTable {
           pos, sumCast, count);
     }
 
+    /**
+     * Compute a higher precision version of a type.
+     *
+     * @return If type is a DECIMAL type, return a type with double the 
precision and scale
+     * if possible.  Otherwise, return the type unchanged. */
+    private static RelDataType highPrecision(final SqlRexContext cx, final 
RelDataType type) {
+      if (type.getSqlTypeName() == SqlTypeName.DECIMAL) {
+        RelDataTypeFactory typeFactory = cx.getValidator().getTypeFactory();
+        return typeFactory.createSqlType(
+            type.getSqlTypeName(),
+            Math.min(
+                type.getPrecision() * 2,
+                
typeFactory.getTypeSystem().getMaxPrecision(SqlTypeName.DECIMAL)),
+            Math.min(
+                type.getScale() * 2,
+                typeFactory.getTypeSystem().getMaxScale(SqlTypeName.DECIMAL)));
+      }
+      return type;
+    }
+
     private static SqlNode expandVariance(
         final SqlNode argInput,
         final RelDataType varType,
@@ -1732,28 +1753,29 @@ public class StandardConvertletTable extends 
ReflectiveConvertletTable {
       //     (sum(x * x) - sum(x) * sum(x) / count(x))
       //     / (count(x) - 1)
       final SqlParserPos pos = SqlParserPos.ZERO;
+      final RelDataType highPrecision = highPrecision(cx, varType);
 
       final SqlNode arg =
-          getCastedSqlNode(argInput, varType, pos,
+          getCastedSqlNode(argInput, highPrecision, pos,
               cx.convertExpression(argInput));
 
       final SqlNode argSquared =
           SqlStdOperatorTable.MULTIPLY.createCall(pos, arg, arg);
       final SqlNode argSquaredCasted =
-          getCastedSqlNode(argSquared, varType, pos,
+          getCastedSqlNode(argSquared, highPrecision, pos,
               cx.convertExpression(argSquared));
       final SqlNode sumArgSquared =
           SqlStdOperatorTable.SUM.createCall(pos, argSquaredCasted);
       final SqlNode sumArgSquaredCasted =
-          getCastedSqlNode(sumArgSquared, varType, pos,
+          getCastedSqlNode(sumArgSquared, highPrecision, pos,
               cx.convertExpression(sumArgSquared));
       final SqlNode sum = SqlStdOperatorTable.SUM.createCall(pos, arg);
       final SqlNode sumCasted =
-          getCastedSqlNode(sum, varType, pos, cx.convertExpression(sum));
+          getCastedSqlNode(sum, highPrecision, pos, cx.convertExpression(sum));
       final SqlNode sumSquared =
           SqlStdOperatorTable.MULTIPLY.createCall(pos, sumCasted, sumCasted);
       final SqlNode sumSquaredCasted =
-          getCastedSqlNode(sumSquared, varType, pos,
+          getCastedSqlNode(sumSquared, highPrecision, pos,
               cx.convertExpression(sumSquared));
       final SqlNode count = SqlStdOperatorTable.COUNT.createCall(pos, arg);
       final SqlNode countCasted =
@@ -1762,13 +1784,13 @@ public class StandardConvertletTable extends 
ReflectiveConvertletTable {
           SqlStdOperatorTable.DIVIDE.createCall(pos, sumSquaredCasted,
               countCasted);
       final SqlNode avgSumSquaredCasted =
-          getCastedSqlNode(avgSumSquared, varType, pos,
+          getCastedSqlNode(avgSumSquared, highPrecision, pos,
               cx.convertExpression(avgSumSquared));
       final SqlNode diff =
           SqlStdOperatorTable.MINUS.createCall(pos, sumArgSquaredCasted,
               avgSumSquaredCasted);
       final SqlNode diffCasted =
-          getCastedSqlNode(diff, varType, pos, cx.convertExpression(diff));
+          getCastedSqlNode(diff, highPrecision, pos, 
cx.convertExpression(diff));
       final SqlNode denominator;
       if (biased) {
         denominator = countCasted;
@@ -1779,13 +1801,13 @@ public class StandardConvertletTable extends 
ReflectiveConvertletTable {
             new SqlCase(SqlParserPos.ZERO, count,
                 SqlNodeList.of(
                     SqlStdOperatorTable.EQUALS.createCall(pos, count, one)),
-                SqlNodeList.of(getCastedSqlNode(nullLiteral, varType, pos, 
null)),
+                SqlNodeList.of(getCastedSqlNode(nullLiteral, highPrecision, 
pos, null)),
                 SqlStdOperatorTable.MINUS.createCall(pos, count, one));
       }
       final SqlNode div =
           SqlStdOperatorTable.DIVIDE.createCall(pos, diffCasted, denominator);
       final SqlNode divCasted =
-          getCastedSqlNode(div, varType, pos, cx.convertExpression(div));
+          getCastedSqlNode(div, highPrecision, pos, cx.convertExpression(div));
 
       SqlNode result = div;
       if (sqrt) {

Reply via email to