This is an automated email from the ASF dual-hosted git repository. rubenql pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/calcite.git
The following commit(s) were added to refs/heads/main by this push: new dc2e1af5c1 [CALCITE-5061] Improve recursive application of the field trimming dc2e1af5c1 is described below commit dc2e1af5c168f01a045edc3fb3f1970418b2a03a Author: rubenada <rube...@gmail.com> AuthorDate: Fri May 13 15:23:34 2022 +0100 [CALCITE-5061] Improve recursive application of the field trimming --- .../apache/calcite/sql2rel/RelFieldTrimmer.java | 24 +++++++-- .../calcite/sql2rel/RelFieldTrimmerTest.java | 29 +++++++++++ .../java/org/apache/calcite/test/JdbcTest.java | 10 ++-- .../java/org/apache/calcite/test/StreamTest.java | 21 ++++---- .../test/enumerable/EnumerableJoinTest.java | 16 +++--- .../test/enumerable/EnumerableRepeatUnionTest.java | 6 +-- .../org/apache/calcite/test/RelOptRulesTest.xml | 58 +++++++++++----------- 7 files changed, 103 insertions(+), 61 deletions(-) diff --git a/core/src/main/java/org/apache/calcite/sql2rel/RelFieldTrimmer.java b/core/src/main/java/org/apache/calcite/sql2rel/RelFieldTrimmer.java index fff44d9694..2a53e9427c 100644 --- a/core/src/main/java/org/apache/calcite/sql2rel/RelFieldTrimmer.java +++ b/core/src/main/java/org/apache/calcite/sql2rel/RelFieldTrimmer.java @@ -352,11 +352,25 @@ public class RelFieldTrimmer implements ReflectiveVisitor { RelNode rel, ImmutableBitSet fieldsUsed, Set<RelDataTypeField> extraFields) { - // We don't know how to trim this kind of relational expression, so give - // it back intact. + // We don't know how to trim this kind of relational expression Util.discard(fieldsUsed); - return result(rel, - Mappings.createIdentity(rel.getRowType().getFieldCount())); + if (rel.getInputs().isEmpty()) { + return result(rel, Mappings.createIdentity(rel.getRowType().getFieldCount())); + } + + // We don't know how to trim this RelNode, but we can try to trim inside its inputs + List<RelNode> newInputs = new ArrayList<>(rel.getInputs().size()); + for (RelNode input : rel.getInputs()) { + ImmutableBitSet inputFieldsUsed = ImmutableBitSet.range(input.getRowType().getFieldCount()); + TrimResult trimResult = dispatchTrimFields(input, inputFieldsUsed, extraFields); + if (!trimResult.right.isIdentity()) { + throw new IllegalArgumentException("Expected identity mapping after processing RelNode " + + input + "; but got " + trimResult.right); + } + newInputs.add(trimResult.left); + } + RelNode newRel = rel.copy(rel.getTraitSet(), newInputs); + return result(newRel, Mappings.createIdentity(newRel.getRowType().getFieldCount())); } /** @@ -914,7 +928,7 @@ public class RelFieldTrimmer implements ReflectiveVisitor { // They can not be trimmed because the comparison needs // complete fields. if (!(setOp.kind == SqlKind.UNION && setOp.all)) { - return result(setOp, Mappings.createIdentity(fieldCount)); + return trimFields((RelNode) setOp, fieldsUsed, extraFields); } int changeCount = 0; diff --git a/core/src/test/java/org/apache/calcite/sql2rel/RelFieldTrimmerTest.java b/core/src/test/java/org/apache/calcite/sql2rel/RelFieldTrimmerTest.java index 81d3620ce6..8f10a2cef7 100644 --- a/core/src/test/java/org/apache/calcite/sql2rel/RelFieldTrimmerTest.java +++ b/core/src/test/java/org/apache/calcite/sql2rel/RelFieldTrimmerTest.java @@ -516,4 +516,33 @@ class RelFieldTrimmerTest { assertThat(trimmed, hasTree(expected)); } } + + @Test void testUnionFieldTrimmer() { + final RelBuilder builder = RelBuilder.create(config().build()); + final RelNode root = + builder.scan("EMP").as("t1") + .project(builder.field("EMPNO")) + .scan("EMP").as("t2") + .scan("EMP").as("t3") + .join(JoinRelType.INNER, + builder.equals( + builder.field(2, "t2", "EMPNO"), + builder.field(2, "t3", "EMPNO"))) + .project(builder.field("t2", "EMPNO")) + .union(false) + .build(); + final RelFieldTrimmer fieldTrimmer = new RelFieldTrimmer(null, builder); + final RelNode trimmed = fieldTrimmer.trim(root); + final String expected = "" + + "LogicalUnion(all=[false])\n" + + " LogicalProject(EMPNO=[$0])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(EMPNO=[$0])\n" + + " LogicalJoin(condition=[=($0, $1)], joinType=[inner])\n" + + " LogicalProject(EMPNO=[$0])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(EMPNO=[$0])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + assertThat(trimmed, hasTree(expected)); + } } diff --git a/core/src/test/java/org/apache/calcite/test/JdbcTest.java b/core/src/test/java/org/apache/calcite/test/JdbcTest.java index ddd35fa1a3..daf7484fed 100644 --- a/core/src/test/java/org/apache/calcite/test/JdbcTest.java +++ b/core/src/test/java/org/apache/calcite/test/JdbcTest.java @@ -7749,14 +7749,13 @@ public class JdbcTest { + " pattern (up s)\n" + " define up as up.\"empid\" = 100)"; final String convert = "" - + "LogicalProject(C=[$0], EMPID=[$1], TWO=[$2])\n" - + " LogicalMatch(partition=[[]], order=[[0 DESC]], " + + "LogicalMatch(partition=[[]], order=[[0 DESC]], " + "outputFields=[[C, EMPID, TWO]], allRows=[false], " + "after=[FLAG(SKIP TO NEXT ROW)], pattern=[('UP', 'S')], " + "isStrictStarts=[false], isStrictEnds=[false], subsets=[[]], " + "patternDefinitions=[[=(CAST(PREV(UP.$0, 0)):INTEGER NOT NULL, 100)]], " + "inputFields=[[empid, deptno, name, salary, commission]])\n" - + " LogicalTableScan(table=[[hr, emps]])\n"; + + " LogicalTableScan(table=[[hr, emps]])\n"; final String plan = "PLAN=" + "EnumerableMatch(partition=[[]], order=[[0 DESC]], " + "outputFields=[[C, EMPID, TWO]], allRows=[false], " @@ -7782,14 +7781,13 @@ public class JdbcTest { + " pattern (s up)\n" + " define up as up.\"commission\" < prev(up.\"commission\"))"; final String convert = "" - + "LogicalProject(C=[$0], EMPID=[$1])\n" - + " LogicalMatch(partition=[[]], order=[[0 DESC]], " + + "LogicalMatch(partition=[[]], order=[[0 DESC]], " + "outputFields=[[C, EMPID]], allRows=[false], " + "after=[FLAG(SKIP TO NEXT ROW)], pattern=[('S', 'UP')], " + "isStrictStarts=[false], isStrictEnds=[false], subsets=[[]], " + "patternDefinitions=[[<(PREV(UP.$4, 0), PREV(UP.$4, 1))]], " + "inputFields=[[empid, deptno, name, salary, commission]])\n" - + " LogicalTableScan(table=[[hr, emps]])\n"; + + " LogicalTableScan(table=[[hr, emps]])\n"; final String plan = "PLAN=" + "EnumerableMatch(partition=[[]], order=[[0 DESC]], " + "outputFields=[[C, EMPID]], allRows=[false], " diff --git a/core/src/test/java/org/apache/calcite/test/StreamTest.java b/core/src/test/java/org/apache/calcite/test/StreamTest.java index 6192657bb4..d2bf27f372 100644 --- a/core/src/test/java/org/apache/calcite/test/StreamTest.java +++ b/core/src/test/java/org/apache/calcite/test/StreamTest.java @@ -114,9 +114,10 @@ public class StreamTest { .query("select stream product from orders where units > 6") .convertContains( "LogicalDelta\n" - + " LogicalProject(PRODUCT=[$2])\n" - + " LogicalFilter(condition=[>($3, 6)])\n" - + " LogicalTableScan(table=[[STREAMS, ORDERS]])\n") + + " LogicalProject(PRODUCT=[$1])\n" + + " LogicalFilter(condition=[>($2, 6)])\n" + + " LogicalProject(ROWTIME=[$0], PRODUCT=[$2], UNITS=[$3])\n" + + " LogicalTableScan(table=[[STREAMS, ORDERS]])\n") .explainContains( "EnumerableCalc(expr#0..3=[{inputs}], expr#4=[6], expr#5=[>($t3, $t4)], PRODUCT=[$t2], $condition=[$t5])\n" + " EnumerableInterpreter\n" @@ -260,16 +261,16 @@ public class StreamTest { + "orders.rowtime as rowtime, orders.id as orderId, products.supplier as supplierId " + "from orders join products on orders.product = products.id") .convertContains("LogicalDelta\n" - + " LogicalProject(ROWTIME=[$0], ORDERID=[$1], SUPPLIERID=[$6])\n" - + " LogicalJoin(condition=[=($4, $5)], joinType=[inner])\n" - + " LogicalProject(ROWTIME=[$0], ID=[$1], PRODUCT=[$2], UNITS=[$3], PRODUCT0=[CAST($2):VARCHAR(32) NOT NULL])\n" + + " LogicalProject(ROWTIME=[$0], ORDERID=[$1], SUPPLIERID=[$4])\n" + + " LogicalJoin(condition=[=($2, $3)], joinType=[inner])\n" + + " LogicalProject(ROWTIME=[$0], ID=[$1], PRODUCT0=[CAST($2):VARCHAR(32) NOT NULL])\n" + " LogicalTableScan(table=[[STREAM_JOINS, ORDERS]])\n" + " LogicalTableScan(table=[[STREAM_JOINS, PRODUCTS]])\n") .explainContains("" - + "EnumerableCalc(expr#0..6=[{inputs}], proj#0..1=[{exprs}], SUPPLIERID=[$t6])\n" - + " EnumerableMergeJoin(condition=[=($4, $5)], joinType=[inner])\n" - + " EnumerableSort(sort0=[$4], dir0=[ASC])\n" - + " EnumerableCalc(expr#0..3=[{inputs}], expr#4=[CAST($t2):VARCHAR(32) NOT NULL], proj#0..4=[{exprs}])\n" + + "EnumerableCalc(expr#0..4=[{inputs}], proj#0..1=[{exprs}], SUPPLIERID=[$t4])\n" + + " EnumerableMergeJoin(condition=[=($2, $3)], joinType=[inner])\n" + + " EnumerableSort(sort0=[$2], dir0=[ASC])\n" + + " EnumerableCalc(expr#0..3=[{inputs}], expr#4=[CAST($t2):VARCHAR(32) NOT NULL], proj#0..1=[{exprs}], PRODUCT0=[$t4])\n" + " EnumerableInterpreter\n" + " BindableTableScan(table=[[STREAM_JOINS, ORDERS, (STREAM)]])\n" + " EnumerableSort(sort0=[$0], dir0=[ASC])\n" diff --git a/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableJoinTest.java b/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableJoinTest.java index e3a2c6d388..ccd311c831 100644 --- a/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableJoinTest.java +++ b/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableJoinTest.java @@ -313,17 +313,19 @@ class EnumerableJoinTest { + " EnumerableCalc(expr#0..4=[{inputs}], expr#5=[2], expr#6=[=($t0, $t5)], empid=[$t0], name=[$t2], $condition=[$t6])\n" + " EnumerableTableScan(table=[[s, emps]])\n" + " EnumerableTableSpool(readType=[LAZY], writeType=[LAZY], table=[[#DELTA#]])\n" - + " EnumerableCalc(expr#0..8=[{inputs}], empid=[$t4], name=[$t6])\n" - + " EnumerableMergeJoin(condition=[=($3, $4)], joinType=[inner])\n" - + " EnumerableSort(sort0=[$3], dir0=[ASC])\n" - + " EnumerableMergeJoin(condition=[=($0, $2)], joinType=[inner])\n" + + " EnumerableCalc(expr#0..4=[{inputs}], empid=[$t3], name=[$t4])\n" + + " EnumerableMergeJoin(condition=[=($2, $3)], joinType=[inner])\n" + + " EnumerableSort(sort0=[$2], dir0=[ASC])\n" + + " EnumerableMergeJoin(condition=[=($0, $1)], joinType=[inner])\n" + " EnumerableSort(sort0=[$0], dir0=[ASC])\n" - + " EnumerableInterpreter\n" - + " BindableTableScan(table=[[#DELTA#]])\n" + + " EnumerableCalc(expr#0..1=[{inputs}], empid=[$t0])\n" + + " EnumerableInterpreter\n" + + " BindableTableScan(table=[[#DELTA#]])\n" + " EnumerableSort(sort0=[$0], dir0=[ASC])\n" + " EnumerableTableScan(table=[[s, hierarchies]])\n" + " EnumerableSort(sort0=[$0], dir0=[ASC])\n" - + " EnumerableTableScan(table=[[s, emps]])\n") + + " EnumerableCalc(expr#0..4=[{inputs}], empid=[$t0], name=[$t2])\n" + + " EnumerableTableScan(table=[[s, emps]])\n") .returnsUnordered("" + "empid=2; name=Emp2\n" + "empid=3; name=Emp3\n" diff --git a/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableRepeatUnionTest.java b/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableRepeatUnionTest.java index f4a38781aa..b7c18f4a33 100644 --- a/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableRepeatUnionTest.java +++ b/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableRepeatUnionTest.java @@ -302,16 +302,16 @@ class EnumerableRepeatUnionTest { + " EnumerableCalc(expr#0..4=[{inputs}], expr#5=[2], expr#6=[=($t0, $t5)], empid=[$t0], name=[$t2], $condition=[$t6])\n" + " EnumerableTableScan(table=[[s, emps]])\n" + " EnumerableTableSpool(readType=[LAZY], writeType=[LAZY], table=[[#DELTA#]])\n" - + " EnumerableCalc(expr#0..8=[{inputs}], empid=[$t4], name=[$t6])\n" + + " EnumerableCalc(expr#0..4=[{inputs}], empid=[$t3], name=[$t4])\n" + " EnumerableCorrelate(correlation=[$cor1], joinType=[inner], requiredColumns=[{1}])\n" // It is important to have EnumerableCorrelate + #DELTA# table scan on its right // to reproduce the issue CALCITE-4054 + " EnumerableCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0}])\n" + " EnumerableTableScan(table=[[s, hierarchies]])\n" - + " EnumerableCalc(expr#0..1=[{inputs}], expr#2=[$cor0], expr#3=[$t2.managerid], expr#4=[=($t0, $t3)], proj#0..1=[{exprs}], $condition=[$t4])\n" + + " EnumerableCalc(expr#0..1=[{inputs}], expr#2=[$cor0], expr#3=[$t2.managerid], expr#4=[=($t0, $t3)], empid=[$t0], $condition=[$t4])\n" + " EnumerableInterpreter\n" + " BindableTableScan(table=[[#DELTA#]])\n" - + " EnumerableCalc(expr#0..4=[{inputs}], expr#5=[$cor1], expr#6=[$t5.subordinateid], expr#7=[=($t6, $t0)], proj#0..4=[{exprs}], $condition=[$t7])\n" + + " EnumerableCalc(expr#0..4=[{inputs}], expr#5=[$cor1], expr#6=[$t5.subordinateid], expr#7=[=($t6, $t0)], empid=[$t0], name=[$t2], $condition=[$t7])\n" + " EnumerableTableScan(table=[[s, emps]])\n") .returnsUnordered("" + "empid=2; name=Emp2\n" diff --git a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml index 908829af8c..d37e1c019e 100644 --- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml +++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml @@ -1811,9 +1811,10 @@ LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$ LogicalAggregate(group=[{0}]) LogicalProject(EXPR$0=[$2]) LogicalAggregate(group=[{0, 1}], EXPR$0=[MAX($2)]) - LogicalProject(DEPTNO=[$7], $f1=['abc'], SAL=[$5]) + LogicalProject(DEPTNO=[$2], $f1=['abc'], SAL=[$1]) LogicalFilter(condition=[=($cor0.MGR, $0)]) - LogicalTableScan(table=[[CATALOG, SALES, EMP]]) + LogicalProject(EMPNO=[$0], SAL=[$5], DEPTNO=[$7]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) ]]> </Resource> <Resource name="planMid"> @@ -1825,9 +1826,10 @@ LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$ LogicalAggregate(group=[{0}]) LogicalProject(EXPR$0=[$2]) LogicalAggregate(group=[{0, 1}], EXPR$0=[MAX($2)]) - LogicalProject(DEPTNO=[$7], $f1=['abc'], SAL=[$5]) + LogicalProject(DEPTNO=[$2], $f1=['abc'], SAL=[$1]) LogicalFilter(condition=[=($cor0.MGR, $0)]) - LogicalTableScan(table=[[CATALOG, SALES, EMP]]) + LogicalProject(EMPNO=[$0], SAL=[$5], DEPTNO=[$7]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) ]]> </Resource> <Resource name="planAfter"> @@ -1849,22 +1851,20 @@ LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$ </Resource> <Resource name="planBefore"> <![CDATA[ -LogicalProject(C=[$0], S=[$1]) - LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0}]) - LogicalAggregate(group=[{}], C=[MYAGG($0, $1)]) - LogicalProject(SAL=[$5], $f1=[1]) - LogicalTableScan(table=[[CATALOG, SALES, EMP]]) - LogicalTableFunctionScan(invocation=[RAMP($cor0.C)], rowType=[RecordType(INTEGER I)]) +LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0}]) + LogicalAggregate(group=[{}], C=[MYAGG($0, $1)]) + LogicalProject(SAL=[$5], $f1=[1]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) + LogicalTableFunctionScan(invocation=[RAMP($cor0.C)], rowType=[RecordType(INTEGER I)]) ]]> </Resource> <Resource name="planMid"> <![CDATA[ -LogicalProject(C=[$0], S=[$1]) - LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0}]) - LogicalAggregate(group=[{}], C=[MYAGG($0, $1)]) - LogicalProject(SAL=[$5], $f1=[1]) - LogicalTableScan(table=[[CATALOG, SALES, EMP]]) - LogicalTableFunctionScan(invocation=[RAMP($cor0.C)], rowType=[RecordType(INTEGER I)]) +LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0}]) + LogicalAggregate(group=[{}], C=[MYAGG($0, $1)]) + LogicalProject(SAL=[$5], $f1=[1]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) + LogicalTableFunctionScan(invocation=[RAMP($cor0.C)], rowType=[RecordType(INTEGER I)]) ]]> </Resource> </TestCase> @@ -1875,24 +1875,22 @@ LogicalProject(C=[$0], S=[$1]) </Resource> <Resource name="planBefore"> <![CDATA[ -LogicalProject(C=[$0], S=[$1]) - LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0}]) - LogicalProject(C=[$2]) - LogicalAggregate(group=[{0, 1}], C=[MYAGG($2, $3)]) - LogicalProject(EMPNO=[$0], $f1=['abc'], SAL=[$5], $f3=[1]) - LogicalTableScan(table=[[CATALOG, SALES, EMP]]) - LogicalTableFunctionScan(invocation=[RAMP($cor0.C)], rowType=[RecordType(INTEGER I)]) +LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0}]) + LogicalProject(C=[$2]) + LogicalAggregate(group=[{0, 1}], C=[MYAGG($2, $3)]) + LogicalProject(EMPNO=[$0], $f1=['abc'], SAL=[$5], $f3=[1]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) + LogicalTableFunctionScan(invocation=[RAMP($cor0.C)], rowType=[RecordType(INTEGER I)]) ]]> </Resource> <Resource name="planMid"> <![CDATA[ -LogicalProject(C=[$0], S=[$1]) - LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0}]) - LogicalProject(C=[$2]) - LogicalAggregate(group=[{0, 1}], C=[MYAGG($2, $3)]) - LogicalProject(EMPNO=[$0], $f1=['abc'], SAL=[$5], $f3=[1]) - LogicalTableScan(table=[[CATALOG, SALES, EMP]]) - LogicalTableFunctionScan(invocation=[RAMP($cor0.C)], rowType=[RecordType(INTEGER I)]) +LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0}]) + LogicalProject(C=[$2]) + LogicalAggregate(group=[{0, 1}], C=[MYAGG($2, $3)]) + LogicalProject(EMPNO=[$0], $f1=['abc'], SAL=[$5], $f3=[1]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) + LogicalTableFunctionScan(invocation=[RAMP($cor0.C)], rowType=[RecordType(INTEGER I)]) ]]> </Resource> </TestCase>