This is an automated email from the ASF dual-hosted git repository.
mbudiu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git
The following commit(s) were added to refs/heads/main by this push:
new 5e14cb2eb0 [CALCITE-6427] Use a higher precision for DECIMAL
intermediate results for some aggregate functions like STDDEV
5e14cb2eb0 is described below
commit 5e14cb2eb0a56cf81e84e3459e6fe64d41e9a5c3
Author: Mihai Budiu <[email protected]>
AuthorDate: Tue Jun 4 14:53:51 2024 -0700
[CALCITE-6427] Use a higher precision for DECIMAL intermediate results for
some aggregate functions like STDDEV
Signed-off-by: Mihai Budiu <[email protected]>
---
.../calcite/sql2rel/StandardConvertletTable.java | 48 ++++++++++++++++------
1 file changed, 35 insertions(+), 13 deletions(-)
diff --git
a/core/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java
b/core/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java
index 82e1ceeaea..e04f2e7f3c 100644
--- a/core/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java
+++ b/core/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java
@@ -1591,12 +1591,13 @@ public class StandardConvertletTable extends
ReflectiveConvertletTable {
// / (count(x1, x2) - 1)
final SqlParserPos pos = SqlParserPos.ZERO;
final SqlLiteral nullLiteral = SqlLiteral.createNull(SqlParserPos.ZERO);
+ final RelDataType highPrecision =
AvgVarianceConvertlet.highPrecision(cx, varType);
final RexNode arg0Rex = cx.convertExpression(arg0Input);
final RexNode arg1Rex = cx.convertExpression(arg1Input);
- final SqlNode arg0 = getCastedSqlNode(arg0Input, varType, pos, arg0Rex);
- final SqlNode arg1 = getCastedSqlNode(arg1Input, varType, pos, arg1Rex);
+ final SqlNode arg0 = getCastedSqlNode(arg0Input, highPrecision, pos,
arg0Rex);
+ final SqlNode arg1 = getCastedSqlNode(arg1Input, highPrecision, pos,
arg1Rex);
final SqlNode argSquared = SqlStdOperatorTable.MULTIPLY.createCall(pos,
arg0, arg1);
final SqlNode sumArgSquared;
final SqlNode sum0;
@@ -1622,7 +1623,7 @@ public class StandardConvertletTable extends
ReflectiveConvertletTable {
final SqlNode sumSquared = SqlStdOperatorTable.MULTIPLY.createCall(pos,
sum0, sum1);
final SqlNode countCasted =
- getCastedSqlNode(count, varType, pos, cx.convertExpression(count));
+ getCastedSqlNode(count, highPrecision, pos,
cx.convertExpression(count));
final SqlNode avgSumSquared =
SqlStdOperatorTable.DIVIDE.createCall(pos, sumSquared, countCasted);
@@ -1637,7 +1638,7 @@ public class StandardConvertletTable extends
ReflectiveConvertletTable {
new SqlCase(SqlParserPos.ZERO, countCasted,
SqlNodeList.of(
SqlStdOperatorTable.EQUALS.createCall(pos, countCasted,
one)),
- SqlNodeList.of(getCastedSqlNode(nullLiteral, varType, pos,
null)),
+ SqlNodeList.of(getCastedSqlNode(nullLiteral, highPrecision,
pos, null)),
SqlStdOperatorTable.MINUS.createCall(pos, countCasted, one));
}
@@ -1706,6 +1707,26 @@ public class StandardConvertletTable extends
ReflectiveConvertletTable {
pos, sumCast, count);
}
+ /**
+ * Compute a higher precision version of a type.
+ *
+ * @return If type is a DECIMAL type, return a type with double the
precision and scale
+ * if possible. Otherwise, return the type unchanged. */
+ private static RelDataType highPrecision(final SqlRexContext cx, final
RelDataType type) {
+ if (type.getSqlTypeName() == SqlTypeName.DECIMAL) {
+ RelDataTypeFactory typeFactory = cx.getValidator().getTypeFactory();
+ return typeFactory.createSqlType(
+ type.getSqlTypeName(),
+ Math.min(
+ type.getPrecision() * 2,
+
typeFactory.getTypeSystem().getMaxPrecision(SqlTypeName.DECIMAL)),
+ Math.min(
+ type.getScale() * 2,
+ typeFactory.getTypeSystem().getMaxScale(SqlTypeName.DECIMAL)));
+ }
+ return type;
+ }
+
private static SqlNode expandVariance(
final SqlNode argInput,
final RelDataType varType,
@@ -1732,28 +1753,29 @@ public class StandardConvertletTable extends
ReflectiveConvertletTable {
// (sum(x * x) - sum(x) * sum(x) / count(x))
// / (count(x) - 1)
final SqlParserPos pos = SqlParserPos.ZERO;
+ final RelDataType highPrecision = highPrecision(cx, varType);
final SqlNode arg =
- getCastedSqlNode(argInput, varType, pos,
+ getCastedSqlNode(argInput, highPrecision, pos,
cx.convertExpression(argInput));
final SqlNode argSquared =
SqlStdOperatorTable.MULTIPLY.createCall(pos, arg, arg);
final SqlNode argSquaredCasted =
- getCastedSqlNode(argSquared, varType, pos,
+ getCastedSqlNode(argSquared, highPrecision, pos,
cx.convertExpression(argSquared));
final SqlNode sumArgSquared =
SqlStdOperatorTable.SUM.createCall(pos, argSquaredCasted);
final SqlNode sumArgSquaredCasted =
- getCastedSqlNode(sumArgSquared, varType, pos,
+ getCastedSqlNode(sumArgSquared, highPrecision, pos,
cx.convertExpression(sumArgSquared));
final SqlNode sum = SqlStdOperatorTable.SUM.createCall(pos, arg);
final SqlNode sumCasted =
- getCastedSqlNode(sum, varType, pos, cx.convertExpression(sum));
+ getCastedSqlNode(sum, highPrecision, pos, cx.convertExpression(sum));
final SqlNode sumSquared =
SqlStdOperatorTable.MULTIPLY.createCall(pos, sumCasted, sumCasted);
final SqlNode sumSquaredCasted =
- getCastedSqlNode(sumSquared, varType, pos,
+ getCastedSqlNode(sumSquared, highPrecision, pos,
cx.convertExpression(sumSquared));
final SqlNode count = SqlStdOperatorTable.COUNT.createCall(pos, arg);
final SqlNode countCasted =
@@ -1762,13 +1784,13 @@ public class StandardConvertletTable extends
ReflectiveConvertletTable {
SqlStdOperatorTable.DIVIDE.createCall(pos, sumSquaredCasted,
countCasted);
final SqlNode avgSumSquaredCasted =
- getCastedSqlNode(avgSumSquared, varType, pos,
+ getCastedSqlNode(avgSumSquared, highPrecision, pos,
cx.convertExpression(avgSumSquared));
final SqlNode diff =
SqlStdOperatorTable.MINUS.createCall(pos, sumArgSquaredCasted,
avgSumSquaredCasted);
final SqlNode diffCasted =
- getCastedSqlNode(diff, varType, pos, cx.convertExpression(diff));
+ getCastedSqlNode(diff, highPrecision, pos,
cx.convertExpression(diff));
final SqlNode denominator;
if (biased) {
denominator = countCasted;
@@ -1779,13 +1801,13 @@ public class StandardConvertletTable extends
ReflectiveConvertletTable {
new SqlCase(SqlParserPos.ZERO, count,
SqlNodeList.of(
SqlStdOperatorTable.EQUALS.createCall(pos, count, one)),
- SqlNodeList.of(getCastedSqlNode(nullLiteral, varType, pos,
null)),
+ SqlNodeList.of(getCastedSqlNode(nullLiteral, highPrecision,
pos, null)),
SqlStdOperatorTable.MINUS.createCall(pos, count, one));
}
final SqlNode div =
SqlStdOperatorTable.DIVIDE.createCall(pos, diffCasted, denominator);
final SqlNode divCasted =
- getCastedSqlNode(div, varType, pos, cx.convertExpression(div));
+ getCastedSqlNode(div, highPrecision, pos, cx.convertExpression(div));
SqlNode result = div;
if (sqrt) {