[CALCITE-2402] Implement regr functions: COVAR_POP, COVAR_SAMP, REGR_COUNT, REGR_SXX, REGR_SYY
Use filters in case of AggregateReduceFunctionsRule expansions. Close apache/calcite#779 Project: http://git-wip-us.apache.org/repos/asf/calcite/repo Commit: http://git-wip-us.apache.org/repos/asf/calcite/commit/ca858dd7 Tree: http://git-wip-us.apache.org/repos/asf/calcite/tree/ca858dd7 Diff: http://git-wip-us.apache.org/repos/asf/calcite/diff/ca858dd7 Branch: refs/heads/master Commit: ca858dd725dea6bf9b4a9059cf1c3ba98bd82f26 Parents: 5574873 Author: snuyanzin <[email protected]> Authored: Fri Jun 29 11:41:24 2018 +0300 Committer: Julian Hyde <[email protected]> Committed: Sun Aug 12 18:04:44 2018 -0700 ---------------------------------------------------------------------- core/src/main/codegen/templates/Parser.jj | 1 + .../calcite/adapter/enumerable/RexImpTable.java | 3 + .../rel/rules/AggregateReduceFunctionsRule.java | 259 ++++++++++++++++++- .../java/org/apache/calcite/sql/SqlKind.java | 23 +- .../calcite/sql/fun/SqlCountAggFunction.java | 10 +- .../calcite/sql/fun/SqlCovarAggFunction.java | 8 +- .../sql/fun/SqlRegrCountAggFunction.java | 37 +++ .../calcite/sql/fun/SqlStdOperatorTable.java | 6 + .../apache/calcite/sql/type/ReturnTypes.java | 2 +- .../sql2rel/StandardConvertletTable.java | 250 ++++++++++++++---- .../apache/calcite/sql/test/SqlAdvisorTest.java | 1 + core/src/test/resources/sql/agg.iq | 97 +++++++ core/src/test/resources/sql/winagg.iq | 133 ++++++++++ site/_docs/reference.md | 2 +- 14 files changed, 761 insertions(+), 71 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/calcite/blob/ca858dd7/core/src/main/codegen/templates/Parser.jj ---------------------------------------------------------------------- diff --git a/core/src/main/codegen/templates/Parser.jj b/core/src/main/codegen/templates/Parser.jj index be05d9c..0dc23eb 100644 --- a/core/src/main/codegen/templates/Parser.jj +++ b/core/src/main/codegen/templates/Parser.jj @@ -5171,6 +5171,7 @@ SqlIdentifier ReservedFunctionName() : | <PERCENT_RANK> | <POWER> | <RANK> + | <REGR_COUNT> | <REGR_SXX> | <REGR_SYY> | <ROW_NUMBER> http://git-wip-us.apache.org/repos/asf/calcite/blob/ca858dd7/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java index 5ba5959..80d5541 100644 --- a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java +++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java @@ -191,6 +191,7 @@ import static org.apache.calcite.sql.fun.SqlStdOperatorTable.RADIANS; import static org.apache.calcite.sql.fun.SqlStdOperatorTable.RAND; import static org.apache.calcite.sql.fun.SqlStdOperatorTable.RAND_INTEGER; import static org.apache.calcite.sql.fun.SqlStdOperatorTable.RANK; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.REGR_COUNT; import static org.apache.calcite.sql.fun.SqlStdOperatorTable.REINTERPRET; import static org.apache.calcite.sql.fun.SqlStdOperatorTable.REPLACE; import static org.apache.calcite.sql.fun.SqlStdOperatorTable.ROUND; @@ -438,6 +439,7 @@ public class RexImpTable { map.put(LOCALTIMESTAMP, systemFunctionImplementor); aggMap.put(COUNT, constructorSupplier(CountImplementor.class)); + aggMap.put(REGR_COUNT, constructorSupplier(CountImplementor.class)); aggMap.put(SUM0, constructorSupplier(SumImplementor.class)); aggMap.put(SUM, constructorSupplier(SumImplementor.class)); Supplier<MinMaxImplementor> minMax = @@ -464,6 +466,7 @@ public class RexImpTable { winAggMap.put(LAG, constructorSupplier(LagImplementor.class)); winAggMap.put(NTILE, constructorSupplier(NtileImplementor.class)); winAggMap.put(COUNT, constructorSupplier(CountWinImplementor.class)); + winAggMap.put(REGR_COUNT, constructorSupplier(CountWinImplementor.class)); } private <T> Supplier<T> constructorSupplier(Class<T> klass) { http://git-wip-us.apache.org/repos/asf/calcite/blob/ca858dd7/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java b/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java index 8bdd6c1..68f6b16 100644 --- a/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java +++ b/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java @@ -72,6 +72,17 @@ import java.util.Map; * * <li>VAR_SAMP(x) → (SUM(x * x) - SUM(x) * SUM(x) / COUNT(x)) * / CASE COUNT(x) WHEN 1 THEN NULL ELSE COUNT(x) - 1 END + * + * <li>COVAR_POP(x, y) → (SUM(x * y) - SUM(x, y) * SUM(y, x) + * / REGR_COUNT(x, y)) / REGR_COUNT(x, y) + * + * <li>COVAR_SAMP(x, y) → (SUM(x * y) - SUM(x, y) * SUM(y, x) / REGR_COUNT(x, y)) + * / CASE REGR_COUNT(x, y) WHEN 1 THEN NULL ELSE REGR_COUNT(x, y) - 1 END + * + * <li>REGR_SXX(x, y) → REGR_COUNT(x, y) * VAR_POP(y) + * + * <li>REGR_SYY(x, y) → REGR_COUNT(x, y) * VAR_POP(x) + * * </ul> * * <p>Since many of these rewrites introduce multiple occurrences of simpler @@ -127,7 +138,8 @@ public class AggregateReduceFunctionsRule extends RelOptRule { * Returns whether the aggregate call is a reducible function */ private boolean isReducible(final SqlKind kind) { - if (SqlKind.AVG_AGG_FUNCTIONS.contains(kind)) { + if (SqlKind.AVG_AGG_FUNCTIONS.contains(kind) + || SqlKind.COVAR_AVG_AGG_FUNCTIONS.contains(kind)) { return true; } switch (kind) { @@ -201,6 +213,8 @@ public class AggregateReduceFunctionsRule extends RelOptRule { List<RexNode> inputExprs) { final SqlKind kind = oldCall.getAggregation().getKind(); if (isReducible(kind)) { + final Integer y; + final Integer x; switch (kind) { case SUM: // replace original SUM(x) with @@ -209,6 +223,37 @@ public class AggregateReduceFunctionsRule extends RelOptRule { case AVG: // replace original AVG(x) with SUM(x) / COUNT(x) return reduceAvg(oldAggRel, oldCall, newCalls, aggCallMapping, inputExprs); + case COVAR_POP: + // replace original COVAR_POP(x, y) with + // (SUM(x * y) - SUM(y) * SUM(y) / COUNT(x)) + // / COUNT(x)) + return reduceCovariance(oldAggRel, oldCall, true, newCalls, + aggCallMapping, inputExprs); + case COVAR_SAMP: + // replace original COVAR_SAMP(x, y) with + // SQRT( + // (SUM(x * y) - SUM(x) * SUM(y) / COUNT(x)) + // / CASE COUNT(x) WHEN 1 THEN NULL ELSE COUNT(x) - 1 END) + return reduceCovariance(oldAggRel, oldCall, false, newCalls, + aggCallMapping, inputExprs); + case REGR_SXX: + // replace original REGR_SXX(x, y) with + // REGR_COUNT(x, y) * VAR_POP(y) + assert oldCall.getArgList().size() == 2 : oldCall.getArgList(); + x = oldCall.getArgList().get(0); + y = oldCall.getArgList().get(1); + //noinspection SuspiciousNameCombination + return reduceRegrSzz(oldAggRel, oldCall, newCalls, aggCallMapping, + inputExprs, y, y, x); + case REGR_SYY: + // replace original REGR_SYY(x, y) with + // REGR_COUNT(x, y) * VAR_POP(x) + assert oldCall.getArgList().size() == 2 : oldCall.getArgList(); + x = oldCall.getArgList().get(0); + y = oldCall.getArgList().get(1); + //noinspection SuspiciousNameCombination + return reduceRegrSzz(oldAggRel, oldCall, newCalls, aggCallMapping, + inputExprs, x, x, y); case STDDEV_POP: // replace original STDDEV_POP(x) with // SQRT( @@ -260,16 +305,17 @@ public class AggregateReduceFunctionsRule extends RelOptRule { RelDataType operandType, Aggregate oldAggRel, AggregateCall oldCall, - int argOrdinal) { + int argOrdinal, + int filter) { final Aggregate.AggCallBinding binding = new Aggregate.AggCallBinding(typeFactory, aggFunction, ImmutableList.of(operandType), oldAggRel.getGroupCount(), - oldCall.filterArg >= 0); + filter >= 0); return AggregateCall.create(aggFunction, oldCall.isDistinct(), oldCall.isApproximate(), ImmutableIntList.of(argOrdinal), - oldCall.filterArg, + filter, aggFunction.inferReturnType(binding), null); } @@ -346,6 +392,7 @@ public class AggregateReduceFunctionsRule extends RelOptRule { getFieldType( oldAggRel.getInput(), arg); + final AggregateCall sumZeroCall = AggregateCall.create(SqlStdOperatorTable.SUM0, oldCall.isDistinct(), oldCall.isApproximate(), oldCall.getArgList(), oldCall.filterArg, @@ -424,7 +471,6 @@ public class AggregateReduceFunctionsRule extends RelOptRule { final RexNode argRef = rexBuilder.ensureType(oldCallType, inputExprs.get(argOrdinal), true); - final int argRefOrdinal = lookupOrAdd(inputExprs, argRef); final RexNode argSquared = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, argRef, argRef); @@ -432,7 +478,7 @@ public class AggregateReduceFunctionsRule extends RelOptRule { final AggregateCall sumArgSquaredAggCall = createAggregateCallWithBinding(typeFactory, SqlStdOperatorTable.SUM, - argSquared.getType(), oldAggRel, oldCall, argSquaredOrdinal); + argSquared.getType(), oldAggRel, oldCall, argSquaredOrdinal, -1); final RexNode sumArgSquared = rexBuilder.addAggCall(sumArgSquaredAggCall, @@ -530,6 +576,207 @@ public class AggregateReduceFunctionsRule extends RelOptRule { oldCall.getType(), result); } + private RexNode getSumAggregatedRexNode(Aggregate oldAggRel, + AggregateCall oldCall, + List<AggregateCall> newCalls, + Map<AggregateCall, RexNode> aggCallMapping, + RexBuilder rexBuilder, + int argOrdinal, + int filterArg) { + final AggregateCall aggregateCall = + AggregateCall.create(SqlStdOperatorTable.SUM, + oldCall.isDistinct(), + oldCall.isApproximate(), + ImmutableIntList.of(argOrdinal), + filterArg, + oldAggRel.getGroupCount(), + oldAggRel.getInput(), + null, + null); + return rexBuilder.addAggCall(aggregateCall, + oldAggRel.getGroupCount(), + oldAggRel.indicator, + newCalls, + aggCallMapping, + ImmutableList.of(aggregateCall.getType())); + } + + private RexNode getSumAggregatedRexNodeWithBinding(Aggregate oldAggRel, + AggregateCall oldCall, + List<AggregateCall> newCalls, + Map<AggregateCall, RexNode> aggCallMapping, + RelDataType operandType, + int argOrdinal, + int filter) { + RelOptCluster cluster = oldAggRel.getCluster(); + final AggregateCall sumArgSquaredAggCall = + createAggregateCallWithBinding(cluster.getTypeFactory(), + SqlStdOperatorTable.SUM, operandType, oldAggRel, oldCall, argOrdinal, filter); + + return cluster.getRexBuilder().addAggCall(sumArgSquaredAggCall, + oldAggRel.getGroupCount(), + oldAggRel.indicator, + newCalls, + aggCallMapping, + ImmutableList.of(sumArgSquaredAggCall.getType())); + } + + private RexNode getRegrCountRexNode(Aggregate oldAggRel, + AggregateCall oldCall, + List<AggregateCall> newCalls, + Map<AggregateCall, RexNode> aggCallMapping, + ImmutableIntList argOrdinals, + ImmutableList<RelDataType> operandTypes, + int filterArg) { + final AggregateCall countArgAggCall = + AggregateCall.create(SqlStdOperatorTable.REGR_COUNT, + oldCall.isDistinct(), + oldCall.isApproximate(), + argOrdinals, + filterArg, + oldAggRel.getGroupCount(), + oldAggRel, + null, + null); + + return oldAggRel.getCluster().getRexBuilder().addAggCall(countArgAggCall, + oldAggRel.getGroupCount(), + oldAggRel.indicator, + newCalls, + aggCallMapping, + operandTypes); + } + + private RexNode reduceRegrSzz( + Aggregate oldAggRel, + AggregateCall oldCall, + List<AggregateCall> newCalls, + Map<AggregateCall, RexNode> aggCallMapping, + List<RexNode> inputExprs, + int xIndex, + int yIndex, + int nullFilterIndex) { + // regr_sxx(x, y) ==> + // sum(y * y, x) - sum(y, x) * sum(y, x) / regr_count(x, y) + // + + final RelOptCluster cluster = oldAggRel.getCluster(); + final RexBuilder rexBuilder = cluster.getRexBuilder(); + final RelDataTypeFactory typeFactory = cluster.getTypeFactory(); + final RelDataType argXType = getFieldType(oldAggRel.getInput(), xIndex); + final RelDataType argYType = + xIndex == yIndex ? argXType : getFieldType(oldAggRel.getInput(), yIndex); + final RelDataType nullFilterIndexType = + nullFilterIndex == yIndex ? argYType : getFieldType(oldAggRel.getInput(), yIndex); + + final RelDataType oldCallType = + typeFactory.createTypeWithNullability(oldCall.getType(), + argXType.isNullable() || argYType.isNullable() || nullFilterIndexType.isNullable()); + + final RexNode argX = + rexBuilder.ensureType(oldCallType, inputExprs.get(xIndex), true); + final RexNode argY = + rexBuilder.ensureType(oldCallType, inputExprs.get(yIndex), true); + final RexNode argNullFilter = + rexBuilder.ensureType(oldCallType, inputExprs.get(nullFilterIndex), true); + + final RexNode argXArgY = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, argX, argY); + final int argSquaredOrdinal = lookupOrAdd(inputExprs, argXArgY); + + final RexNode argXAndYNotNullFilter = rexBuilder.makeCall(SqlStdOperatorTable.AND, + rexBuilder.makeCall(SqlStdOperatorTable.AND, + rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, argX), + rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, argY)), + rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, argNullFilter)); + final int argXAndYNotNullFilterOrdinal = lookupOrAdd(inputExprs, argXAndYNotNullFilter); + final RexNode sumXY = getSumAggregatedRexNodeWithBinding( + oldAggRel, oldCall, newCalls, aggCallMapping, argXArgY.getType(), + argSquaredOrdinal, argXAndYNotNullFilterOrdinal); + final RexNode sumXYCast = rexBuilder.ensureType(oldCallType, sumXY, true); + + final RexNode sumX = getSumAggregatedRexNode(oldAggRel, oldCall, + newCalls, aggCallMapping, rexBuilder, xIndex, argXAndYNotNullFilterOrdinal); + final RexNode sumY = xIndex == yIndex + ? sumX + : getSumAggregatedRexNode(oldAggRel, oldCall, newCalls, + aggCallMapping, rexBuilder, yIndex, argXAndYNotNullFilterOrdinal); + + final RexNode sumXSumY = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, sumX, sumY); + + final RexNode countArg = getRegrCountRexNode(oldAggRel, oldCall, newCalls, aggCallMapping, + ImmutableIntList.of(xIndex), ImmutableList.of(argXType), argXAndYNotNullFilterOrdinal); + + RexLiteral zero = rexBuilder.makeExactLiteral(BigDecimal.ZERO); + RexNode nul = rexBuilder.constantNull(); + final RexNode avgSumXSumY = rexBuilder.makeCall(SqlStdOperatorTable.CASE, + rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, countArg, zero), nul, + rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, sumXSumY, countArg)); + final RexNode avgSumXSumYCast = rexBuilder.ensureType(oldCallType, avgSumXSumY, true); + final RexNode result = + rexBuilder.makeCall(SqlStdOperatorTable.MINUS, sumXYCast, avgSumXSumYCast); + return rexBuilder.makeCast(oldCall.getType(), result); + } + + private RexNode reduceCovariance( + Aggregate oldAggRel, + AggregateCall oldCall, + boolean biased, + List<AggregateCall> newCalls, + Map<AggregateCall, RexNode> aggCallMapping, + List<RexNode> inputExprs) { + // covar_pop(x, y) ==> + // (sum(x * y) - sum(x) * sum(y) / regr_count(x, y)) + // / regr_count(x, y) + // + // covar_samp(x, y) ==> + // (sum(x * y) - sum(x) * sum(y) / regr_count(x, y)) + // / regr_count(count(x, y) - 1, 0) + final RelOptCluster cluster = oldAggRel.getCluster(); + final RexBuilder rexBuilder = cluster.getRexBuilder(); + final RelDataTypeFactory typeFactory = cluster.getTypeFactory(); + assert oldCall.getArgList().size() == 2 : oldCall.getArgList(); + final int argXOrdinal = oldCall.getArgList().get(0); + final int argYOrdinal = oldCall.getArgList().get(1); + final RelDataType argXOrdinalType = getFieldType(oldAggRel.getInput(), argXOrdinal); + final RelDataType argYOrdinalType = getFieldType(oldAggRel.getInput(), argYOrdinal); + final RelDataType oldCallType = typeFactory.createTypeWithNullability(oldCall.getType(), + argXOrdinalType.isNullable() || argYOrdinalType.isNullable()); + final RexNode argX = rexBuilder.ensureType(oldCallType, inputExprs.get(argXOrdinal), true); + final RexNode argY = rexBuilder.ensureType(oldCallType, inputExprs.get(argYOrdinal), true); + final RexNode argXAndYNotNullFilter = rexBuilder.makeCall(SqlStdOperatorTable.AND, + rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, argX), + rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, argY)); + final int argXAndYNotNullFilterOrdinal = lookupOrAdd(inputExprs, argXAndYNotNullFilter); + final RexNode argXY = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, argX, argY); + final int argXYOrdinal = lookupOrAdd(inputExprs, argXY); + final RexNode sumXY = getSumAggregatedRexNodeWithBinding(oldAggRel, oldCall, newCalls, + aggCallMapping, argXY.getType(), argXYOrdinal, argXAndYNotNullFilterOrdinal); + final RexNode sumX = getSumAggregatedRexNode(oldAggRel, oldCall, newCalls, + aggCallMapping, rexBuilder, argXOrdinal, argXAndYNotNullFilterOrdinal); + final RexNode sumY = getSumAggregatedRexNode(oldAggRel, oldCall, newCalls, + aggCallMapping, rexBuilder, argYOrdinal, argXAndYNotNullFilterOrdinal); + final RexNode sumXSumY = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, sumX, sumY); + final RexNode countArg = getRegrCountRexNode(oldAggRel, oldCall, newCalls, aggCallMapping, + ImmutableIntList.of(argXOrdinal, argYOrdinal), + ImmutableList.of(argXOrdinalType, argYOrdinalType), + argXAndYNotNullFilterOrdinal); + final RexNode avgSumSquaredArg = + rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, sumXSumY, countArg); + final RexNode diff = rexBuilder.makeCall(SqlStdOperatorTable.MINUS, sumXY, avgSumSquaredArg); + final RexNode denominator; + if (biased) { + denominator = countArg; + } else { + final RexLiteral one = rexBuilder.makeExactLiteral(BigDecimal.ONE); + final RexNode nul = rexBuilder.makeCast(countArg.getType(), rexBuilder.constantNull()); + final RexNode countMinusOne = rexBuilder.makeCall(SqlStdOperatorTable.MINUS, countArg, one); + final RexNode countEqOne = rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, countArg, one); + denominator = rexBuilder.makeCall(SqlStdOperatorTable.CASE, countEqOne, nul, countMinusOne); + } + final RexNode result = rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, diff, denominator); + return rexBuilder.makeCast(oldCall.getType(), result); + } + /** * Finds the ordinal of an element in a list, or adds it. * http://git-wip-us.apache.org/repos/asf/calcite/blob/ca858dd7/core/src/main/java/org/apache/calcite/sql/SqlKind.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/calcite/sql/SqlKind.java b/core/src/main/java/org/apache/calcite/sql/SqlKind.java index 54b390b..cbf201f 100644 --- a/core/src/main/java/org/apache/calcite/sql/SqlKind.java +++ b/core/src/main/java/org/apache/calcite/sql/SqlKind.java @@ -810,14 +810,11 @@ public enum SqlKind { /** The {@code GROUP_ID()} function. */ GROUP_ID, - /** - * the internal permute function in match_recognize cluse - */ + /** The internal "permute" function in a MATCH_RECOGNIZE clause. */ PATTERN_PERMUTE, - /** - * the special patterns to exclude enclosing pattern from output in match_recognize clause - */ + /** The special patterns to exclude enclosing pattern from output in a + * MATCH_RECOGNIZE clause. */ PATTERN_EXCLUDED, // Aggregate functions @@ -858,6 +855,9 @@ public enum SqlKind { /** The {@code COVAR_SAMP} aggregate function. */ COVAR_SAMP, + /** The {@code REGR_COUNT} aggregate function. */ + REGR_COUNT, + /** The {@code REGR_SXX} aggregate function. */ REGR_SXX, @@ -1064,7 +1064,7 @@ public enum SqlKind { */ public static final EnumSet<SqlKind> AGGREGATE = EnumSet.of(COUNT, SUM, SUM0, MIN, MAX, LEAD, LAG, FIRST_VALUE, - LAST_VALUE, COVAR_POP, COVAR_SAMP, REGR_SXX, REGR_SYY, + LAST_VALUE, COVAR_POP, COVAR_SAMP, REGR_COUNT, REGR_SXX, REGR_SYY, AVG, STDDEV_POP, STDDEV_SAMP, VAR_POP, VAR_SAMP, NTILE, COLLECT, FUSION, SINGLE_VALUE, ROW_NUMBER, RANK, PERCENT_RANK, DENSE_RANK, CUME_DIST); @@ -1180,6 +1180,15 @@ public enum SqlKind { EnumSet.of(AVG, STDDEV_POP, STDDEV_SAMP, VAR_POP, VAR_SAMP); /** + * Category of SqlCovarAggFunction. + * + * <p>Consists of {@link #COVAR_POP}, {@link #COVAR_SAMP}, {@link #REGR_SXX}, + * {@link #REGR_SYY}. + */ + public static final Set<SqlKind> COVAR_AVG_AGG_FUNCTIONS = + EnumSet.of(COVAR_POP, COVAR_SAMP, REGR_COUNT, REGR_SXX, REGR_SYY); + + /** * Category of comparison operators. * * <p>Consists of: http://git-wip-us.apache.org/repos/asf/calcite/blob/ca858dd7/core/src/main/java/org/apache/calcite/sql/fun/SqlCountAggFunction.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/calcite/sql/fun/SqlCountAggFunction.java b/core/src/main/java/org/apache/calcite/sql/fun/SqlCountAggFunction.java index e053294..db54102 100644 --- a/core/src/main/java/org/apache/calcite/sql/fun/SqlCountAggFunction.java +++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlCountAggFunction.java @@ -26,6 +26,7 @@ import org.apache.calcite.sql.SqlSplittableAggFunction; import org.apache.calcite.sql.SqlSyntax; import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.ReturnTypes; +import org.apache.calcite.sql.type.SqlOperandTypeChecker; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql.validate.SqlValidator; import org.apache.calcite.sql.validate.SqlValidatorScope; @@ -45,11 +46,12 @@ public class SqlCountAggFunction extends SqlAggFunction { //~ Constructors ----------------------------------------------------------- public SqlCountAggFunction(String name) { + this(name, SqlValidator.STRICT ? OperandTypes.ANY : OperandTypes.ONE_OR_MORE); + } + + public SqlCountAggFunction(String name, SqlOperandTypeChecker sqlOperandTypeChecker) { super(name, null, SqlKind.COUNT, ReturnTypes.BIGINT, null, - SqlValidator.STRICT - ? OperandTypes.ANY - : OperandTypes.ONE_OR_MORE, - SqlFunctionCategory.NUMERIC, false, false); + sqlOperandTypeChecker, SqlFunctionCategory.NUMERIC, false, false); } //~ Methods ---------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/calcite/blob/ca858dd7/core/src/main/java/org/apache/calcite/sql/fun/SqlCovarAggFunction.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/calcite/sql/fun/SqlCovarAggFunction.java b/core/src/main/java/org/apache/calcite/sql/fun/SqlCovarAggFunction.java index 8c62290..8591959 100644 --- a/core/src/main/java/org/apache/calcite/sql/fun/SqlCovarAggFunction.java +++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlCovarAggFunction.java @@ -43,16 +43,14 @@ public class SqlCovarAggFunction extends SqlAggFunction { super(kind.name(), null, kind, - ReturnTypes.COVAR_FUNCTION, + kind == SqlKind.REGR_COUNT ? ReturnTypes.BIGINT : ReturnTypes.COVAR_REGR_FUNCTION, null, OperandTypes.NUMERIC_NUMERIC, SqlFunctionCategory.NUMERIC, false, false); - Preconditions.checkArgument(kind == SqlKind.COVAR_POP - || kind == SqlKind.COVAR_SAMP - || kind == SqlKind.REGR_SXX - || kind == SqlKind.REGR_SYY); + Preconditions.checkArgument(SqlKind.COVAR_AVG_AGG_FUNCTIONS.contains(kind), + "unsupported sql kind: " + kind); } @Deprecated // to be removed before 2.0 http://git-wip-us.apache.org/repos/asf/calcite/blob/ca858dd7/core/src/main/java/org/apache/calcite/sql/fun/SqlRegrCountAggFunction.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/calcite/sql/fun/SqlRegrCountAggFunction.java b/core/src/main/java/org/apache/calcite/sql/fun/SqlRegrCountAggFunction.java new file mode 100644 index 0000000..4408272 --- /dev/null +++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlRegrCountAggFunction.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.calcite.sql.fun; + +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.type.OperandTypes; + +import com.google.common.base.Preconditions; + +/** + * Definition of the SQL <code>REGR_COUNT</code> aggregation function. + * + * <p><code>REGR_COUNT</code> is an aggregator which returns the number of rows which + * have gone into it and both arguments are not <code>null</code>. + */ +public class SqlRegrCountAggFunction extends SqlCountAggFunction { + public SqlRegrCountAggFunction(SqlKind kind) { + super("REGR_COUNT", OperandTypes.NUMERIC_NUMERIC); + Preconditions.checkArgument(SqlKind.REGR_COUNT == kind, "unsupported sql kind: " + kind); + } +} + +// End SqlRegrCountAggFunction.java http://git-wip-us.apache.org/repos/asf/calcite/blob/ca858dd7/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java b/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java index 0064cba..ea17ec8 100644 --- a/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java +++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java @@ -915,6 +915,12 @@ public class SqlStdOperatorTable extends ReflectiveSqlOperatorTable { new SqlAvgAggFunction(SqlKind.STDDEV_POP); /** + * <code>REGR_COUNT</code> aggregate function. + */ + public static final SqlAggFunction REGR_COUNT = + new SqlRegrCountAggFunction(SqlKind.REGR_COUNT); + + /** * <code>REGR_SXX</code> aggregate function. */ public static final SqlAggFunction REGR_SXX = http://git-wip-us.apache.org/repos/asf/calcite/blob/ca858dd7/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java b/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java index 8b07b83..fc0022d 100644 --- a/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java +++ b/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java @@ -766,7 +766,7 @@ public abstract class ReturnTypes { } }; - public static final SqlReturnTypeInference COVAR_FUNCTION = opBinding -> { + public static final SqlReturnTypeInference COVAR_REGR_FUNCTION = opBinding -> { final RelDataTypeFactory typeFactory = opBinding.getTypeFactory(); final RelDataType relDataType = typeFactory.getTypeSystem().deriveCovarType(typeFactory, http://git-wip-us.apache.org/repos/asf/calcite/blob/ca858dd7/core/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java b/core/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java index e9b7cf6..987e821 100644 --- a/core/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java +++ b/core/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java @@ -67,6 +67,7 @@ import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql.type.SqlTypeUtil; import org.apache.calcite.sql.validate.SqlValidator; +import org.apache.calcite.sql.validate.SqlValidatorImpl; import org.apache.calcite.util.Pair; import org.apache.calcite.util.Util; @@ -77,6 +78,7 @@ import java.math.BigDecimal; import java.math.RoundingMode; import java.util.ArrayList; import java.util.List; +import java.util.Objects; /** * Standard implementation of {@link SqlRexConvertletTable}. @@ -237,6 +239,14 @@ public class StandardConvertletTable extends ReflectiveConvertletTable { new AvgVarianceConvertlet(SqlKind.VAR_SAMP)); registerOp(SqlStdOperatorTable.VARIANCE, new AvgVarianceConvertlet(SqlKind.VAR_SAMP)); + registerOp(SqlStdOperatorTable.COVAR_POP, + new RegrCovarianceConvertlet(SqlKind.COVAR_POP)); + registerOp(SqlStdOperatorTable.COVAR_SAMP, + new RegrCovarianceConvertlet(SqlKind.COVAR_SAMP)); + registerOp(SqlStdOperatorTable.REGR_SXX, + new RegrCovarianceConvertlet(SqlKind.REGR_SXX)); + registerOp(SqlStdOperatorTable.REGR_SYY, + new RegrCovarianceConvertlet(SqlKind.REGR_SYY)); final SqlRexConvertlet floorCeilConvertlet = new FloorCeilConvertlet(); registerOp(SqlStdOperatorTable.FLOOR, floorCeilConvertlet); @@ -342,14 +352,26 @@ public class StandardConvertletTable extends ReflectiveConvertletTable { SqlNodeList thenList = call.getThenOperands(); assert whenList.size() == thenList.size(); + RexBuilder rexBuilder = cx.getRexBuilder(); final List<RexNode> exprList = new ArrayList<>(); for (int i = 0; i < whenList.size(); i++) { - exprList.add(cx.convertExpression(whenList.get(i))); - exprList.add(cx.convertExpression(thenList.get(i))); + if (SqlUtil.isNullLiteral(whenList.get(i), false)) { + exprList.add(rexBuilder.constantNull()); + } else { + exprList.add(cx.convertExpression(whenList.get(i))); + } + if (SqlUtil.isNullLiteral(thenList.get(i), false)) { + exprList.add(rexBuilder.constantNull()); + } else { + exprList.add(cx.convertExpression(thenList.get(i))); + } + } + if (SqlUtil.isNullLiteral(call.getElseOperand(), false)) { + exprList.add(rexBuilder.constantNull()); + } else { + exprList.add(cx.convertExpression(call.getElseOperand())); } - exprList.add(cx.convertExpression(call.getElseOperand())); - RexBuilder rexBuilder = cx.getRexBuilder(); RelDataType type = rexBuilder.deriveReturnType(call.getOperator(), exprList); for (int i : elseArgs(exprList.size())) { @@ -473,11 +495,13 @@ public class StandardConvertletTable extends ReflectiveConvertletTable { return castToValidatedType(cx, call, cx.convertExpression(left)); } SqlDataTypeSpec dataType = (SqlDataTypeSpec) right; + RelDataType type = dataType.deriveType(typeFactory); if (SqlUtil.isNullLiteral(left, false)) { + final SqlValidatorImpl validator = (SqlValidatorImpl) cx.getValidator(); + validator.setValidatedNodeType(left, type); return cx.convertExpression(left); } RexNode arg = cx.convertExpression(left); - RelDataType type = dataType.deriveType(typeFactory); if (type == null) { type = cx.getValidator().getValidatedNodeType(dataType.getTypeName()); } @@ -1061,6 +1085,133 @@ public class StandardConvertletTable extends ReflectiveConvertletTable { return rexBuilder.makeCast(type, e); } + /** Convertlet that handles {@code COVAR_POP}, {@code COVAR_SAMP}, + * {@code REGR_SXX}, {@code REGR_SYY} windowed aggregate functions. + */ + private static class RegrCovarianceConvertlet implements SqlRexConvertlet { + private final SqlKind kind; + + RegrCovarianceConvertlet(SqlKind kind) { + this.kind = kind; + } + + public RexNode convertCall(SqlRexContext cx, SqlCall call) { + assert call.operandCount() == 2; + final SqlNode arg1 = call.operand(0); + final SqlNode arg2 = call.operand(1); + final SqlNode expr; + final RelDataType type = + cx.getValidator().getValidatedNodeType(call); + switch (kind) { + case COVAR_POP: + expr = expandCovariance(arg1, arg2, null, type, cx, true); + break; + case COVAR_SAMP: + expr = expandCovariance(arg1, arg2, null, type, cx, false); + break; + case REGR_SXX: + expr = expandRegrSzz(arg2, arg1, type, cx, true); + break; + case REGR_SYY: + expr = expandRegrSzz(arg1, arg2, type, cx, true); + break; + default: + throw Util.unexpected(kind); + } + RexNode rex = cx.convertExpression(expr); + return cx.getRexBuilder().ensureType(type, rex, true); + } + + private SqlNode expandRegrSzz( + final SqlNode arg1, final SqlNode arg2, + final RelDataType avgType, final SqlRexContext cx, boolean variance) { + final SqlParserPos pos = SqlParserPos.ZERO; + final SqlNode count = + SqlStdOperatorTable.REGR_COUNT.createCall(pos, arg1, arg2); + final SqlNode varPop = + expandCovariance(arg1, variance ? arg1 : arg2, arg2, avgType, cx, true); + final RexNode varPopRex = cx.convertExpression(varPop); + final SqlNode varPopCast; + varPopCast = getCastedSqlNode(varPop, avgType, pos, varPopRex); + return SqlStdOperatorTable.MULTIPLY.createCall(pos, varPopCast, count); + } + + private SqlNode expandCovariance( + final SqlNode arg0Input, + final SqlNode arg1Input, + final SqlNode dependent, + final RelDataType varType, + final SqlRexContext cx, + boolean biased) { + // covar_pop(x1, x2) ==> + // (sum(x1 * x2) - sum(x2) * sum(x1) / count(x1, x2)) + // / count(x1, x2) + // + // covar_samp(x1, x2) ==> + // (sum(x1 * x2) - sum(x1) * sum(x2) / count(x1, x2)) + // / (count(x1, x2) - 1) + final SqlParserPos pos = SqlParserPos.ZERO; + final SqlLiteral nullLiteral = SqlLiteral.createNull(SqlParserPos.ZERO); + + final RexNode arg0Rex = cx.convertExpression(arg0Input); + final RexNode arg1Rex = cx.convertExpression(arg1Input); + + final SqlNode arg0 = getCastedSqlNode(arg0Input, varType, pos, arg0Rex); + final SqlNode arg1 = getCastedSqlNode(arg1Input, varType, pos, arg1Rex); + final SqlNode argSquared = SqlStdOperatorTable.MULTIPLY.createCall(pos, arg0, arg1); + final SqlNode sumArgSquared; + final SqlNode sum0; + final SqlNode sum1; + final SqlNode count; + if (dependent == null) { + sumArgSquared = SqlStdOperatorTable.SUM.createCall(pos, argSquared); + sum0 = SqlStdOperatorTable.SUM.createCall(pos, arg0, arg1); + sum1 = SqlStdOperatorTable.SUM.createCall(pos, arg1, arg0); + count = SqlStdOperatorTable.REGR_COUNT.createCall(pos, arg0, arg1); + } else { + sumArgSquared = SqlStdOperatorTable.SUM.createCall(pos, argSquared, dependent); + sum0 = SqlStdOperatorTable.SUM.createCall( + pos, arg0, Objects.equals(dependent, arg0Input) ? arg1 : dependent); + sum1 = SqlStdOperatorTable.SUM.createCall( + pos, arg1, Objects.equals(dependent, arg1Input) ? arg0 : dependent); + count = SqlStdOperatorTable.REGR_COUNT.createCall( + pos, arg0, Objects.equals(dependent, arg0Input) ? arg1 : dependent); + } + + final SqlNode sumSquared = SqlStdOperatorTable.MULTIPLY.createCall(pos, sum0, sum1); + final SqlNode countCasted = + getCastedSqlNode(count, varType, pos, cx.convertExpression(count)); + + final SqlNode avgSumSquared = + SqlStdOperatorTable.DIVIDE.createCall(pos, sumSquared, countCasted); + final SqlNode diff = SqlStdOperatorTable.MINUS.createCall(pos, sumArgSquared, avgSumSquared); + SqlNode denominator; + if (biased) { + denominator = countCasted; + } else { + final SqlNumericLiteral one = SqlLiteral.createExactNumeric("1", pos); + denominator = new SqlCase(SqlParserPos.ZERO, countCasted, + SqlNodeList.of(SqlStdOperatorTable.EQUALS.createCall(pos, countCasted, one)), + SqlNodeList.of(getCastedSqlNode(nullLiteral, varType, pos, null)), + SqlStdOperatorTable.MINUS.createCall(pos, countCasted, one)); + } + + return SqlStdOperatorTable.DIVIDE.createCall(pos, diff, denominator); + } + + private SqlNode getCastedSqlNode(SqlNode argInput, RelDataType varType, + SqlParserPos pos, RexNode argRex) { + SqlNode arg; + if (argRex != null && !argRex.getType().equals(varType)) { + arg = SqlStdOperatorTable.CAST.createCall( + pos, argInput, SqlTypeUtil.convertTypeToSpec(varType)); + } else { + arg = argInput; + } + return arg; + } + } + /** Convertlet that handles {@code AVG} and {@code VARIANCE} * windowed aggregate functions. */ private static class AvgVarianceConvertlet implements SqlRexConvertlet { @@ -1106,14 +1257,7 @@ public class StandardConvertletTable extends ReflectiveConvertletTable { SqlStdOperatorTable.SUM.createCall(pos, arg); final RexNode sumRex = cx.convertExpression(sum); final SqlNode sumCast; - if (!sumRex.getType().equals(avgType)) { - sumCast = SqlStdOperatorTable.CAST.createCall(pos, - new SqlDataTypeSpec( - new SqlIdentifier(avgType.getSqlTypeName().getName(), pos), - avgType.getPrecision(), avgType.getScale(), null, null, pos)); - } else { - sumCast = sum; - } + sumCast = getCastedSqlNode(sum, avgType, pos, sumRex); final SqlNode count = SqlStdOperatorTable.COUNT.createCall(pos, arg); return SqlStdOperatorTable.DIVIDE.createCall( @@ -1147,54 +1291,66 @@ public class StandardConvertletTable extends ReflectiveConvertletTable { // / (count(x) - 1) final SqlParserPos pos = SqlParserPos.ZERO; - final RexNode argRex = cx.convertExpression(argInput); - final SqlNode arg; - if (!argRex.getType().equals(varType)) { - arg = SqlStdOperatorTable.CAST.createCall(pos, - new SqlDataTypeSpec(new SqlIdentifier(varType.getSqlTypeName().getName(), pos), - varType.getPrecision(), varType.getScale(), null, null, pos)); - } else { - arg = argInput; - } + final SqlNode arg = getCastedSqlNode(argInput, varType, pos, cx.convertExpression(argInput)); - final SqlNode argSquared = - SqlStdOperatorTable.MULTIPLY.createCall(pos, arg, arg); - final SqlNode sumArgSquared = - SqlStdOperatorTable.SUM.createCall(pos, argSquared); - final SqlNode sum = - SqlStdOperatorTable.SUM.createCall(pos, arg); + final SqlNode argSquared = SqlStdOperatorTable.MULTIPLY.createCall(pos, arg, arg); + final SqlNode argSquaredCasted = + getCastedSqlNode(argSquared, varType, pos, cx.convertExpression(argSquared)); + final SqlNode sumArgSquared = SqlStdOperatorTable.SUM.createCall(pos, argSquaredCasted); + final SqlNode sumArgSquaredCasted = + getCastedSqlNode(sumArgSquared, varType, pos, cx.convertExpression(sumArgSquared)); + final SqlNode sum = SqlStdOperatorTable.SUM.createCall(pos, arg); + final SqlNode sumCasted = getCastedSqlNode(sum, varType, pos, cx.convertExpression(sum)); final SqlNode sumSquared = - SqlStdOperatorTable.MULTIPLY.createCall(pos, sum, sum); - final SqlNode count = - SqlStdOperatorTable.COUNT.createCall(pos, arg); + SqlStdOperatorTable.MULTIPLY.createCall(pos, sumCasted, sumCasted); + final SqlNode sumSquaredCasted = + getCastedSqlNode(sumSquared, varType, pos, cx.convertExpression(sumSquared)); + final SqlNode count = SqlStdOperatorTable.COUNT.createCall(pos, arg); + final SqlNode countCasted = + getCastedSqlNode(count, varType, pos, cx.convertExpression(count)); final SqlNode avgSumSquared = - SqlStdOperatorTable.DIVIDE.createCall( - pos, sumSquared, count); + SqlStdOperatorTable.DIVIDE.createCall(pos, sumSquaredCasted, countCasted); + final SqlNode avgSumSquaredCasted = + getCastedSqlNode(avgSumSquared, varType, pos, cx.convertExpression(avgSumSquared)); final SqlNode diff = - SqlStdOperatorTable.MINUS.createCall( - pos, sumArgSquared, avgSumSquared); + SqlStdOperatorTable.MINUS.createCall(pos, sumArgSquaredCasted, avgSumSquaredCasted); + final SqlNode diffCasted = + getCastedSqlNode(diff, varType, pos, cx.convertExpression(diff)); final SqlNode denominator; if (biased) { - denominator = count; + denominator = countCasted; } else { - final SqlNumericLiteral one = - SqlLiteral.createExactNumeric("1", pos); - denominator = - SqlStdOperatorTable.MINUS.createCall( - pos, count, one); + final SqlNumericLiteral one = SqlLiteral.createExactNumeric("1", pos); + final SqlLiteral nullLiteral = SqlLiteral.createNull(SqlParserPos.ZERO); + denominator = new SqlCase(SqlParserPos.ZERO, + count, + SqlNodeList.of(SqlStdOperatorTable.EQUALS.createCall(pos, count, one)), + SqlNodeList.of(getCastedSqlNode(nullLiteral, varType, pos, null)), + SqlStdOperatorTable.MINUS.createCall(pos, count, one)); } final SqlNode div = - SqlStdOperatorTable.DIVIDE.createCall( - pos, diff, denominator); + SqlStdOperatorTable.DIVIDE.createCall(pos, diffCasted, denominator); + final SqlNode divCasted = getCastedSqlNode(div, varType, pos, cx.convertExpression(div)); + SqlNode result = div; if (sqrt) { - final SqlNumericLiteral half = - SqlLiteral.createExactNumeric("0.5", pos); - result = - SqlStdOperatorTable.POWER.createCall(pos, div, half); + final SqlNumericLiteral half = SqlLiteral.createExactNumeric("0.5", pos); + result = SqlStdOperatorTable.POWER.createCall(pos, divCasted, half); } return result; } + + private SqlNode getCastedSqlNode(SqlNode argInput, RelDataType varType, + SqlParserPos pos, RexNode argRex) { + SqlNode arg; + if (argRex != null && !argRex.getType().equals(varType)) { + arg = SqlStdOperatorTable.CAST.createCall( + pos, argInput, SqlTypeUtil.convertTypeToSpec(varType)); + } else { + arg = argInput; + } + return arg; + } } /** Convertlet that converts {@code LTRIM} and {@code RTRIM} to http://git-wip-us.apache.org/repos/asf/calcite/blob/ca858dd7/core/src/test/java/org/apache/calcite/sql/test/SqlAdvisorTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/org/apache/calcite/sql/test/SqlAdvisorTest.java b/core/src/test/java/org/apache/calcite/sql/test/SqlAdvisorTest.java index 0742731..634c5ff 100644 --- a/core/src/test/java/org/apache/calcite/sql/test/SqlAdvisorTest.java +++ b/core/src/test/java/org/apache/calcite/sql/test/SqlAdvisorTest.java @@ -183,6 +183,7 @@ public class SqlAdvisorTest extends SqlValidatorTestCase { "KEYWORD(POWER)", "KEYWORD(PREV)", "KEYWORD(RANK)", + "KEYWORD(REGR_COUNT)", "KEYWORD(REGR_SXX)", "KEYWORD(REGR_SYY)", "KEYWORD(ROW)", http://git-wip-us.apache.org/repos/asf/calcite/blob/ca858dd7/core/src/test/resources/sql/agg.iq ---------------------------------------------------------------------- diff --git a/core/src/test/resources/sql/agg.iq b/core/src/test/resources/sql/agg.iq index 6c26a89..997ac94 100755 --- a/core/src/test/resources/sql/agg.iq +++ b/core/src/test/resources/sql/agg.iq @@ -2284,4 +2284,101 @@ EnumerableCalc(expr#0..1=[{inputs}], ANYEMPNO=[$t1]) EnumerableTableScan(table=[[scott, EMP]]) !plan +# [CALCITE-1776, CALCITE-2402] REGR_COUNT +SELECT regr_count(COMM, SAL) as "REGR_COUNT(COMM, SAL)", + regr_count(EMPNO, SAL) as "REGR_COUNT(EMPNO, SAL)" +from "scott".emp; ++-----------------------+------------------------+ +| REGR_COUNT(COMM, SAL) | REGR_COUNT(EMPNO, SAL) | ++-----------------------+------------------------+ +| 4 | 14 | ++-----------------------+------------------------+ +(1 row) + +!ok + +EnumerableAggregate(group=[{}], REGR_COUNT(COMM, SAL)=[REGR_COUNT($6, $5)], REGR_COUNT(EMPNO, SAL)=[REGR_COUNT($5)]) + EnumerableTableScan(table=[[scott, EMP]]) +!plan + +# [CALCITE-1776, CALCITE-2402] REGR_SXX, REGR_SXY, REGR_SYY +SELECT + regr_sxx(COMM, SAL) as "REGR_SXX(COMM, SAL)", + regr_syy(COMM, SAL) as "REGR_SYY(COMM, SAL)", + regr_sxx(SAL, COMM) as "REGR_SXX(SAL, COMM)", + regr_syy(SAL, COMM) as "REGR_SYY(SAL, COMM)" +from "scott".emp; ++---------------------+---------------------+---------------------+---------------------+ +| REGR_SXX(COMM, SAL) | REGR_SYY(COMM, SAL) | REGR_SXX(SAL, COMM) | REGR_SYY(SAL, COMM) | ++---------------------+---------------------+---------------------+---------------------+ +| 95000.0000 | 1090000.0000 | 1090000.0000 | 95000.0000 | ++---------------------+---------------------+---------------------+---------------------+ +(1 row) + +!ok + +# [CALCITE-1776, CALCITE-2402] COVAR_POP, COVAR_SAMP, VAR_SAMP, VAR_POP +SELECT + covar_pop(COMM, COMM) as "COVAR_POP(COMM, COMM)", + covar_samp(SAL, SAL) as "COVAR_SAMP(SAL, SAL)", + var_pop(COMM) as "VAR_POP(COMM)", + var_samp(SAL) as "VAR_SAMP(SAL)" +from "scott".emp; ++-----------------------+----------------------+---------------+-------------------+ +| COVAR_POP(COMM, COMM) | COVAR_SAMP(SAL, SAL) | VAR_POP(COMM) | VAR_SAMP(SAL) | ++-----------------------+----------------------+---------------+-------------------+ +| 272500.0000 | 1398313.873626374 | 272500.0000 | 1398313.873626374 | ++-----------------------+----------------------+---------------+-------------------+ +(1 row) + +!ok + +# [CALCITE-1776, CALCITE-2402] REGR_COUNT with group by +SELECT SAL, regr_count(COMM, SAL) as "REGR_COUNT(COMM, SAL)", + regr_count(EMPNO, SAL) as "REGR_COUNT(EMPNO, SAL)" +from "scott".emp group by SAL; ++---------+-----------------------+------------------------+ +| SAL | REGR_COUNT(COMM, SAL) | REGR_COUNT(EMPNO, SAL) | ++---------+-----------------------+------------------------+ +| 1100.00 | 0 | 1 | +| 1250.00 | 2 | 2 | +| 1300.00 | 0 | 1 | +| 1500.00 | 1 | 1 | +| 1600.00 | 1 | 1 | +| 2450.00 | 0 | 1 | +| 2850.00 | 0 | 1 | +| 2975.00 | 0 | 1 | +| 3000.00 | 0 | 2 | +| 5000.00 | 0 | 1 | +| 800.00 | 0 | 1 | +| 950.00 | 0 | 1 | ++---------+-----------------------+------------------------+ +(12 rows) + +!ok + +# [CALCITE-1776, CALCITE-2402] COVAR_POP, COVAR_SAMP, VAR_SAMP, VAR_POP with group by +SELECT + MONTH(HIREDATE) as "MONTH", + covar_samp(SAL, COMM) as "COVAR_SAMP(SAL, COMM)", + var_pop(COMM) as "VAR_POP(COMM)", + var_samp(SAL) as "VAR_SAMP(SAL)" +from "scott".emp +group by MONTH(HIREDATE); ++-------+-----------------------+---------------+-------------------+ +| MONTH | COVAR_SAMP(SAL, COMM) | VAR_POP(COMM) | VAR_SAMP(SAL) | ++-------+-----------------------+---------------+-------------------+ +| 1 | | | 1201250.0000 | +| 11 | | | | +| 12 | | | 1510833.333333334 | +| 2 | -35000.0000 | 10000.0000 | 831458.333333335 | +| 4 | | | | +| 5 | | | | +| 6 | | | | +| 9 | -175000.0000 | 490000.0000 | 31250.0000 | ++-------+-----------------------+---------------+-------------------+ +(8 rows) + +!ok + # End agg.iq http://git-wip-us.apache.org/repos/asf/calcite/blob/ca858dd7/core/src/test/resources/sql/winagg.iq ---------------------------------------------------------------------- diff --git a/core/src/test/resources/sql/winagg.iq b/core/src/test/resources/sql/winagg.iq index eac5822..fbd0dde 100644 --- a/core/src/test/resources/sql/winagg.iq +++ b/core/src/test/resources/sql/winagg.iq @@ -455,4 +455,137 @@ from emp order by emp."ENAME"; !ok +# [CALCITE-2402] COVAR_POP, REGR_COUNT functions +# SUM(x, y) = SUM(x) WHERE y IS NOT NULL +# COVAR_POP(x, y) = (SUM(x * y) - SUM(x, y) * SUM(y, x) / REGR_COUNT(x, y)) / REGR_COUNT(x, y) +select emps."AGE", emps."DEPTNO", + sum(emps."AGE" * emps."DEPTNO") over() as "sum(age * deptno)", + regr_count(emps."AGE", emps."DEPTNO") over() as "regr_count(age, deptno)", + covar_pop(emps."DEPTNO", emps."AGE") over() as "covar_pop" +from emps order by emps."AGE"; ++-----+--------+-------------------+-------------------------+-----------+ +| AGE | DEPTNO | sum(age * deptno) | regr_count(age, deptno) | covar_pop | ++-----+--------+-------------------+-------------------------+-----------+ +| 5 | 20 | 1950 | 3 | 39 | +| 25 | 10 | 1950 | 3 | 39 | +| 80 | 20 | 1950 | 3 | 39 | +| | 40 | 1950 | 3 | 39 | +| | 40 | 1950 | 3 | 39 | ++-----+--------+-------------------+-------------------------+-----------+ +(5 rows) + +!ok + +# [CALCITE-2402] COVAR_POP, REGR_COUNT functions +# SUM(x, y) = SUM(x) WHERE y IS NOT NULL +# COVAR_POP(x, y) = (SUM(x * y) - SUM(x, y) * SUM(y, x) / REGR_COUNT(x, y)) / REGR_COUNT(x, y) +select emps."AGE", emps."DEPTNO", emps."GENDER", + sum(emps."AGE" * emps."DEPTNO") over(partition by emps."GENDER") as "sum(age * deptno)", + regr_count(emps."AGE", emps."DEPTNO") over(partition by emps."GENDER") as "regr_count(age, deptno)", + covar_pop(emps."DEPTNO", emps."AGE") over(partition by emps."GENDER") as "covar_pop" +from emps order by emps."GENDER"; ++-----+--------+--------+-------------------+-------------------------+-----------+ +| AGE | DEPTNO | GENDER | sum(age * deptno) | regr_count(age, deptno) | covar_pop | ++-----+--------+--------+-------------------+-------------------------+-----------+ +| 5 | 20 | F | 100 | 1 | 0 | +| | 40 | F | 100 | 1 | 0 | +| 80 | 20 | M | 1600 | 1 | 0 | +| | 40 | M | 1600 | 1 | 0 | +| 25 | 10 | | 250 | 1 | 0 | ++-----+--------+--------+-------------------+-------------------------+-----------+ +(5 rows) + +!ok + +# [CALCITE-2402] COVAR_SAMP functions +# SUM(x, y) = SUM(x) WHERE y IS NOT NULL +# COVAR_SAMP(x, y) = (SUM(x * y) - SUM(x, y) * SUM(y, x) / REGR_COUNT(x, y)) / (REGR_COUNT(x, y) - 1) +select emps."AGE", emps."DEPTNO", emps."GENDER", + covar_samp(emps."AGE", emps."AGE") over() as "var_samp", + covar_samp(emps."DEPTNO", emps."AGE") over() as "covar_samp", + covar_samp(emps."EMPNO", emps."DEPTNO") over(partition by emps."MANAGER") as "covar_samp partitioned" +from emps order by emps."AGE"; ++-----+--------+--------+----------+------------+------------------------+ +| AGE | DEPTNO | GENDER | var_samp | covar_samp | covar_samp partitioned | ++-----+--------+--------+----------+------------+------------------------+ +| 5 | 20 | F | 1508 | 58 | 0 | +| 25 | 10 | | 1508 | 58 | 50 | +| 80 | 20 | M | 1508 | 58 | 50 | +| | 40 | M | 1508 | 58 | 0 | +| | 40 | F | 1508 | 58 | 0 | ++-----+--------+--------+----------+------------+------------------------+ +(5 rows) + +!ok + +# [CALCITE-2402] VAR_POP, VAR_SAMP functions +# VAR_POP(x) = (SUM(x * x) - SUM(x) * SUM(x) / COUNT(x)) / COUNT(x) +# VAR_SAMP(x) = (SUM(x * x) - SUM(x) * SUM(x) / COUNT(x)) / (COUNT(x) - 1) +select emps."AGE", emps."DEPTNO", emps."GENDER", + var_pop(emps."AGE") over() as "var_pop", + var_pop(emps."AGE") over(partition by emps."AGE") as "var_pop by age", + var_samp(emps."AGE") over() as "var_samp", + var_samp(emps."AGE") over(partition by emps."GENDER") as "var_samp by gender" +from emps order by emps."AGE"; ++-----+--------+--------+---------+----------------+----------+--------------------+ +| AGE | DEPTNO | GENDER | var_pop | var_pop by age | var_samp | var_samp by gender | ++-----+--------+--------+---------+----------------+----------+--------------------+ +| 5 | 20 | F | 1005 | 0 | 1508 | | +| 25 | 10 | | 1005 | 0 | 1508 | | +| 80 | 20 | M | 1005 | 0 | 1508 | | +| | 40 | F | 1005 | | 1508 | | +| | 40 | M | 1005 | | 1508 | | ++-----+--------+--------+---------+----------------+----------+--------------------+ +(5 rows) + +!ok + +# [CALCITE-2402] REGR_SXX, REGR_SXY, REGR_SYY functions +# SUM(x, y) = SUM(x) WHERE y IS NOT NULL +# REGR_SXX(x, y) = REGR_COUNT(x, y) * VAR_POP(y, y) +# REGR_SXY(x, y) = REGR_COUNT(x, y) * COVAR_POP(x, y) +# REGR_SYY(x, y) = REGR_COUNT(x, y) * VAR_POP(x, x) +## COVAR_POP(x, y) = (SUM(x * y) - SUM(x, y) * SUM(y, x) / REGR_COUNT(x, y)) / REGR_COUNT(x, y) +## VAR_POP(y, y) = (SUM(y * y, x) - SUM(y, x) * SUM(y, x) / REGR_COUNT(x, y)) / REGR_COUNT(x, y) +select emps."AGE", emps."DEPTNO", + regr_sxx(emps."AGE", emps."DEPTNO") over() as "regr_sxx(age, deptno)", + regr_syy(emps."AGE", emps."DEPTNO") over() as "regr_syy(age, deptno)" +from emps order by emps."AGE"; ++-----+--------+-----------------------+-----------------------+ +| AGE | DEPTNO | regr_sxx(age, deptno) | regr_syy(age, deptno) | ++-----+--------+-----------------------+-----------------------+ +| 5 | 20 | 66 | 3015 | +| 25 | 10 | 66 | 3015 | +| 80 | 20 | 66 | 3015 | +| | 40 | 66 | 3015 | +| | 40 | 66 | 3015 | ++-----+--------+-----------------------+-----------------------+ +(5 rows) + +!ok + +# [CALCITE-2402] REGR_SXX, REGR_SXY, REGR_SYY functions +# SUM(x, y) = SUM(x) WHERE y IS NOT NULL +# REGR_SXX(x, y) = REGR_COUNT(x, y) * COVAR_POP(y, y) +# REGR_SXY(x, y) = REGR_COUNT(x, y) * COVAR_POP(x, y) +# REGR_SYY(x, y) = REGR_COUNT(x, y) * COVAR_POP(x, x) +## COVAR_POP(x, y) = (SUM(x * y) - SUM(x, y) * SUM(y, x) / REGR_COUNT(x, y)) / REGR_COUNT(x, y) +## COVAR_POP(y, y) = (SUM(y * y, x) - SUM(y, x) * SUM(y, x) / REGR_COUNT(x, y)) / REGR_COUNT(x, y) +select emps."AGE", emps."DEPTNO", emps."GENDER", + regr_sxx(emps."AGE", emps."DEPTNO") over(partition by emps."GENDER") as "regr_sxx(age, deptno)", + regr_syy(emps."AGE", emps."DEPTNO") over(partition by emps."GENDER") as "regr_syy(age, deptno)" +from emps order by emps."GENDER"; ++-----+--------+--------+-----------------------+-----------------------+ +| AGE | DEPTNO | GENDER | regr_sxx(age, deptno) | regr_syy(age, deptno) | ++-----+--------+--------+-----------------------+-----------------------+ +| 5 | 20 | F | 0 | 0 | +| | 40 | F | 0 | 0 | +| 80 | 20 | M | 0 | 0 | +| | 40 | M | 0 | 0 | +| 25 | 10 | | 0 | 0 | ++-----+--------+--------+-----------------------+-----------------------+ +(5 rows) + +!ok + # End winagg.iq http://git-wip-us.apache.org/repos/asf/calcite/blob/ca858dd7/site/_docs/reference.md ---------------------------------------------------------------------- diff --git a/site/_docs/reference.md b/site/_docs/reference.md index acf8f29..e76e7d6 100644 --- a/site/_docs/reference.md +++ b/site/_docs/reference.md @@ -1510,6 +1510,7 @@ passed to the aggregate function. | VAR_SAMP( [ ALL | DISTINCT ] numeric) | Returns the sample variance (square of the sample standard deviation) of *numeric* across all input values | COVAR_POP(numeric1, numeric2) | Returns the population covariance of the pair (*numeric1*, *numeric2*) across all input values | COVAR_SAMP(numeric1, numeric2) | Returns the sample covariance of the pair (*numeric1*, *numeric2*) across all input values +| REGR_COUNT(numeric1, numeric2) | Returns the number of rows where both dependent and independent expressions are not null | REGR_SXX(numeric1, numeric2) | Returns the sum of squares of the dependent expression in a linear regression model | REGR_SYY(numeric1, numeric2) | Returns the sum of squares of the independent expression in a linear regression model @@ -1517,7 +1518,6 @@ Not implemented: * REGR_AVGX(numeric1, numeric2) * REGR_AVGY(numeric1, numeric2) -* REGR_COUNT(numeric1, numeric2) * REGR_INTERCEPT(numeric1, numeric2) * REGR_R2(numeric1, numeric2) * REGR_SLOPE(numeric1, numeric2)
