[CALCITE-1909] Output rowType of Match should include PARTITION BY and ORDER BY columns
Add method RelDataTypeFactory.Builder.nameExists(String). Close apache/calcite#508 Project: http://git-wip-us.apache.org/repos/asf/calcite/repo Commit: http://git-wip-us.apache.org/repos/asf/calcite/commit/9efefbc8 Tree: http://git-wip-us.apache.org/repos/asf/calcite/tree/9efefbc8 Diff: http://git-wip-us.apache.org/repos/asf/calcite/diff/9efefbc8 Branch: refs/heads/master Commit: 9efefbc8cad3a965e2a22730dd06015c0b2b4ff7 Parents: a98b21e Author: Zhiqiang-He <absolute...@qq.com> Authored: Sun Aug 6 10:54:49 2017 +0800 Committer: Julian Hyde <jh...@apache.org> Committed: Mon Aug 7 09:15:11 2017 -0700 ---------------------------------------------------------------------- .../calcite/rel/type/RelDataTypeFactory.java | 5 ++ .../calcite/sql/validate/SqlValidatorImpl.java | 62 ++++++++++++++------ .../calcite/test/SqlToRelConverterTest.java | 43 ++++++++++++++ .../calcite/test/SqlToRelConverterTest.xml | 48 ++++++++++++++- 4 files changed, 138 insertions(+), 20 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/calcite/blob/9efefbc8/core/src/main/java/org/apache/calcite/rel/type/RelDataTypeFactory.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/calcite/rel/type/RelDataTypeFactory.java b/core/src/main/java/org/apache/calcite/rel/type/RelDataTypeFactory.java index 67e19b9..559853e 100644 --- a/core/src/main/java/org/apache/calcite/rel/type/RelDataTypeFactory.java +++ b/core/src/main/java/org/apache/calcite/rel/type/RelDataTypeFactory.java @@ -503,6 +503,11 @@ public interface RelDataTypeFactory { public RelDataType build() { return typeFactory.createStructType(kind, types, names); } + + /** Returns whether a field exists with the given name. */ + public boolean nameExists(String name) { + return names.contains(name); + } } } http://git-wip-us.apache.org/repos/asf/calcite/blob/9efefbc8/core/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java b/core/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java index c2dc999..ce640a2 100644 --- a/core/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java +++ b/core/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java @@ -4506,12 +4506,7 @@ public class SqlValidatorImpl implements SqlValidatorWithHints { && rowsPerMatch.getValue() == SqlMatchRecognize.RowsPerMatchOption.ALL_ROWS; - final List<Map.Entry<String, RelDataType>> fields = new ArrayList<>(); - if (allRows) { - final SqlValidatorNamespace sqlNs = getNamespace(matchRecognize.getTableRef()); - final RelDataType inputDataType = sqlNs.getRowType(); - fields.addAll(inputDataType.getFieldList()); - } + final RelDataTypeFactory.Builder typeBuilder = typeFactory.builder(); // parse PARTITION BY column SqlNodeList partitionBy = matchRecognize.getPartitionList(); @@ -4519,11 +4514,9 @@ public class SqlValidatorImpl implements SqlValidatorWithHints { for (SqlNode node : partitionBy) { SqlIdentifier identifier = (SqlIdentifier) node; identifier.validate(this, scope); - if (allRows) { - RelDataType type = deriveType(scope, identifier); - String name = identifier.names.get(1); - fields.add(Pair.of(name, type)); - } + RelDataType type = deriveType(scope, identifier); + String name = identifier.names.get(1); + typeBuilder.add(name, type); } } @@ -4532,6 +4525,31 @@ public class SqlValidatorImpl implements SqlValidatorWithHints { if (orderBy != null) { for (SqlNode node : orderBy) { node.validate(this, scope); + SqlIdentifier identifier = null; + if (node instanceof SqlBasicCall) { + identifier = (SqlIdentifier) ((SqlBasicCall) node).getOperands()[0]; + } else { + identifier = (SqlIdentifier) node; + } + + if (allRows) { + RelDataType type = deriveType(scope, identifier); + String name = identifier.names.get(1); + if (!typeBuilder.nameExists(name)) { + typeBuilder.add(name, type); + } + } + } + } + + if (allRows) { + final SqlValidatorNamespace sqlNs = + getNamespace(matchRecognize.getTableRef()); + final RelDataType inputDataType = sqlNs.getRowType(); + for (RelDataTypeField fs : inputDataType.getFieldList()) { + if (!typeBuilder.nameExists(fs.getName())) { + typeBuilder.add(fs); + } } } @@ -4549,14 +4567,14 @@ public class SqlValidatorImpl implements SqlValidatorWithHints { String leftString = ((SqlIdentifier) operands.get(0)).getSimple(); if (scope.getPatternVars().contains(leftString)) { throw newValidationError(operands.get(0), - RESOURCE.patternVarAlreadyDefined(leftString)); + RESOURCE.patternVarAlreadyDefined(leftString)); } scope.addPatternVar(leftString); for (SqlNode right : (SqlNodeList) operands.get(1)) { SqlIdentifier id = (SqlIdentifier) right; if (!scope.getPatternVars().contains(id.getSimple())) { throw newValidationError(id, - RESOURCE.unknownPattern(id.getSimple())); + RESOURCE.unknownPattern(id.getSimple())); } scope.addPatternVar(id.getSimple()); } @@ -4574,8 +4592,15 @@ public class SqlValidatorImpl implements SqlValidatorWithHints { } } - fields.addAll(validateMeasure(matchRecognize, scope, allRows)); - final RelDataType rowType = typeFactory.createStructType(fields); + List<Map.Entry<String, RelDataType>> measureColumns = + validateMeasure(matchRecognize, scope, allRows); + for (Map.Entry<String, RelDataType> c : measureColumns) { + if (!typeBuilder.nameExists(c.getKey())) { + typeBuilder.add(c.getKey(), c.getValue()); + } + } + + final RelDataType rowType = typeBuilder.build(); if (matchRecognize.getMeasureList().size() == 0) { ns.setType(getNamespace(matchRecognize.getTableRef()).getRowType()); } else { @@ -4776,9 +4801,10 @@ public class SqlValidatorImpl implements SqlValidatorWithHints { return newExpr; } - public SqlNode expandGroupByOrHavingExpr(SqlNode expr, SqlValidatorScope scope, SqlSelect select, - boolean havingExpression) { - final Expander expander = new ExtendedExpander(this, scope, select, expr, havingExpression); + public SqlNode expandGroupByOrHavingExpr(SqlNode expr, + SqlValidatorScope scope, SqlSelect select, boolean havingExpression) { + final Expander expander = new ExtendedExpander(this, scope, select, expr, + havingExpression); SqlNode newExpr = expr.accept(expander); if (expr != newExpr) { setOriginal(newExpr, expr); http://git-wip-us.apache.org/repos/asf/calcite/blob/9efefbc8/core/src/test/java/org/apache/calcite/test/SqlToRelConverterTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/org/apache/calcite/test/SqlToRelConverterTest.java b/core/src/test/java/org/apache/calcite/test/SqlToRelConverterTest.java index d921f6a..9275cfa 100644 --- a/core/src/test/java/org/apache/calcite/test/SqlToRelConverterTest.java +++ b/core/src/test/java/org/apache/calcite/test/SqlToRelConverterTest.java @@ -2542,6 +2542,49 @@ public class SqlToRelConverterTest extends SqlToRelTestBase { sql(sql).ok(); } + /** Test case for + * <a href="https://issues.apache.org/jira/browse/CALCITE-1909">[CALCITE-1909] + * Output rowType of Match should include PARTITION BY and ORDER BY + * columns</a>. */ + @Test public void testMatchRecognizeMeasures2() { + final String sql = "select *\n" + + " from emp match_recognize\n" + + " (\n" + + " partition by job\n" + + " order by sal\n" + + " measures MATCH_NUMBER() as match_num, " + + " CLASSIFIER() as var_match, " + + " STRT.mgr as start_nw," + + " LAST(DOWN.mgr) as bottom_nw," + + " LAST(up.mgr) as end_nw" + + " pattern (strt down+ up+)\n" + + " define\n" + + " down as down.mgr < PREV(down.mgr),\n" + + " up as up.mgr > prev(up.mgr)\n" + + " ) mr"; + sql(sql).ok(); + } + + @Test public void testMatchRecognizeMeasures3() { + final String sql = "select *\n" + + " from emp match_recognize\n" + + " (\n" + + " partition by job\n" + + " order by sal\n" + + " measures MATCH_NUMBER() as match_num, " + + " CLASSIFIER() as var_match, " + + " STRT.mgr as start_nw," + + " LAST(DOWN.mgr) as bottom_nw," + + " LAST(up.mgr) as end_nw" + + " ALL ROWS PER MATCH" + + " pattern (strt down+ up+)\n" + + " define\n" + + " down as down.mgr < PREV(down.mgr),\n" + + " up as up.mgr > prev(up.mgr)\n" + + " ) mr"; + sql(sql).ok(); + } + @Test public void testMatchRecognizePatternSkip1() { final String sql = "select *\n" + " from emp match_recognize\n" http://git-wip-us.apache.org/repos/asf/calcite/blob/9efefbc8/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml ---------------------------------------------------------------------- diff --git a/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml b/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml index c14c85a..0ce33b6 100644 --- a/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml +++ b/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml @@ -4763,8 +4763,52 @@ LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$ </Resource> <Resource name="plan"> <![CDATA[ -LogicalProject(MATCH_NUM=[$0], VAR_MATCH=[$1], START_NW=[$2], BOTTOM_NW=[$3], END_NW=[$4]) - LogicalMatch(partition=[[$2, $5]], order=[[2, 5 DESC]], outputFields=[[MATCH_NUM, VAR_MATCH, START_NW, BOTTOM_NW, END_NW]], allRows=[false], after=[FLAG(SKIP TO NEXT ROW)], pattern=[(('STRT', PATTERN_QUANTIFIER('DOWN', 1, -1, false)), PATTERN_QUANTIFIER('UP', 1, -1, false))], isStrictStarts=[false], isStrictEnds=[false], subsets=[[]], patternDefinitions=[[<(DOWN.$3, PREV(DOWN.$3, 1)), >(UP.$3, PREV(UP.$3, 1))]], inputFields=[[EMPNO, ENAME, JOB, MGR, HIREDATE, SAL, COMM, DEPTNO, SLACKER]]) +LogicalProject(JOB=[$0], SAL=[$1], MATCH_NUM=[$2], VAR_MATCH=[$3], START_NW=[$4], BOTTOM_NW=[$5], END_NW=[$6]) + LogicalMatch(partition=[[$2, $5]], order=[[2, 5 DESC]], outputFields=[[JOB, SAL, MATCH_NUM, VAR_MATCH, START_NW, BOTTOM_NW, END_NW]], allRows=[false], after=[FLAG(SKIP TO NEXT ROW)], pattern=[(('STRT', PATTERN_QUANTIFIER('DOWN', 1, -1, false)), PATTERN_QUANTIFIER('UP', 1, -1, false))], isStrictStarts=[false], isStrictEnds=[false], subsets=[[]], patternDefinitions=[[<(DOWN.$3, PREV(DOWN.$3, 1)), >(UP.$3, PREV(UP.$3, 1))]], inputFields=[[EMPNO, ENAME, JOB, MGR, HIREDATE, SAL, COMM, DEPTNO, SLACKER]]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) +]]> + </Resource> + </TestCase> + <TestCase name="testMatchRecognizeMeasures2"> + <Resource name="sql"> + <![CDATA[select * + from emp match_recognize + ( + partition by job + order by sal + measures STRT.mgr as start_nw, LAST(DOWN.mgr) as bottom_nw, LAST(up.mgr) as end_nw pattern (strt down+ up+) + define + down as down.mgr < PREV(down.mgr), + up as up.mgr > prev(up.mgr) + ) mr]]> + </Resource> + <Resource name="plan"> + <![CDATA[ +LogicalProject(JOB=[$0], MATCH_NUM=[$1], VAR_MATCH=[$2], START_NW=[$3], BOTTOM_NW=[$4], END_NW=[$5]) + LogicalMatch(partition=[[$2]], order=[[5]], outputFields=[[JOB, MATCH_NUM, VAR_MATCH, START_NW, BOTTOM_NW, END_NW]], allRows=[false], after=[FLAG(SKIP TO NEXT ROW)], pattern=[(('STRT', PATTERN_QUANTIFIER('DOWN', 1, -1, false)), PATTERN_QUANTIFIER('UP', 1, -1, false))], isStrictStarts=[false], isStrictEnds=[false], subsets=[[]], patternDefinitions=[[<(DOWN.$3, PREV(DOWN.$3, 1)), >(UP.$3, PREV(UP.$3, 1))]], inputFields=[[EMPNO, ENAME, JOB, MGR, HIREDATE, SAL, COMM, DEPTNO, SLACKER]]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) +]]> + </Resource> + </TestCase> + <TestCase name="testMatchRecognizeMeasures3"> + <Resource name="sql"> + <![CDATA[select * + from emp match_recognize + ( + partition by job + order by sal + measures STRT.mgr as start_nw, LAST(DOWN.mgr) as bottom_nw, LAST(up.mgr) as end_nw + ALL ROWS PER MATCH + pattern (strt down+ up+) + define + down as down.mgr < PREV(down.mgr), + up as up.mgr > prev(up.mgr) + ) mr]]> + </Resource> + <Resource name="plan"> + <![CDATA[ +LogicalProject(JOB=[$0], SAL=[$1], EMPNO=[$2], ENAME=[$3], MGR=[$4], HIREDATE=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8], MATCH_NUM=[$9], VAR_MATCH=[$10], START_NW=[$11], BOTTOM_NW=[$12], END_NW=[$13]) + LogicalMatch(partition=[[$2]], order=[[5]], outputFields=[[JOB, SAL, EMPNO, ENAME, MGR, HIREDATE, COMM, DEPTNO, SLACKER, MATCH_NUM, VAR_MATCH, START_NW, BOTTOM_NW, END_NW]], allRows=[true], after=[FLAG(SKIP TO NEXT ROW)], pattern=[(('STRT', PATTERN_QUANTIFIER('DOWN', 1, -1, false)), PATTERN_QUANTIFIER('UP', 1, -1, false))], isStrictStarts=[false], isStrictEnds=[false], subsets=[[]], patternDefinitions=[[<(DOWN.$3, PREV(DOWN.$3, 1)), >(UP.$3, PREV(UP.$3, 1))]], inputFields=[[EMPNO, ENAME, JOB, MGR, HIREDATE, SAL, COMM, DEPTNO, SLACKER]]) LogicalTableScan(table=[[CATALOG, SALES, EMP]]) ]]> </Resource>