iwanttobepowerful commented on code in PR #4637:
URL: https://github.com/apache/calcite/pull/4637#discussion_r2575428901
##########
core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java:
##########
@@ -1184,6 +1190,147 @@ private static void shiftMapping(Map<Integer, Integer>
mapping, int startIndex,
return null;
}
+ /**
+ * Given the SQL:
+ * SELECT ename,
+ * (SELECT sum(c)
+ * FROM
+ * (SELECT deptno AS c
+ * FROM dept
+ * WHERE dept.deptno = emp.deptno
+ * UNION ALL
+ * SELECT 2 AS c
+ * FROM bonus
+ * WHERE bonus.job = emp.job) AS union_subquery
+ * ) AS correlated_sum
+ * FROM emp;
+ *
+ * <p>from:
+ * LogicalUnion(all=[true])
+ * LogicalProject(C=[CAST($0):INTEGER NOT NULL])
+ * LogicalFilter(condition=[=($0, $cor0.DEPTNO)])
+ * LogicalTableScan(table=[[scott, DEPT]])
+ * LogicalProject(C=[2])
+ * LogicalFilter(condition=[=($1, $cor0.JOB)])
+ * LogicalTableScan(table=[[scott, BONUS]])
+ *
+ * <p>to:
+ * LogicalUnion(all=[true])
+ * LogicalProject(JOB=[$0], DEPTNO=[$1], C=[$2])
+ * LogicalJoin(condition=[IS NOT DISTINCT FROM($1, $3)],
joinType=[inner])
+ * LogicalAggregate(group=[{2, 7}])
+ * LogicalTableScan(table=[[scott, EMP]])
+ * LogicalProject(C=[CAST($0):INTEGER NOT NULL], DEPTNO=[$0])
+ * LogicalTableScan(table=[[scott, DEPT]])
+ * LogicalProject(JOB=[$0], DEPTNO=[$1], C=[$2])
+ * LogicalJoin(condition=[IS NOT DISTINCT FROM($0, $3)],
joinType=[inner])
+ * LogicalAggregate(group=[{2, 7}])
+ * LogicalTableScan(table=[[scott, EMP]])
+ * LogicalProject(C=[2], JOB=[$1])
+ * LogicalFilter(condition=[IS NOT NULL($1)])
+ * LogicalTableScan(table=[[scott, BONUS]])
+ */
+ public @Nullable Frame decorrelateRel(SetOp rel, boolean isCorVarDefined,
+ boolean parentPropagatesNullValues) {
+ if (!isCorVarDefined) {
+ return decorrelateRelHelper(rel, false, parentPropagatesNullValues);
+ }
+
+ final Pair<CorrelationId, Frame> outerFramePair =
requireNonNull(this.frameStack.peek());
+ final CorrelationId outFrameCorrId = outerFramePair.left;
+ final Frame outFrame = outerFramePair.right;
+
+ // Collect CorDef from all inputs
+ ImmutableBitSet.Builder corFieldBuilder = ImmutableBitSet.builder();
+ List<Frame> frames = new ArrayList<>();
+ for (RelNode oldInput : rel.getInputs()) {
+ Frame frame = getInvoke(oldInput, true, rel, parentPropagatesNullValues);
+ if (frame == null) {
+ // If input has not been rewritten, do not rewrite this rel.
+ return null;
+ }
+ frames.add(frame);
+ for (Map.Entry<CorDef, Integer> corDefOutput :
frame.corDefOutputs.entrySet()) {
+ CorDef corDef = corDefOutput.getKey();
+ if (corDef.corr.equals(outFrameCorrId)) {
+ int newIdx =
requireNonNull(outFrame.oldToNewOutputs.get(corDef.field));
+ corFieldBuilder.set(newIdx);
+ }
+ }
+ }
+
+ ImmutableBitSet groupSet = corFieldBuilder.build();
+ List<RelNode> newInputs = new ArrayList<>();
+ final Map<Integer, Integer> mapOldToNewOutputs = new HashMap<>();
+ final NavigableMap<CorDef, Integer> corDefOutputs = new TreeMap<>();
+ for (int i = 0; i < rel.getInputs().size(); i++) {
+ Frame frame = frames.get(i);
+ final List<RexNode> conditions = new ArrayList<>();
+ int groupKeySize = groupSet.cardinality();
+ for (Map.Entry<CorDef, Integer> corDefOutput :
frame.corDefOutputs.entrySet()) {
+ CorDef corDef = corDefOutput.getKey();
+ Integer corIndex = corDefOutput.getValue();
+ if (corDef.corr.equals(outFrameCorrId)) {
+ int newIdx =
requireNonNull(outFrame.oldToNewOutputs.get(corDef.field));
+ int pos = groupSet.indexOf(newIdx);
+ RelDataType leftType =
outFrame.r.getRowType().getFieldList().get(newIdx).getType();
+ RexNode left = new RexInputRef(pos, leftType);
+ RelDataType rightType =
frame.r.getRowType().getFieldList().get(corIndex).getType();
+ RexNode right = new RexInputRef(groupKeySize + corIndex, rightType);
+ conditions.add(relBuilder.isNotDistinctFrom(left, right));
+ corDefOutputs.put(corDef, pos);
+ }
+ }
+
+ // Build LogicalAggregate to obtain the distinct set of corVar from
outFrame.
+ relBuilder.push(outFrame.r).aggregate(relBuilder.groupKey(groupSet));
+
+ // Build LogicalJoin(condition=[IS NOT DISTINCT FROM($0, $1)],
joinType=[inner])
+ // to ensure each corVar's aggregate result is output.
+ final RelNode join = relBuilder.push(frame.r)
+ .join(JoinRelType.INNER, conditions).build();
+ final List<RelDataTypeField> joinOutput =
join.getRowType().getFieldList();
+
+ final PairList<RexNode, String> projects = PairList.of();
+ Project oldProj = (Project) rel.getInputs().get(0);
Review Comment:
@mihaibudiu @suibianwanwank
changed to use setop's first child
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]