This is an automated email from the ASF dual-hosted git repository.
morrysnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new d23646793c [fix](nereids) binding group by key on agg.output if output
is slot (#15623)
d23646793c is described below
commit d23646793c71ea44e8e9f6b5f28acd9b5b51e237
Author: minghong <[email protected]>
AuthorDate: Thu Jan 12 16:34:56 2023 +0800
[fix](nereids) binding group by key on agg.output if output is slot (#15623)
case 1
`select count(1) from t1 join t2 on t1.a = t2.a group by a`
`group by a` is ambiguous
case 2
`select t1.a from t1 join t2 on t1.a = t2.a group by a`
`group by a` is bound on t1.a
---
.../nereids/rules/analysis/BindSlotReference.java | 40 +++++++++++++++-
.../rules/analysis/BindSlotReferenceTest.java | 55 ++++++++++++++++++++++
2 files changed, 93 insertions(+), 2 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotReference.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotReference.java
index bc0c9325ae..7fbcb9fde0 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotReference.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotReference.java
@@ -74,6 +74,7 @@ import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
import org.apache.commons.lang.StringUtils;
import java.util.ArrayList;
@@ -227,8 +228,20 @@ public class BindSlotReference implements
AnalysisRuleFactory {
group by key cannot bind with agg func
plan:
agg(group_by v, output sum(k) as v)
-
throw AnalysisException
+
+ CASE 4
+ sql:
+ `select count(1) from t1 join t2 group by a`
+ we cannot bind `group by a`, because it is ambiguous
(t1.a and t2.a)
+
+ CASE 5
+ following case 4, if t1.a is in agg.output, we can bind
`group by a` to t1.a
+ sql
+ select t1.a
+ from t1 join t2 on t1.a = t2.a
+ group by a
+ group_by_key is bound on t1.a
*/
duplicatedSlotNames.stream().forEach(dup ->
childOutputsToExpr.remove(dup));
Map<String, Expression> aliasNameToExpr = output.stream()
@@ -261,8 +274,31 @@ public class BindSlotReference implements
AnalysisRuleFactory {
}
return groupBy;
}).collect(Collectors.toList());
+ /*
+ according to case 4 and case 5, we construct boundSlots
+ */
+ Set<String> outputSlotNames = Sets.newHashSet();
+ Set<Slot> outputSlots = output.stream()
+ .filter(SlotReference.class::isInstance)
+ .peek(slot -> outputSlotNames.add(slot.getName()))
+ .map(NamedExpression::toSlot).collect(
+ Collectors.toSet());
+ //suppose group by key is a.
+ // if both t1.a and t2.a are in agg.child.output, and t1.a
in agg.output,
+ // bind group_by_key a with t1.a
+ // ` .filter(slot ->
!outputSlotNames.contains(slot.getName()))`
+ // is used to avoid add t2.a into boundSlots
+ Set<Slot> boundSlots = agg.child().getOutputSet().stream()
+ .filter(slot ->
!outputSlotNames.contains(slot.getName()))
+ .collect(Collectors.toSet());
+
+ boundSlots.addAll(outputSlots);
+ SlotBinder binder = new
SlotBinder(toScope(Lists.newArrayList(boundSlots)), ctx.cascadesContext);
+
+ List<Expression> groupBy = replacedGroupBy.stream()
+ .map(expression -> binder.bind(expression))
+ .collect(Collectors.toList());
- List<Expression> groupBy = bind(replacedGroupBy,
agg.children(), agg, ctx.cascadesContext);
List<Expression> unboundGroupBys = Lists.newArrayList();
boolean hasUnbound = groupBy.stream().anyMatch(
expression -> {
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/BindSlotReferenceTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/BindSlotReferenceTest.java
index e63618cdd7..abc18a70c7 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/BindSlotReferenceTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/BindSlotReferenceTest.java
@@ -19,17 +19,24 @@ package org.apache.doris.nereids.rules.analysis;
import org.apache.doris.nereids.analyzer.UnboundSlot;
import org.apache.doris.nereids.exceptions.AnalysisException;
+import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.NamedExpressionUtil;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
+import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
import org.apache.doris.nereids.trees.plans.JoinType;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+import org.apache.doris.nereids.trees.plans.logical.LogicalSubQueryAlias;
import org.apache.doris.nereids.trees.plans.logical.RelationUtil;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Lists;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
@@ -65,4 +72,52 @@ class BindSlotReferenceTest {
Assertions.assertTrue(exception.getMessage().contains("id#4"));
Assertions.assertTrue(exception.getMessage().contains("id#0"));
}
+
+ /*
+ select t1.id from student t1 join on student t2 on t1.di=t2.id group by id;
+ group_by_key bind on t1.id, not t2.id
+ */
+ @Test
+ public void testGroupByOnJoin() {
+ LogicalOlapScan scan1 = new
LogicalOlapScan(RelationUtil.newRelationId(), PlanConstructor.student);
+ LogicalSubQueryAlias sub1 = new LogicalSubQueryAlias("t1", scan1);
+ LogicalOlapScan scan2 = new
LogicalOlapScan(RelationUtil.newRelationId(), PlanConstructor.student);
+ LogicalSubQueryAlias sub2 = new LogicalSubQueryAlias("t2", scan2);
+ LogicalJoin<LogicalSubQueryAlias<LogicalOlapScan>,
LogicalSubQueryAlias<LogicalOlapScan>> join =
+ new LogicalJoin<>(JoinType.CROSS_JOIN, sub1, sub2);
+ LogicalAggregate<LogicalJoin> aggregate = new LogicalAggregate<>(
+ Lists.newArrayList(new UnboundSlot("id")), //group by
+ Lists.newArrayList(new UnboundSlot("t1", "id")), //output
+ join
+ );
+ PlanChecker checker =
PlanChecker.from(MemoTestUtils.createConnectContext()).analyze(aggregate);
+ LogicalAggregate plan = (LogicalAggregate)
checker.getCascadesContext().getMemo().copyOut();
+ SlotReference groupByKey = (SlotReference)
plan.getGroupByExpressions().get(0);
+ SlotReference t1id = (SlotReference) ((LogicalJoin)
plan.child()).left().getOutput().get(0);
+ SlotReference t2id = (SlotReference) ((LogicalJoin)
plan.child()).right().getOutput().get(0);
+ Assertions.assertEquals(groupByKey.getExprId(), t1id.getExprId());
+ Assertions.assertNotEquals(t1id.getExprId(), t2id.getExprId());
+ }
+
+ /*
+ select count(1) from student t1 join on student t2 on t1.di=t2.id group by
id;
+ group by key is ambiguous
+ */
+ @Test
+ public void testGroupByOnJoinAmbiguous() {
+ LogicalOlapScan scan1 = new
LogicalOlapScan(RelationUtil.newRelationId(), PlanConstructor.student);
+ LogicalSubQueryAlias sub1 = new LogicalSubQueryAlias("t1", scan1);
+ LogicalOlapScan scan2 = new
LogicalOlapScan(RelationUtil.newRelationId(), PlanConstructor.student);
+ LogicalSubQueryAlias sub2 = new LogicalSubQueryAlias("t2", scan2);
+ LogicalJoin<LogicalSubQueryAlias<LogicalOlapScan>,
LogicalSubQueryAlias<LogicalOlapScan>> join =
+ new LogicalJoin<>(JoinType.CROSS_JOIN, sub1, sub2);
+ LogicalAggregate<LogicalJoin> aggregate = new LogicalAggregate<>(
+ Lists.newArrayList(new UnboundSlot("id")), //group by
+ Lists.newArrayList(new Alias(new Count(new IntegerLiteral(1)),
"count(1)")), //output
+ join
+ );
+ AnalysisException exception =
Assertions.assertThrows(AnalysisException.class,
+ () ->
PlanChecker.from(MemoTestUtils.createConnectContext()).analyze(aggregate));
+ Assertions.assertTrue(exception.getMessage().contains("id is
ambiguous: "));
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]