This is an automated email from the ASF dual-hosted git repository. danny0405 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/calcite.git
commit d8f4cc4aa4f6c8b49bd76d5043e30270b7ff2546 Author: godfreyhe <[email protected]> AuthorDate: Tue Jun 4 16:37:54 2019 +0800 [CALCITE-2744] RelDecorrelator use wrong output map for LogicalAggregate decorrelate (godfreyhe and Danny Chan) godfreyhe started the work by apply a new map for LogicalAggregate decorrelate register, and Danny Chan add shifts for constant keys mapping. Also fix the test case name and comments. close apache/calcite#1254 --- .../apache/calcite/sql2rel/RelDecorrelator.java | 38 ++++++-- .../apache/calcite/test/MockSqlOperatorTable.java | 24 +++++ .../org/apache/calcite/test/RelOptRulesTest.java | 38 ++++++++ .../org/apache/calcite/test/RelOptRulesTest.xml | 100 +++++++++++++++++++++ 4 files changed, 195 insertions(+), 5 deletions(-) diff --git a/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java b/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java index 5e9a1c4..a11cb9e 100644 --- a/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java +++ b/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java @@ -467,8 +467,11 @@ public class RelDecorrelator implements ReflectiveVisitor { } final RelNode newInput = frame.r; + // aggregate outputs mapping: group keys and aggregates + final Map<Integer, Integer> outputMap = new HashMap<>(); + // map from newInput - Map<Integer, Integer> mapNewInputToProjOutputs = new HashMap<>(); + final Map<Integer, Integer> mapNewInputToProjOutputs = new HashMap<>(); final int oldGroupKeyCount = rel.getGroupSet().cardinality(); // Project projects the original expressions, @@ -490,6 +493,9 @@ public class RelDecorrelator implements ReflectiveVisitor { omittedConstants.put(i, constant); continue; } + + // add mapping of group keys. + outputMap.put(i, newPos); int newInputPos = frame.oldToNewOutputs.get(i); projects.add(RexInputRef.of2(newInputPos, newInputOutput)); mapNewInputToProjOutputs.put(newInputPos, newPos); @@ -593,7 +599,7 @@ public class RelDecorrelator implements ReflectiveVisitor { // The old to new output position mapping will be the same as that // of newProject, plus any aggregates that the oldAgg produces. - combinedMap.put( + outputMap.put( oldInputOutputFieldCount + i, newInputOutputFieldCount + i); } @@ -605,15 +611,37 @@ public class RelDecorrelator implements ReflectiveVisitor { final List<RexNode> postProjects = new ArrayList<>(relBuilder.fields()); for (Map.Entry<Integer, RexLiteral> entry : omittedConstants.descendingMap().entrySet()) { - postProjects.add(entry.getKey() + frame.corDefOutputs.size(), - entry.getValue()); + int index = entry.getKey() + frame.corDefOutputs.size(); + postProjects.add(index, entry.getValue()); + // Shift the outputs whose index equals with or bigger than the added index + // with 1 offset. + shiftMapping(outputMap, index, 1); + // Then add the constant key mapping. + outputMap.put(entry.getKey(), index); } relBuilder.project(postProjects); } // Aggregate does not change input ordering so corVars will be // located at the same position as the input newProject. - return register(rel, relBuilder.build(), combinedMap, corDefOutputs); + return register(rel, relBuilder.build(), outputMap, corDefOutputs); + } + + /** + * Shift the mapping to fixed offset from the {@code startIndex}. + * @param mapping the original mapping + * @param startIndex any output whose index equals with or bigger than the starting index + * would be shift + * @param offset shift offset + */ + private static void shiftMapping(Map<Integer, Integer> mapping, int startIndex, int offset) { + for (Map.Entry<Integer, Integer> entry : mapping.entrySet()) { + if (entry.getValue() >= startIndex) { + mapping.put(entry.getKey(), entry.getValue() + offset); + } else { + mapping.put(entry.getKey(), entry.getValue()); + } + } } public Frame getInvoke(RelNode r, RelNode parent) { diff --git a/core/src/test/java/org/apache/calcite/test/MockSqlOperatorTable.java b/core/src/test/java/org/apache/calcite/test/MockSqlOperatorTable.java index 78b842d..c128540 100644 --- a/core/src/test/java/org/apache/calcite/test/MockSqlOperatorTable.java +++ b/core/src/test/java/org/apache/calcite/test/MockSqlOperatorTable.java @@ -18,6 +18,7 @@ package org.apache.calcite.test; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlFunction; import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlIdentifier; @@ -27,6 +28,8 @@ import org.apache.calcite.sql.SqlOperatorBinding; import org.apache.calcite.sql.SqlOperatorTable; import org.apache.calcite.sql.parser.SqlParserPos; import org.apache.calcite.sql.type.OperandTypes; +import org.apache.calcite.sql.type.ReturnTypes; +import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql.util.ChainedSqlOperatorTable; import org.apache.calcite.sql.util.ListSqlOperatorTable; @@ -64,6 +67,7 @@ public class MockSqlOperatorTable extends ChainedSqlOperatorTable { opTab.addOperator(new RampFunction()); opTab.addOperator(new DedupFunction()); opTab.addOperator(new MyFunction()); + opTab.addOperator(new MyAvgAggFunction()); } /** "RAMP" user-defined function. */ @@ -125,6 +129,26 @@ public class MockSqlOperatorTable extends ChainedSqlOperatorTable { return typeFactory.createSqlType(SqlTypeName.BIGINT); } } + + /** "MY_AVG" user-defined aggregate function. */ + public static class MyAvgAggFunction extends SqlAggFunction { + public MyAvgAggFunction() { + super("MY_AVG", + null, + SqlKind.AVG, + ReturnTypes.AVG_AGG_FUNCTION, + null, + OperandTypes.family(SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC), + SqlFunctionCategory.NUMERIC, + false, + false); + } + + @Override public boolean isDeterministic() { + return false; + } + } + } // End MockSqlOperatorTable.java diff --git a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java index d70b320..38719ce 100644 --- a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java +++ b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java @@ -5278,6 +5278,44 @@ public class RelOptRulesTest extends RelOptTestBase { } /** Test case for + * <a href="https://issues.apache.org/jira/browse/CALCITE-2744">[CALCITE-2744] + * RelDecorrelator use wrong output map for LogicalAggregate decorrelate</a>. */ + @Test public void testDecorrelateAggWithConstantGroupKey() { + final String sql = "SELECT * FROM emp A where sal in \n" + + "(SELECT max(sal) FROM emp B where A.mgr = B.empno group by deptno, 'abc')"; + sql(sql) + .withLateDecorrelation(true) + .withTrim(true) + .with(HepProgram.builder().build()) + .check(); + } + + /** Test case for CALCITE-2744 for aggregate decorrelate with multi-param agg call + * but without group key. */ + @Test public void testDecorrelateAggWithMultiParamsAggCall() { + final String sql = "SELECT * FROM (SELECT MY_AVG(sal, 1) AS c FROM emp) as m,\n" + + " LATERAL TABLE(ramp(m.c)) AS T(s)"; + sql(sql) + .withLateDecorrelation(true) + .withTrim(true) + .with(HepProgram.builder().build()) + .checkUnchanged(); + } + + /** Same as {@link #testDecorrelateAggWithMultiParamsAggCall} + * but with constant grouping key. */ + @Test public void testDecorrelateAggWithMultiParamsAggCall2() { + final String sql = "SELECT * FROM " + + "(SELECT MY_AVG(sal, 1) AS c FROM emp group by empno, 'abc') as m,\n" + + " LATERAL TABLE(ramp(m.c)) AS T(s)"; + sql(sql) + .withLateDecorrelation(true) + .withTrim(true) + .with(HepProgram.builder().build()) + .checkUnchanged(); + } + + /** Test case for * <a href="https://issues.apache.org/jira/browse/CALCITE-434">[CALCITE-434] * Converting predicates on date dimension columns into date ranges</a>, * specifically a rule that converts {@code EXTRACT(YEAR FROM ...) = constant} diff --git a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml index a359d64..21f99d2 100644 --- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml +++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml @@ -2684,6 +2684,106 @@ LogicalProject("K0"=[$0], "C1"=[$1], "F1"."A0"=[$2], "F2"."A0"=[$3], "F0"."C0"=[ ]]> </Resource> </TestCase> + <TestCase name="testDecorrelateAggWithConstantGroupKey"> + <Resource name="sql"> + <![CDATA[SELECT * FROM emp A where sal in +(SELECT max(sal) FROM emp B where A.mgr = B.empno group by deptno, 'abc')]]> + </Resource> + <Resource name="planBefore"> + <![CDATA[ +LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8]) + LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{3, 5}]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) + LogicalFilter(condition=[=($cor0.SAL, $0)]) + LogicalAggregate(group=[{0}]) + LogicalProject(EXPR$0=[$2]) + LogicalAggregate(group=[{0, 1}], EXPR$0=[MAX($2)]) + LogicalProject(DEPTNO=[$7], $f1=['abc'], SAL=[$5]) + LogicalFilter(condition=[=($cor0.MGR, $0)]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) +]]> + </Resource> + <Resource name="planMid"> + <![CDATA[ +LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8]) + LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{3, 5}]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) + LogicalFilter(condition=[=($cor0.SAL, $0)]) + LogicalAggregate(group=[{0}]) + LogicalProject(EXPR$0=[$2]) + LogicalAggregate(group=[{0, 1}], EXPR$0=[MAX($2)]) + LogicalProject(DEPTNO=[$7], $f1=['abc'], SAL=[$5]) + LogicalFilter(condition=[=($cor0.MGR, $0)]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) +]]> + </Resource> + <Resource name="planAfter"> + <![CDATA[ +LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8]) + LogicalJoin(condition=[AND(=($3, $10), =($5, $9))], joinType=[inner]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) + LogicalAggregate(group=[{0, 1}]) + LogicalProject(EXPR$0=[$2], EMPNO=[$1]) + LogicalAggregate(group=[{0, 1}], EXPR$0=[MAX($3)]) + LogicalProject(DEPTNO=[$7], EMPNO=[$0], $f1=['abc'], SAL=[$5]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) +]]> + </Resource> + </TestCase> + <TestCase name="testDecorrelateAggWithMultiParamsAggCall"> + <Resource name="sql"> + <![CDATA[SELECT * FROM (SELECT MY_AVG(sal, 1) AS c FROM emp) as m, + LATERAL TABLE(ramp(m.c)) AS T(s)]]> + </Resource> + <Resource name="planBefore"> + <![CDATA[ +LogicalProject(C=[$0], S=[$1]) + LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0}]) + LogicalAggregate(group=[{}], C=[MY_AVG($0, $1)]) + LogicalProject(SAL=[$5], $f1=[1]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) + LogicalTableFunctionScan(invocation=[RAMP($cor0.C)], rowType=[RecordType(INTEGER I)]) +]]> + </Resource> + <Resource name="planMid"> + <![CDATA[ +LogicalProject(C=[$0], S=[$1]) + LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0}]) + LogicalAggregate(group=[{}], C=[MY_AVG($0, $1)]) + LogicalProject(SAL=[$5], $f1=[1]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) + LogicalTableFunctionScan(invocation=[RAMP($cor0.C)], rowType=[RecordType(INTEGER I)]) +]]> + </Resource> + </TestCase> + <TestCase name="testDecorrelateAggWithMultiParamsAggCall2"> + <Resource name="sql"> + <![CDATA[SELECT * FROM (SELECT MY_AVG(sal, 1) AS c FROM emp group by empno, 'abc') as m, + LATERAL TABLE(ramp(m.c)) AS T(s)]]> + </Resource> + <Resource name="planBefore"> + <![CDATA[ +LogicalProject(C=[$0], S=[$1]) + LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0}]) + LogicalProject(C=[$2]) + LogicalAggregate(group=[{0, 1}], C=[MY_AVG($2, $3)]) + LogicalProject(EMPNO=[$0], $f1=['abc'], SAL=[$5], $f3=[1]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) + LogicalTableFunctionScan(invocation=[RAMP($cor0.C)], rowType=[RecordType(INTEGER I)]) +]]> + </Resource> + <Resource name="planMid"> + <![CDATA[ +LogicalProject(C=[$0], S=[$1]) + LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0}]) + LogicalProject(C=[$2]) + LogicalAggregate(group=[{0, 1}], C=[MY_AVG($2, $3)]) + LogicalProject(EMPNO=[$0], $f1=['abc'], SAL=[$5], $f3=[1]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) + LogicalTableFunctionScan(invocation=[RAMP($cor0.C)], rowType=[RecordType(INTEGER I)]) +]]> + </Resource> + </TestCase> <TestCase name="testExtractYearMonthToRange"> <Resource name="sql"> <